diff --git a/sys/kern/uipc_socket.c b/sys/kern/uipc_socket.c --- a/sys/kern/uipc_socket.c +++ b/sys/kern/uipc_socket.c @@ -197,6 +197,9 @@ so_gen_t so_gencnt; /* generation count for sockets */ +so_checkfiballowed_t so_checkfiballowed; +soopt_validfib_t soopt_validfib; + MALLOC_DEFINE(M_SONAME, "soname", "socket name"); MALLOC_DEFINE(M_PCB, "pcb", "protocol control block"); @@ -3143,24 +3146,6 @@ SOCK_UNLOCK(so); break; - case SO_SETFIB: - error = sooptcopyin(sopt, &optval, sizeof optval, - sizeof optval); - if (error) - goto bad; - - if (optval < 0 || optval >= rt_numfibs) { - error = EINVAL; - goto bad; - } - if (((so->so_proto->pr_domain->dom_family == PF_INET) || - (so->so_proto->pr_domain->dom_family == PF_INET6) || - (so->so_proto->pr_domain->dom_family == PF_ROUTE))) - so->so_fibnum = optval; - else - so->so_fibnum = 0; - break; - case SO_USER_COOKIE: error = sooptcopyin(sopt, &val32, sizeof val32, sizeof val32); @@ -3246,6 +3231,15 @@ so->so_max_pacing_rate = val32; break; + case SO_SETFIB: + /* Let the protocol-specific ctloutput handle it */ + if (so->so_proto->pr_ctloutput != NULL) { + error = (*so->so_proto->pr_ctloutput)(so, + sopt); + goto bad; + } + + /* Fall through */ default: if (V_socket_hhh[HHOOK_SOCKET_OPT]->hhh_nhooks > 0) error = hhook_run_socket(so, sopt, @@ -3589,6 +3583,25 @@ return (0); } +int +soopt_getfib(struct sockopt *sopt, int *fibnum) +{ + int error, optval; + + error = sooptcopyin(sopt, &optval, sizeof optval, sizeof(optval)); + if (error) + return (error); + + if (soopt_validfib != NULL) { /* Let netstack check FIB */ + if (!soopt_validfib(sopt, optval)) + return (EINVAL); + } else if (optval != 0) /* Default only allow FIB 0 */ + return (EINVAL); + + *fibnum = optval; + return (0); +} + /* * sohasoutofband(): protocol notifies socket layer of the arrival of new * out-of-band data, which will then notify socket consumers. @@ -4378,3 +4391,32 @@ SOCK_UNLOCK(so); } + +bool +so_setfiballowed(struct socket *so) +{ + + if (so_checkfiballowed != NULL) + return (so_checkfiballowed(so)); + return (false); +} + +int +sosetfib(struct socket *so, struct sockopt *sopt) +{ + int error, fibnum; + + if (sopt->sopt_level != SOL_SOCKET || + sopt->sopt_name != SO_SETFIB) + return (ENOPROTOOPT); + + error = soopt_getfib(sopt, &fibnum); + if (error != 0) + return (error); + + if (so_setfiballowed(so)) + so->so_fibnum = fibnum; + else + so->so_fibnum = 0; + return (0); +} diff --git a/sys/kern/uipc_syscalls.c b/sys/kern/uipc_syscalls.c --- a/sys/kern/uipc_syscalls.c +++ b/sys/kern/uipc_syscalls.c @@ -1606,3 +1606,21 @@ m_chtype(m, MT_CONTROL); } } + +/* + * Sets fib of a current process. + */ +int +sys_setfib(struct thread *td, struct setfib_args *uap) +{ + int error = 0; + + CURVNET_SET(TD_TO_VNET(td)); + if (soopt_validfib && soopt_validfib(NULL, uap->fibnum)) + td->td_proc->p_fibnum = uap->fibnum; + else + error = EINVAL; + CURVNET_RESTORE(); + + return (error); +} diff --git a/sys/kern/uipc_usrreq.c b/sys/kern/uipc_usrreq.c --- a/sys/kern/uipc_usrreq.c +++ b/sys/kern/uipc_usrreq.c @@ -1735,8 +1735,12 @@ struct xucred xu; int error, optval; - if (sopt->sopt_level != SOL_LOCAL) + if (sopt->sopt_level != SOL_LOCAL) { + error = sosetfib(so, sopt); + if (error != ENOPROTOOPT) + return (error); return (EINVAL); + } unp = sotounpcb(so); KASSERT(unp != NULL, ("uipc_ctloutput: unp == NULL")); diff --git a/sys/net/route/route_tables.c b/sys/net/route/route_tables.c --- a/sys/net/route/route_tables.c +++ b/sys/net/route/route_tables.c @@ -42,9 +42,16 @@ #include #include #include +#include #include #include +#include +#include +#include +#include #include +#include +#include #include #include #include @@ -85,6 +92,9 @@ VNET_DEFINE(uint32_t, _rt_numfibs) = RT_NUMFIBS; +static bool rtsocheckfiballowed(struct socket *so); +static bool rtsooptvalidfib(struct sockopt *sopt, int fibnum); + /* * Handler for net.my_fibnum. * Returns current fib of the process. @@ -145,24 +155,6 @@ NULL, 0, &sysctl_fibs, "IU", "set number of fibs"); -/* - * Sets fib of a current process. - */ -int -sys_setfib(struct thread *td, struct setfib_args *uap) -{ - int error = 0; - - CURVNET_SET(TD_TO_VNET(td)); - if (uap->fibnum >= 0 && uap->fibnum < V_rt_numfibs) - td->td_proc->p_fibnum = uap->fibnum; - else - error = EINVAL; - CURVNET_RESTORE(); - - return (error); -} - static int rtables_check_proc_fib(void *obj, void *data) { @@ -192,6 +184,8 @@ [PR_METHOD_ATTACH] = rtables_check_proc_fib, }; osd_jail_register(rtables_prison_destructor, methods); + so_checkfiballowed = rtsocheckfiballowed; + soopt_validfib = rtsooptvalidfib; } SYSINIT(rtables_init, SI_SUB_PROTO_DOMAIN, SI_ORDER_THIRD, rtables_init, NULL); @@ -397,3 +391,17 @@ __func__, table, family)); return (rnh->rnh_gen); } + +static bool +rtsocheckfiballowed(struct socket *so) +{ + return (so->so_proto->pr_domain->dom_family == PF_INET || + so->so_proto->pr_domain->dom_family == PF_INET6 || + so->so_proto->pr_domain->dom_family == PF_ROUTE); +} + +static bool +rtsooptvalidfib(struct sockopt *sopt __unused, int fibnum) +{ + return ((u_int)fibnum < V_rt_numfibs); +} diff --git a/sys/net/rtsock.c b/sys/net/rtsock.c --- a/sys/net/rtsock.c +++ b/sys/net/rtsock.c @@ -380,6 +380,13 @@ soisdisconnected(so); } +static int +rts_ctloutput(struct socket *so, struct sockopt *sopt) +{ + + return (sosetfib(so, sopt)); +} + static SYSCTL_NODE(_net, OID_AUTO, rtsock, CTLFLAG_RW | CTLFLAG_MPSAFE, 0, "Routing socket infrastructure"); static u_long rts_sendspace = 8192; @@ -2691,6 +2698,7 @@ .pr_shutdown = rts_shutdown, .pr_disconnect = rts_disconnect, .pr_close = rts_close, + .pr_ctloutput = rts_ctloutput, }; static struct domain routedomain = { diff --git a/sys/netinet/ip_output.c b/sys/netinet/ip_output.c --- a/sys/netinet/ip_output.c +++ b/sys/netinet/ip_output.c @@ -1105,10 +1105,14 @@ error = 0; break; case SO_SETFIB: - INP_WLOCK(inp); - inp->inp_inc.inc_fibnum = so->so_fibnum; - INP_WUNLOCK(inp); - error = 0; + error = sosetfib(so, sopt); + if (error == 0) { + INP_WLOCK(inp); + inp->inp_inc.inc_fibnum = so->so_fibnum; + INP_WUNLOCK(inp); + } + if (error == ENOPROTOOPT) + error = 0; break; case SO_MAX_PACING_RATE: #ifdef RATELIMIT diff --git a/sys/netinet/raw_ip.c b/sys/netinet/raw_ip.c --- a/sys/netinet/raw_ip.c +++ b/sys/netinet/raw_ip.c @@ -639,11 +639,11 @@ int error, optval; if (sopt->sopt_level != IPPROTO_IP) { - if ((sopt->sopt_level == SOL_SOCKET) && - (sopt->sopt_name == SO_SETFIB)) { + error = sosetfib(so, sopt); + if (error == 0) inp->inp_inc.inc_fibnum = so->so_fibnum; - return (0); - } + if (error != ENOPROTOOPT) + return (error); return (EINVAL); } diff --git a/sys/netinet6/ip6_output.c b/sys/netinet6/ip6_output.c --- a/sys/netinet6/ip6_output.c +++ b/sys/netinet6/ip6_output.c @@ -1663,10 +1663,15 @@ error = 0; break; case SO_SETFIB: - INP_WLOCK(inp); - inp->inp_inc.inc_fibnum = so->so_fibnum; - INP_WUNLOCK(inp); - error = 0; + error = sosetfib(so, sopt); + if (error == 0) { + INP_WLOCK(inp); + inp->inp_inc.inc_fibnum = + so->so_fibnum; + INP_WUNLOCK(inp); + } + if (error == ENOPROTOOPT) + error = 0; break; case SO_MAX_PACING_RATE: #ifdef RATELIMIT diff --git a/sys/netinet6/raw_ip6.c b/sys/netinet6/raw_ip6.c --- a/sys/netinet6/raw_ip6.c +++ b/sys/netinet6/raw_ip6.c @@ -580,13 +580,14 @@ */ return (icmp6_ctloutput(so, sopt)); else if (sopt->sopt_level != IPPROTO_IPV6) { - if (sopt->sopt_level == SOL_SOCKET && - sopt->sopt_name == SO_SETFIB) { + error = sosetfib(so, sopt); + if (error == 0) { INP_WLOCK(inp); inp->inp_inc.inc_fibnum = so->so_fibnum; INP_WUNLOCK(inp); - return (0); } + if (error != ENOPROTOOPT) + return (error); return (EINVAL); } diff --git a/sys/sys/socket.h b/sys/sys/socket.h --- a/sys/sys/socket.h +++ b/sys/sys/socket.h @@ -722,6 +722,10 @@ #ifdef _KERNEL struct socket; +typedef bool (*so_checkfiballowed_t)(struct socket *so); + +extern so_checkfiballowed_t so_checkfiballowed; + struct inpcb *so_sotoinpcb(struct socket *so); struct sockbuf *so_sockbuf_snd(struct socket *); struct sockbuf *so_sockbuf_rcv(struct socket *); @@ -741,6 +745,8 @@ struct protosw *so_protosw_get(const struct socket *); void so_protosw_set(struct socket *, struct protosw *); +bool so_setfiballowed(struct socket *so); + void so_sorwakeup_locked(struct socket *so); void so_sowwakeup_locked(struct socket *so); diff --git a/sys/sys/sockopt.h b/sys/sys/sockopt.h --- a/sys/sys/sockopt.h +++ b/sys/sys/sockopt.h @@ -57,10 +57,16 @@ struct thread *sopt_td; /* calling thread or null if kernel */ }; +typedef bool (*soopt_validfib_t)(struct sockopt *sopt, int fibnum); + +extern soopt_validfib_t soopt_validfib; + int sosetopt(struct socket *so, struct sockopt *sopt); int sogetopt(struct socket *so, struct sockopt *sopt); +int sosetfib(struct socket *so, struct sockopt *sopt); int sooptcopyin(struct sockopt *sopt, void *buf, size_t len, size_t minlen); int sooptcopyout(struct sockopt *sopt, const void *buf, size_t len); +int soopt_getfib(struct sockopt *sopt, int *fibnum); int soopt_getm(struct sockopt *sopt, struct mbuf **mp); int soopt_mcopyin(struct sockopt *sopt, struct mbuf *m); int soopt_mcopyout(struct sockopt *sopt, struct mbuf *m);