diff --git a/sys/netinet/in_pcb.c b/sys/netinet/in_pcb.c --- a/sys/netinet/in_pcb.c +++ b/sys/netinet/in_pcb.c @@ -253,9 +253,8 @@ */ static struct inpcblbgroup * -in_pcblbgroup_alloc(struct inpcblbgrouphead *hdr, struct ucred *cred, - u_char vflag, uint16_t port, const union in_dependaddr *addr, int size, - uint8_t numa_domain) +in_pcblbgroup_alloc(struct ucred *cred, u_char vflag, uint16_t port, + const union in_dependaddr *addr, int size, uint8_t numa_domain) { struct inpcblbgroup *grp; size_t bytes; @@ -270,7 +269,6 @@ grp->il_numa_domain = numa_domain; grp->il_dependladdr = *addr; grp->il_inpsiz = size; - CK_LIST_INSERT_HEAD(hdr, grp, il_list); return (grp); } @@ -292,6 +290,24 @@ NET_EPOCH_CALL(in_pcblbgroup_free_deferred, &grp->il_epoch_ctx); } +static void +in_pcblbgroup_insert(struct inpcblbgroup *grp, struct inpcb *inp) +{ + KASSERT(grp->il_inpcnt < grp->il_inpsiz, + ("invalid local group size %d and count %d", grp->il_inpsiz, + grp->il_inpcnt)); + INP_WLOCK_ASSERT(inp); + + inp->inp_flags |= INP_INLBGROUP; + grp->il_inp[grp->il_inpcnt] = inp; + + /* + * Synchronize with in_pcblookup_lbgroup(): make sure that we don't + * expose a null slot to the lookup path. + */ + atomic_store_rel_int(&grp->il_inpcnt, grp->il_inpcnt + 1); +} + static struct inpcblbgroup * in_pcblbgroup_resize(struct inpcblbgrouphead *hdr, struct inpcblbgroup *old_grp, int size) @@ -299,7 +315,7 @@ struct inpcblbgroup *grp; int i; - grp = in_pcblbgroup_alloc(hdr, old_grp->il_cred, old_grp->il_vflag, + grp = in_pcblbgroup_alloc(old_grp->il_cred, old_grp->il_vflag, old_grp->il_lport, &old_grp->il_dependladdr, size, old_grp->il_numa_domain); if (grp == NULL) @@ -312,34 +328,11 @@ for (i = 0; i < old_grp->il_inpcnt; ++i) grp->il_inp[i] = old_grp->il_inp[i]; grp->il_inpcnt = old_grp->il_inpcnt; + CK_LIST_INSERT_HEAD(hdr, grp, il_list); in_pcblbgroup_free(old_grp); return (grp); } -/* - * PCB at index 'i' is removed from the group. Pull up the ones below il_inp[i] - * and shrink group if possible. - */ -static void -in_pcblbgroup_reorder(struct inpcblbgrouphead *hdr, struct inpcblbgroup **grpp, - int i) -{ - struct inpcblbgroup *grp, *new_grp; - - grp = *grpp; - for (; i + 1 < grp->il_inpcnt; ++i) - grp->il_inp[i] = grp->il_inp[i + 1]; - grp->il_inpcnt--; - - if (grp->il_inpsiz > INPCBLBGROUP_SIZMIN && - grp->il_inpcnt <= grp->il_inpsiz / 4) { - /* Shrink this group. */ - new_grp = in_pcblbgroup_resize(hdr, grp, grp->il_inpsiz / 2); - if (new_grp != NULL) - *grpp = new_grp; - } -} - /* * Add PCB to load balance group for SO_REUSEPORT_LB option. */ @@ -384,11 +377,13 @@ } if (grp == NULL) { /* Create new load balance group. */ - grp = in_pcblbgroup_alloc(hdr, inp->inp_cred, inp->inp_vflag, + grp = in_pcblbgroup_alloc(inp->inp_cred, inp->inp_vflag, inp->inp_lport, &inp->inp_inc.inc_ie.ie_dependladdr, INPCBLBGROUP_SIZMIN, numa_domain); if (grp == NULL) return (ENOBUFS); + in_pcblbgroup_insert(grp, inp); + CK_LIST_INSERT_HEAD(hdr, grp, il_list); } else if (grp->il_inpcnt == grp->il_inpsiz) { if (grp->il_inpsiz >= INPCBLBGROUP_SIZMAX) { if (ratecheck(&lastprint, &interval)) @@ -401,15 +396,10 @@ grp = in_pcblbgroup_resize(hdr, grp, grp->il_inpsiz * 2); if (grp == NULL) return (ENOBUFS); + in_pcblbgroup_insert(grp, inp); + } else { + in_pcblbgroup_insert(grp, inp); } - - KASSERT(grp->il_inpcnt < grp->il_inpsiz, - ("invalid local group size %d and count %d", grp->il_inpsiz, - grp->il_inpcnt)); - - grp->il_inp[grp->il_inpcnt] = inp; - grp->il_inpcnt++; - inp->inp_flags |= INP_INLBGROUP; return (0); } @@ -441,8 +431,17 @@ /* We are the last, free this local group. */ in_pcblbgroup_free(grp); } else { - /* Pull up inpcbs, shrink group if possible. */ - in_pcblbgroup_reorder(hdr, &grp, i); + KASSERT(grp->il_inpcnt >= 2, + ("invalid local group count %d", + grp->il_inpcnt)); + grp->il_inp[i] = + grp->il_inp[grp->il_inpcnt - 1]; + + /* + * Synchronize with in_pcblookup_lbgroup(). + */ + atomic_store_rel_int(&grp->il_inpcnt, + grp->il_inpcnt - 1); } inp->inp_flags &= ~INP_INLBGROUP; return; @@ -2068,8 +2067,11 @@ const struct inpcblbgrouphead *hdr; struct inpcblbgroup *grp; struct inpcblbgroup *jail_exact, *jail_wild, *local_exact, *local_wild; + struct inpcb *inp; + u_int count; INP_HASH_LOCK_ASSERT(pcbinfo); + NET_EPOCH_ASSERT(); hdr = &pcbinfo->ipi_lbgrouphashbase[ INP_PCBPORTHASH(lport, pcbinfo->ipi_lbgrouphashmask)]; @@ -2128,9 +2130,17 @@ grp = local_wild; if (grp == NULL) return (NULL); + out: - return (grp->il_inp[INP_PCBLBGROUP_PKTHASH(faddr, lport, fport) % - grp->il_inpcnt]); + /* + * Synchronize with in_pcblbgroup_insert(). + */ + count = atomic_load_acq_int(&grp->il_inpcnt); + if (count == 0) + return (NULL); + inp = grp->il_inp[INP_PCBLBGROUP_PKTHASH(faddr, lport, fport) % count]; + KASSERT(inp != NULL, ("%s: inp == NULL", __func__)); + return (inp); } static bool diff --git a/sys/netinet6/in6_pcb.c b/sys/netinet6/in6_pcb.c --- a/sys/netinet6/in6_pcb.c +++ b/sys/netinet6/in6_pcb.c @@ -893,6 +893,8 @@ const struct inpcblbgrouphead *hdr; struct inpcblbgroup *grp; struct inpcblbgroup *jail_exact, *jail_wild, *local_exact, *local_wild; + struct inpcb *inp; + u_int count; INP_HASH_LOCK_ASSERT(pcbinfo); @@ -954,8 +956,15 @@ if (grp == NULL) return (NULL); out: - return (grp->il_inp[INP6_PCBLBGROUP_PKTHASH(faddr, lport, fport) % - grp->il_inpcnt]); + /* + * Synchronize with in_pcblbgroup_insert(). + */ + count = atomic_load_acq_int(&grp->il_inpcnt); + if (count == 0) + return (NULL); + inp = grp->il_inp[INP6_PCBLBGROUP_PKTHASH(faddr, lport, fport) % count]; + KASSERT(inp != NULL, ("%s: inp == NULL", __func__)); + return (inp); } static bool