diff --git a/sys/dev/gve/gve_adminq.c b/sys/dev/gve/gve_adminq.c --- a/sys/dev/gve/gve_adminq.c +++ b/sys/dev/gve/gve_adminq.c @@ -28,6 +28,7 @@ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ +#include #include #include #include diff --git a/sys/kern/uipc_domain.c b/sys/kern/uipc_domain.c --- a/sys/kern/uipc_domain.c +++ b/sys/kern/uipc_domain.c @@ -103,6 +103,12 @@ return (EOPNOTSUPP); } +static int +pr_ctloutput_notsupp(struct socket *so __unused, struct sockopt *sopt __unused) +{ + return (ENOPROTOOPT); +} + static int pr_disconnect_notsupp(struct socket *so) { @@ -207,6 +213,7 @@ NOTSUPP(pr_connect2); NOTSUPP(pr_connectat); NOTSUPP(pr_control); + NOTSUPP(pr_ctloutput); NOTSUPP(pr_disconnect); NOTSUPP(pr_listen); NOTSUPP(pr_peeraddr); diff --git a/sys/kern/uipc_socket.c b/sys/kern/uipc_socket.c --- a/sys/kern/uipc_socket.c +++ b/sys/kern/uipc_socket.c @@ -158,6 +158,8 @@ #include #endif +static bool so_default_checkfiballowed(struct socket *so); +static bool so_default_validfib(struct sockopt *sopt, int fibnum); static int soreceive_rcvoob(struct socket *so, struct uio *uio, int flags); static void so_rdknl_lock(void *); @@ -193,6 +195,9 @@ so_gen_t so_gencnt; /* generation count for sockets */ +so_checkfiballowed_t so_checkfiballowed = so_default_checkfiballowed; +soopt_validfib_t soopt_validfib = so_default_validfib; + MALLOC_DEFINE(M_SONAME, "soname", "socket name"); MALLOC_DEFINE(M_PCB, "pcb", "protocol control block"); @@ -268,6 +273,19 @@ static uma_zone_t socket_zone; int maxsockets; +static bool +so_default_checkfiballowed(struct socket *so __unused) +{ + return (false); +} + +static bool +so_default_validfib(struct sockopt *sopt __unused, int fibnum __unused) +{ + /* Default only allow FIB 0 */ + return (fibnum == 0); +} + static void socket_zone_change(void *tag) { @@ -3144,24 +3162,6 @@ SOCK_UNLOCK(so); break; - case SO_SETFIB: - error = sooptcopyin(sopt, &optval, sizeof optval, - sizeof optval); - if (error) - goto bad; - - if (optval < 0 || optval >= rt_numfibs) { - error = EINVAL; - goto bad; - } - if (((so->so_proto->pr_domain->dom_family == PF_INET) || - (so->so_proto->pr_domain->dom_family == PF_INET6) || - (so->so_proto->pr_domain->dom_family == PF_ROUTE))) - so->so_fibnum = optval; - else - so->so_fibnum = 0; - break; - case SO_USER_COOKIE: error = sooptcopyin(sopt, &val32, sizeof val32, sizeof val32); @@ -3247,6 +3247,13 @@ so->so_max_pacing_rate = val32; break; + case SO_SETFIB: + /* Let the protocol-specific ctloutput handle it */ + error = (*so->so_proto->pr_ctloutput)(so, sopt); + if (error != ENOPROTOOPT) + goto bad; + + /* Fall through */ default: if (V_socket_hhh[HHOOK_SOCKET_OPT]->hhh_nhooks > 0) error = hhook_run_socket(so, sopt, @@ -3590,6 +3597,22 @@ return (0); } +int +soopt_getfib(struct sockopt *sopt, int *fibnum) +{ + int error, optval; + + error = sooptcopyin(sopt, &optval, sizeof optval, sizeof(optval)); + if (error) + return (error); + + if (!soopt_validfib(sopt, optval)) + return (EINVAL); + + *fibnum = optval; + return (0); +} + /* * sohasoutofband(): protocol notifies socket layer of the arrival of new * out-of-band data, which will then notify socket consumers. @@ -4278,3 +4301,26 @@ so->so_error = val; } + +bool +so_setfiballowed(struct socket *so) +{ + + return (so_checkfiballowed(so)); +} + +int +sosetfib(struct socket *so, struct sockopt *sopt) +{ + int error, fibnum; + + error = soopt_getfib(sopt, &fibnum); + if (error != 0) + return (error); + + if (so_setfiballowed(so)) + so->so_fibnum = fibnum; + else + so->so_fibnum = 0; + return (0); +} diff --git a/sys/kern/uipc_syscalls.c b/sys/kern/uipc_syscalls.c --- a/sys/kern/uipc_syscalls.c +++ b/sys/kern/uipc_syscalls.c @@ -1554,3 +1554,21 @@ m_chtype(m, MT_CONTROL); } } + +/* + * Sets fib of a current process. + */ +int +sys_setfib(struct thread *td, struct setfib_args *uap) +{ + int error = 0; + + CURVNET_SET(TD_TO_VNET(td)); + if (soopt_validfib && soopt_validfib(NULL, uap->fibnum)) + td->td_proc->p_fibnum = uap->fibnum; + else + error = EINVAL; + CURVNET_RESTORE(); + + return (error); +} diff --git a/sys/kern/uipc_usrreq.c b/sys/kern/uipc_usrreq.c --- a/sys/kern/uipc_usrreq.c +++ b/sys/kern/uipc_usrreq.c @@ -1757,8 +1757,12 @@ struct xucred xu; int error, optval; - if (sopt->sopt_level != SOL_LOCAL) + if (sopt->sopt_level != SOL_LOCAL) { + error = sosetfib(so, sopt); + if (error != ENOPROTOOPT) + return (error); return (EINVAL); + } unp = sotounpcb(so); KASSERT(unp != NULL, ("uipc_ctloutput: unp == NULL")); diff --git a/sys/net/route/route_tables.c b/sys/net/route/route_tables.c --- a/sys/net/route/route_tables.c +++ b/sys/net/route/route_tables.c @@ -37,12 +37,14 @@ #include "opt_route.h" #include -#include #include #include #include #include #include +#include +#include +#include #include #include #include @@ -85,6 +87,9 @@ VNET_DEFINE(uint32_t, _rt_numfibs) = RT_NUMFIBS; +static bool rtsocheckfiballowed(struct socket *so); +static bool rtsooptvalidfib(struct sockopt *sopt, int fibnum); + /* * Handler for net.my_fibnum. * Returns current fib of the process. @@ -145,24 +150,6 @@ NULL, 0, &sysctl_fibs, "IU", "set number of fibs"); -/* - * Sets fib of a current process. - */ -int -sys_setfib(struct thread *td, struct setfib_args *uap) -{ - int error = 0; - - CURVNET_SET(TD_TO_VNET(td)); - if (uap->fibnum >= 0 && uap->fibnum < V_rt_numfibs) - td->td_proc->p_fibnum = uap->fibnum; - else - error = EINVAL; - CURVNET_RESTORE(); - - return (error); -} - static int rtables_check_proc_fib(void *obj, void *data) { @@ -192,6 +179,8 @@ [PR_METHOD_ATTACH] = rtables_check_proc_fib, }; osd_jail_register(rtables_prison_destructor, methods); + so_checkfiballowed = rtsocheckfiballowed; + soopt_validfib = rtsooptvalidfib; } SYSINIT(rtables_init, SI_SUB_PROTO_DOMAIN, SI_ORDER_THIRD, rtables_init, NULL); @@ -407,3 +396,17 @@ __func__, table, family)); return (rnh->rnh_gen); } + +static bool +rtsocheckfiballowed(struct socket *so) +{ + int family = so->so_proto->pr_domain->dom_family; + + return (family == PF_INET || family == PF_INET6 || family == PF_ROUTE); +} + +static bool +rtsooptvalidfib(struct sockopt *sopt __unused, int fibnum) +{ + return ((u_int)fibnum < V_rt_numfibs); +} diff --git a/sys/net/rtsock.c b/sys/net/rtsock.c --- a/sys/net/rtsock.c +++ b/sys/net/rtsock.c @@ -377,6 +377,17 @@ soisdisconnected(so); } +static int +rts_ctloutput(struct socket *so, struct sockopt *sopt) +{ + + if (sopt->sopt_level != SOL_SOCKET || + sopt->sopt_name != SO_SETFIB) + return (ENOPROTOOPT); + + return (sosetfib(so, sopt)); +} + static SYSCTL_NODE(_net, OID_AUTO, rtsock, CTLFLAG_RW | CTLFLAG_MPSAFE, 0, "Routing socket infrastructure"); static u_long rts_sendspace = 8192; @@ -2700,6 +2711,7 @@ .pr_shutdown = rts_shutdown, .pr_disconnect = rts_disconnect, .pr_close = rts_close, + .pr_ctloutput = rts_ctloutput, }; static struct domain routedomain = { diff --git a/sys/netinet/ip_output.c b/sys/netinet/ip_output.c --- a/sys/netinet/ip_output.c +++ b/sys/netinet/ip_output.c @@ -1081,10 +1081,14 @@ sopt->sopt_dir == SOPT_SET) { switch (sopt->sopt_name) { case SO_SETFIB: - INP_WLOCK(inp); - inp->inp_inc.inc_fibnum = so->so_fibnum; - INP_WUNLOCK(inp); - error = 0; + error = sosetfib(so, sopt); + if (error == 0) { + INP_WLOCK(inp); + inp->inp_inc.inc_fibnum = so->so_fibnum; + INP_WUNLOCK(inp); + } + if (error == ENOPROTOOPT) + error = 0; break; case SO_MAX_PACING_RATE: #ifdef RATELIMIT diff --git a/sys/netinet/raw_ip.c b/sys/netinet/raw_ip.c --- a/sys/netinet/raw_ip.c +++ b/sys/netinet/raw_ip.c @@ -635,11 +635,11 @@ int error, optval; if (sopt->sopt_level != IPPROTO_IP) { - if ((sopt->sopt_level == SOL_SOCKET) && - (sopt->sopt_name == SO_SETFIB)) { + error = sosetfib(so, sopt); + if (error == 0) inp->inp_inc.inc_fibnum = so->so_fibnum; - return (0); - } + if (error != ENOPROTOOPT) + return (error); return (EINVAL); } diff --git a/sys/netinet6/ip6_output.c b/sys/netinet6/ip6_output.c --- a/sys/netinet6/ip6_output.c +++ b/sys/netinet6/ip6_output.c @@ -1643,10 +1643,15 @@ sopt->sopt_dir == SOPT_SET) { switch (sopt->sopt_name) { case SO_SETFIB: - INP_WLOCK(inp); - inp->inp_inc.inc_fibnum = so->so_fibnum; - INP_WUNLOCK(inp); - error = 0; + error = sosetfib(so, sopt); + if (error == 0) { + INP_WLOCK(inp); + inp->inp_inc.inc_fibnum = + so->so_fibnum; + INP_WUNLOCK(inp); + } + if (error == ENOPROTOOPT) + error = 0; break; case SO_MAX_PACING_RATE: #ifdef RATELIMIT diff --git a/sys/netinet6/raw_ip6.c b/sys/netinet6/raw_ip6.c --- a/sys/netinet6/raw_ip6.c +++ b/sys/netinet6/raw_ip6.c @@ -576,13 +576,14 @@ */ return (icmp6_ctloutput(so, sopt)); else if (sopt->sopt_level != IPPROTO_IPV6) { - if (sopt->sopt_level == SOL_SOCKET && - sopt->sopt_name == SO_SETFIB) { + error = sosetfib(so, sopt); + if (error == 0) { INP_WLOCK(inp); inp->inp_inc.inc_fibnum = so->so_fibnum; INP_WUNLOCK(inp); - return (0); } + if (error != ENOPROTOOPT) + return (error); return (EINVAL); } diff --git a/sys/sys/socket.h b/sys/sys/socket.h --- a/sys/sys/socket.h +++ b/sys/sys/socket.h @@ -716,10 +716,16 @@ #ifdef _KERNEL struct socket; +typedef bool (*so_checkfiballowed_t)(struct socket *so); + +extern so_checkfiballowed_t so_checkfiballowed; + int so_options_get(const struct socket *); void so_options_set(struct socket *, int); int so_error_get(const struct socket *); void so_error_set(struct socket *, int); + +bool so_setfiballowed(struct socket *so); #endif /* _KERNEL */ #endif /* !_SYS_SOCKET_H_ */ diff --git a/sys/sys/sockopt.h b/sys/sys/sockopt.h --- a/sys/sys/sockopt.h +++ b/sys/sys/sockopt.h @@ -53,10 +53,16 @@ struct thread *sopt_td; /* calling thread or null if kernel */ }; +typedef bool (*soopt_validfib_t)(struct sockopt *sopt, int fibnum); + +extern soopt_validfib_t soopt_validfib; + int sosetopt(struct socket *so, struct sockopt *sopt); int sogetopt(struct socket *so, struct sockopt *sopt); +int sosetfib(struct socket *so, struct sockopt *sopt); int sooptcopyin(struct sockopt *sopt, void *buf, size_t len, size_t minlen); int sooptcopyout(struct sockopt *sopt, const void *buf, size_t len); +int soopt_getfib(struct sockopt *sopt, int *fibnum); int soopt_getm(struct sockopt *sopt, struct mbuf **mp); int soopt_mcopyin(struct sockopt *sopt, struct mbuf *m); int soopt_mcopyout(struct sockopt *sopt, struct mbuf *m);