diff --git a/sys/netinet/in_pcb.h b/sys/netinet/in_pcb.h --- a/sys/netinet/in_pcb.h +++ b/sys/netinet/in_pcb.h @@ -500,9 +500,10 @@ struct inpcblbgroup { CK_LIST_ENTRY(inpcblbgroup) il_list; struct epoch_context il_epoch_ctx; + struct ucred *il_cred; uint16_t il_lport; /* (c) */ u_char il_vflag; /* (c) */ - u_int8_t il_numa_domain; + uint8_t il_numa_domain; uint32_t il_pad2; union in_dependaddr il_dependladdr; /* (c) */ #define il_laddr il_dependladdr.id46_addr.ia46_addr4 diff --git a/sys/netinet/in_pcb.c b/sys/netinet/in_pcb.c --- a/sys/netinet/in_pcb.c +++ b/sys/netinet/in_pcb.c @@ -250,8 +250,8 @@ */ static struct inpcblbgroup * -in_pcblbgroup_alloc(struct inpcblbgrouphead *hdr, u_char vflag, - uint16_t port, const union in_dependaddr *addr, int size, +in_pcblbgroup_alloc(struct inpcblbgrouphead *hdr, struct ucred *cred, + u_char vflag, uint16_t port, const union in_dependaddr *addr, int size, uint8_t numa_domain) { struct inpcblbgroup *grp; @@ -259,8 +259,9 @@ bytes = __offsetof(struct inpcblbgroup, il_inp[size]); grp = malloc(bytes, M_PCB, M_ZERO | M_NOWAIT); - if (!grp) + if (grp == NULL) return (NULL); + grp->il_cred = crhold(cred); grp->il_vflag = vflag; grp->il_lport = port; grp->il_numa_domain = numa_domain; @@ -276,6 +277,7 @@ struct inpcblbgroup *grp; grp = __containerof(ctx, struct inpcblbgroup, il_epoch_ctx); + crfree(grp->il_cred); free(grp, M_PCB); } @@ -294,7 +296,7 @@ struct inpcblbgroup *grp; int i; - grp = in_pcblbgroup_alloc(hdr, old_grp->il_vflag, + grp = in_pcblbgroup_alloc(hdr, old_grp->il_cred, old_grp->il_vflag, old_grp->il_lport, &old_grp->il_dependladdr, size, old_grp->il_numa_domain); if (grp == NULL) @@ -353,12 +355,6 @@ INP_WLOCK_ASSERT(inp); INP_HASH_WLOCK_ASSERT(pcbinfo); - /* - * Don't allow jailed socket to join local group. - */ - if (inp->inp_socket != NULL && jailed(inp->inp_socket->so_cred)) - return (0); - #ifdef INET6 /* * Don't allow IPv4 mapped INET6 wild socket. @@ -373,17 +369,19 @@ idx = INP_PCBPORTHASH(inp->inp_lport, pcbinfo->ipi_lbgrouphashmask); hdr = &pcbinfo->ipi_lbgrouphashbase[idx]; CK_LIST_FOREACH(grp, hdr, il_list) { - if (grp->il_vflag == inp->inp_vflag && + if (grp->il_cred->cr_prison == inp->inp_cred->cr_prison && + grp->il_vflag == inp->inp_vflag && grp->il_lport == inp->inp_lport && grp->il_numa_domain == numa_domain && memcmp(&grp->il_dependladdr, &inp->inp_inc.inc_ie.ie_dependladdr, - sizeof(grp->il_dependladdr)) == 0) + sizeof(grp->il_dependladdr)) == 0) { break; + } } if (grp == NULL) { /* Create new load balance group. */ - grp = in_pcblbgroup_alloc(hdr, inp->inp_vflag, + grp = in_pcblbgroup_alloc(hdr, inp->inp_cred, inp->inp_vflag, inp->inp_lport, &inp->inp_inc.inc_ie.ie_dependladdr, INPCBLBGROUP_SIZMIN, numa_domain); if (grp == NULL) @@ -2145,15 +2143,20 @@ } #undef INP_LOOKUP_MAPPED_PCB_COST +static bool +in_pcblookup_lb_numa_match(const struct inpcblbgroup *grp, int domain) +{ + return (domain == M_NODOM || domain == grp->il_numa_domain); +} + static struct inpcb * in_pcblookup_lbgroup(const struct inpcbinfo *pcbinfo, const struct in_addr *laddr, uint16_t lport, const struct in_addr *faddr, - uint16_t fport, int lookupflags, int numa_domain) + uint16_t fport, int lookupflags, int domain) { - struct inpcb *local_wild, *numa_wild; const struct inpcblbgrouphead *hdr; struct inpcblbgroup *grp; - uint32_t idx; + struct inpcblbgroup *jail_exact, *jail_wild, *local_exact, *local_wild; INP_HASH_LOCK_ASSERT(pcbinfo); @@ -2161,17 +2164,15 @@ INP_PCBPORTHASH(lport, pcbinfo->ipi_lbgrouphashmask)]; /* - * Order of socket selection: - * 1. non-wild. - * 2. wild (if lookupflags contains INPLOOKUP_WILDCARD). - * - * NOTE: - * - Load balanced group does not contain jailed sockets - * - Load balanced group does not contain IPv4 mapped INET6 wild sockets + * Search for an LB group match based on the following criteria: + * - prefer jailed groups to non-jailed groups + * - prefer exact source address matches to wildcard matches + * - prefer groups bound to the specified NUMA domain */ - local_wild = NULL; - numa_wild = NULL; + jail_exact = jail_wild = local_exact = local_wild = NULL; CK_LIST_FOREACH(grp, hdr, il_list) { + bool injail; + #ifdef INET6 if (!(grp->il_vflag & INP_IPV4)) continue; @@ -2179,27 +2180,47 @@ if (grp->il_lport != lport) continue; - idx = INP_PCBLBGROUP_PKTHASH(faddr, lport, fport) % - grp->il_inpcnt; + injail = prison_flag(grp->il_cred, PR_IP4) != 0; + if (injail && prison_check_ip4_locked(grp->il_cred->cr_prison, + laddr) != 0) + continue; + if (grp->il_laddr.s_addr == laddr->s_addr) { - if (numa_domain == M_NODOM || - grp->il_numa_domain == numa_domain) { - return (grp->il_inp[idx]); - } else { - numa_wild = grp->il_inp[idx]; + if (injail) { + jail_exact = grp; + if (in_pcblookup_lb_numa_match(grp, domain)) + /* This is a perfect match. */ + goto out; + } else if (local_exact == NULL || + in_pcblookup_lb_numa_match(grp, domain)) { + local_exact = grp; + } + } else if (grp->il_laddr.s_addr == INADDR_ANY && + (lookupflags & INPLOOKUP_WILDCARD) != 0) { + if (injail) { + if (jail_wild == NULL || + in_pcblookup_lb_numa_match(grp, domain)) + jail_wild = grp; + } else if (local_wild == NULL || + in_pcblookup_lb_numa_match(grp, domain)) { + local_wild = grp; } - } - if (grp->il_laddr.s_addr == INADDR_ANY && - (lookupflags & INPLOOKUP_WILDCARD) != 0 && - (local_wild == NULL || numa_domain == M_NODOM || - grp->il_numa_domain == numa_domain)) { - local_wild = grp->il_inp[idx]; } } - if (numa_wild != NULL) - return (numa_wild); - return (local_wild); + if (jail_exact != NULL) + grp = jail_exact; + else if (jail_wild != NULL) + grp = jail_wild; + else if (local_exact != NULL) + grp = local_exact; + else + grp = local_wild; + if (grp == NULL) + return (NULL); +out: + return (grp->il_inp[INP_PCBLBGROUP_PKTHASH(faddr, lport, fport) % + grp->il_inpcnt]); } /* @@ -2251,16 +2272,6 @@ if (tmpinp != NULL) return (tmpinp); - /* - * Then look in lb group (for wildcard match). - */ - if ((lookupflags & INPLOOKUP_WILDCARD) != 0) { - inp = in_pcblookup_lbgroup(pcbinfo, &laddr, lport, &faddr, - fport, lookupflags, numa_domain); - if (inp != NULL) - return (inp); - } - /* * Then look for a wildcard match, if requested. */ @@ -2272,6 +2283,15 @@ struct inpcb *jail_wild = NULL; int injail; + /* + * First see if an LB group matches the request before scanning + * all sockets on this port. + */ + inp = in_pcblookup_lbgroup(pcbinfo, &laddr, lport, &faddr, + fport, lookupflags, numa_domain); + if (inp != NULL) + return (inp); + /* * Order of socket selection - we always prefer jails. * 1. jailed, non-wild. @@ -2472,8 +2492,8 @@ MPASS(inp->inp_flags & INP_INHASHLIST); INP_HASH_WLOCK(inp->inp_pcbinfo); - /* XXX: Only do if SO_REUSEPORT_LB set? */ - in_pcbremlbgrouphash(inp); + if ((inp->inp_flags2 & INP_REUSEPORT_LB) != 0) + in_pcbremlbgrouphash(inp); CK_LIST_REMOVE(inp, inp_hash); CK_LIST_REMOVE(inp, inp_portlist); if (CK_LIST_FIRST(&phd->phd_pcblist) == NULL) { diff --git a/sys/netinet6/in6_pcb.c b/sys/netinet6/in6_pcb.c --- a/sys/netinet6/in6_pcb.c +++ b/sys/netinet6/in6_pcb.c @@ -887,15 +887,20 @@ return inp; } +static bool +in6_pcblookup_lb_numa_match(const struct inpcblbgroup *grp, int domain) +{ + return (domain == M_NODOM || domain == grp->il_numa_domain); +} + static struct inpcb * in6_pcblookup_lbgroup(const struct inpcbinfo *pcbinfo, const struct in6_addr *laddr, uint16_t lport, const struct in6_addr *faddr, - uint16_t fport, int lookupflags, uint8_t numa_domain) + uint16_t fport, int lookupflags, uint8_t domain) { - struct inpcb *local_wild, *numa_wild; const struct inpcblbgrouphead *hdr; struct inpcblbgroup *grp; - uint32_t idx; + struct inpcblbgroup *jail_exact, *jail_wild, *local_exact, *local_wild; INP_HASH_LOCK_ASSERT(pcbinfo); @@ -903,17 +908,15 @@ INP_PCBPORTHASH(lport, pcbinfo->ipi_lbgrouphashmask)]; /* - * Order of socket selection: - * 1. non-wild. - * 2. wild (if lookupflags contains INPLOOKUP_WILDCARD). - * - * NOTE: - * - Load balanced group does not contain jailed sockets. - * - Load balanced does not contain IPv4 mapped INET6 wild sockets. + * Search for an LB group match based on the following criteria: + * - prefer jailed groups to non-jailed groups + * - prefer exact source address matches to wildcard matches + * - prefer groups bound to the specified NUMA domain */ - local_wild = NULL; - numa_wild = NULL; + jail_exact = jail_wild = local_exact = local_wild = NULL; CK_LIST_FOREACH(grp, hdr, il_list) { + bool injail; + #ifdef INET if (!(grp->il_vflag & INP_IPV6)) continue; @@ -921,26 +924,47 @@ if (grp->il_lport != lport) continue; - idx = INP6_PCBLBGROUP_PKTHASH(faddr, lport, fport) % - grp->il_inpcnt; + injail = prison_flag(grp->il_cred, PR_IP6) != 0; + if (injail && prison_check_ip6_locked(grp->il_cred->cr_prison, + laddr) != 0) + continue; + if (IN6_ARE_ADDR_EQUAL(&grp->il6_laddr, laddr)) { - if (numa_domain == M_NODOM || - grp->il_numa_domain == numa_domain) { - return (grp->il_inp[idx]); + if (injail) { + jail_exact = grp; + if (in6_pcblookup_lb_numa_match(grp, domain)) + /* This is a perfect match. */ + goto out; + } else if (local_exact == NULL || + in6_pcblookup_lb_numa_match(grp, domain)) { + local_exact = grp; + } + } else if (IN6_IS_ADDR_UNSPECIFIED(&grp->il6_laddr) && + (lookupflags & INPLOOKUP_WILDCARD) != 0) { + if (injail) { + if (jail_wild == NULL || + in6_pcblookup_lb_numa_match(grp, domain)) + jail_wild = grp; + } else if (local_wild == NULL || + in6_pcblookup_lb_numa_match(grp, domain)) { + local_wild = grp; } - else - numa_wild = grp->il_inp[idx]; - } - if (IN6_IS_ADDR_UNSPECIFIED(&grp->il6_laddr) && - (lookupflags & INPLOOKUP_WILDCARD) != 0 && - (local_wild == NULL || numa_domain == M_NODOM || - grp->il_numa_domain == numa_domain)) { - local_wild = grp->il_inp[idx]; } } - if (numa_wild != NULL) - return (numa_wild); - return (local_wild); + + if (jail_exact != NULL) + grp = jail_exact; + else if (jail_wild != NULL) + grp = jail_wild; + else if (local_exact != NULL) + grp = local_exact; + else + grp = local_wild; + if (grp == NULL) + return (NULL); +out: + return (grp->il_inp[INP6_PCBLBGROUP_PKTHASH(faddr, lport, fport) % + grp->il_inpcnt]); } /* @@ -988,16 +1012,6 @@ if (tmpinp != NULL) return (tmpinp); - /* - * Then look in lb group (for wildcard match). - */ - if ((lookupflags & INPLOOKUP_WILDCARD) != 0) { - inp = in6_pcblookup_lbgroup(pcbinfo, laddr, lport, faddr, - fport, lookupflags, numa_domain); - if (inp != NULL) - return (inp); - } - /* * Then look for a wildcard match, if requested. */ @@ -1006,6 +1020,15 @@ struct inpcb *jail_wild = NULL; int injail; + /* + * First see if an LB group matches the request before scanning + * all sockets on this port. + */ + inp = in6_pcblookup_lbgroup(pcbinfo, laddr, lport, faddr, + fport, lookupflags, numa_domain); + if (inp != NULL) + return (inp); + /* * Order of socket selection - we always prefer jails. * 1. jailed, non-wild.