diff --git a/sys/kern/uipc_domain.c b/sys/kern/uipc_domain.c --- a/sys/kern/uipc_domain.c +++ b/sys/kern/uipc_domain.c @@ -109,7 +109,7 @@ return (EOPNOTSUPP); } -static int +int pr_listen_notsupp(struct socket *so, int backlog, struct thread *td) { return (EOPNOTSUPP); diff --git a/sys/netinet/in_pcb.h b/sys/netinet/in_pcb.h --- a/sys/netinet/in_pcb.h +++ b/sys/netinet/in_pcb.h @@ -167,7 +167,10 @@ struct m_snd_tag; struct inpcb { /* Cache line #1 (amd64) */ - CK_LIST_ENTRY(inpcb) inp_hash_exact; /* hash table linkage */ + union { + CK_LIST_ENTRY(inpcb) inp_hash_exact; /* hash table linkage */ + LIST_ENTRY(inpcb) inp_lbgroup_list; /* lb group list */ + }; CK_LIST_ENTRY(inpcb) inp_hash_wild; /* hash table linkage */ struct rwlock inp_lock; /* Cache line #2 (amd64) */ @@ -428,6 +431,7 @@ */ struct inpcblbgroup { CK_LIST_ENTRY(inpcblbgroup) il_list; + LIST_HEAD(, inpcb) il_pending; /* PCBs waiting for listen() */ struct epoch_context il_epoch_ctx; struct ucred *il_cred; uint16_t il_lport; /* (c) */ @@ -671,6 +675,7 @@ int in_pcbladdr(struct inpcb *, struct in_addr *, struct in_addr *, struct ucred *); int in_pcblbgroup_numa(struct inpcb *, int arg); +void in_pcblisten(struct inpcb *); struct inpcb * in_pcblookup(struct inpcbinfo *, struct in_addr, u_int, struct in_addr, u_int, int, struct ifnet *); diff --git a/sys/netinet/in_pcb.c b/sys/netinet/in_pcb.c --- a/sys/netinet/in_pcb.c +++ b/sys/netinet/in_pcb.c @@ -263,6 +263,7 @@ grp = malloc(bytes, M_PCB, M_ZERO | M_NOWAIT); if (grp == NULL) return (NULL); + LIST_INIT(&grp->il_pending); grp->il_cred = crhold(cred); grp->il_vflag = vflag; grp->il_lport = port; @@ -285,11 +286,45 @@ static void in_pcblbgroup_free(struct inpcblbgroup *grp) { + KASSERT(LIST_EMPTY(&grp->il_pending), + ("local group %p still has pending inps", grp)); CK_LIST_REMOVE(grp, il_list); NET_EPOCH_CALL(in_pcblbgroup_free_deferred, &grp->il_epoch_ctx); } +static struct inpcblbgroup * +in_pcblbgroup_find(struct inpcb *inp) +{ + struct inpcbinfo *pcbinfo; + struct inpcblbgroup *grp; + struct inpcblbgrouphead *hdr; + + INP_LOCK_ASSERT(inp); + + pcbinfo = inp->inp_pcbinfo; + INP_HASH_LOCK_ASSERT(pcbinfo); + KASSERT((inp->inp_flags & INP_INLBGROUP) != 0, + ("inpcb %p is not in a load balance group", inp)); + + hdr = &pcbinfo->ipi_lbgrouphashbase[ + INP_PCBPORTHASH(inp->inp_lport, pcbinfo->ipi_lbgrouphashmask)]; + CK_LIST_FOREACH(grp, hdr, il_list) { + struct inpcb *inp1; + + for (unsigned int i = 0; i < grp->il_inpcnt; i++) { + if (inp == grp->il_inp[i]) + goto found; + } + LIST_FOREACH(inp1, &grp->il_pending, inp_lbgroup_list) { + if (inp == inp1) + goto found; + } + } +found: + return (grp); +} + static void in_pcblbgroup_insert(struct inpcblbgroup *grp, struct inpcb *inp) { @@ -298,14 +333,24 @@ grp->il_inpcnt)); INP_WLOCK_ASSERT(inp); - inp->inp_flags |= INP_INLBGROUP; - grp->il_inp[grp->il_inpcnt] = inp; + if (inp->inp_socket->so_proto->pr_listen != pr_listen_notsupp && + !SOLISTENING(inp->inp_socket)) { + /* + * If this is a TCP socket, it should not be visible to lbgroup + * lookups until listen() has been called. + */ + LIST_INSERT_HEAD(&grp->il_pending, inp, inp_lbgroup_list); + } else { + grp->il_inp[grp->il_inpcnt] = inp; - /* - * Synchronize with in_pcblookup_lbgroup(): make sure that we don't - * expose a null slot to the lookup path. - */ - atomic_store_rel_int(&grp->il_inpcnt, grp->il_inpcnt + 1); + /* + * Synchronize with in_pcblookup_lbgroup(): make sure that we + * don't expose a null slot to the lookup path. + */ + atomic_store_rel_int(&grp->il_inpcnt, grp->il_inpcnt + 1); + } + + inp->inp_flags |= INP_INLBGROUP; } static struct inpcblbgroup * @@ -329,6 +374,8 @@ grp->il_inp[i] = old_grp->il_inp[i]; grp->il_inpcnt = old_grp->il_inpcnt; CK_LIST_INSERT_HEAD(hdr, grp, il_list); + LIST_SWAP(&old_grp->il_pending, &grp->il_pending, inpcb, + inp_lbgroup_list); in_pcblbgroup_free(old_grp); return (grp); } @@ -412,6 +459,7 @@ struct inpcbinfo *pcbinfo; struct inpcblbgrouphead *hdr; struct inpcblbgroup *grp; + struct inpcb *inp1; int i; pcbinfo = inp->inp_pcbinfo; @@ -427,13 +475,11 @@ if (grp->il_inp[i] != inp) continue; - if (grp->il_inpcnt == 1) { + if (grp->il_inpcnt == 1 && + LIST_EMPTY(&grp->il_pending)) { /* We are the last, free this local group. */ in_pcblbgroup_free(grp); } else { - KASSERT(grp->il_inpcnt >= 2, - ("invalid local group count %d", - grp->il_inpcnt)); grp->il_inp[i] = grp->il_inp[grp->il_inpcnt - 1]; @@ -446,17 +492,22 @@ inp->inp_flags &= ~INP_INLBGROUP; return; } + LIST_FOREACH(inp1, &grp->il_pending, inp_lbgroup_list) { + if (inp == inp1) { + LIST_REMOVE(inp, inp_lbgroup_list); + inp->inp_flags &= ~INP_INLBGROUP; + return; + } + } } - KASSERT(0, ("%s: did not find %p", __func__, inp)); + __assert_unreachable(); } int in_pcblbgroup_numa(struct inpcb *inp, int arg) { struct inpcbinfo *pcbinfo; - struct inpcblbgrouphead *hdr; - struct inpcblbgroup *grp; - int err, i; + int error; uint8_t numa_domain; switch (arg) { @@ -472,33 +523,20 @@ numa_domain = arg; } - err = 0; pcbinfo = inp->inp_pcbinfo; INP_WLOCK_ASSERT(inp); INP_HASH_WLOCK(pcbinfo); - hdr = &pcbinfo->ipi_lbgrouphashbase[ - INP_PCBPORTHASH(inp->inp_lport, pcbinfo->ipi_lbgrouphashmask)]; - CK_LIST_FOREACH(grp, hdr, il_list) { - for (i = 0; i < grp->il_inpcnt; ++i) { - if (grp->il_inp[i] != inp) - continue; - - if (grp->il_numa_domain == numa_domain) { - goto abort_with_hash_wlock; - } - - /* Remove it from the old group. */ - in_pcbremlbgrouphash(inp); - - /* Add it to the new group based on numa domain. */ - in_pcbinslbgrouphash(inp, numa_domain); - goto abort_with_hash_wlock; - } + if (in_pcblbgroup_find(inp) != NULL) { + /* Remove it from the old group. */ + in_pcbremlbgrouphash(inp); + /* Add it to the new group based on numa domain. */ + in_pcbinslbgrouphash(inp, numa_domain); + error = 0; + } else { + error = ENOENT; } - err = ENOENT; -abort_with_hash_wlock: INP_HASH_WUNLOCK(pcbinfo); - return (err); + return (error); } /* Make sure it is safe to use hashinit(9) on CK_LIST. */ @@ -1437,6 +1475,25 @@ } #endif /* INET */ +void +in_pcblisten(struct inpcb *inp) +{ + struct inpcblbgroup *grp; + + INP_WLOCK_ASSERT(inp); + + if ((inp->inp_flags & INP_INLBGROUP) != 0) { + struct inpcbinfo *pcbinfo; + + pcbinfo = inp->inp_pcbinfo; + INP_HASH_WLOCK(pcbinfo); + grp = in_pcblbgroup_find(inp); + LIST_REMOVE(inp, inp_lbgroup_list); + in_pcblbgroup_insert(grp, inp); + INP_HASH_WUNLOCK(pcbinfo); + } +} + /* * inpcb hash lookups are protected by SMR section. * diff --git a/sys/netinet/tcp_usrreq.c b/sys/netinet/tcp_usrreq.c --- a/sys/netinet/tcp_usrreq.c +++ b/sys/netinet/tcp_usrreq.c @@ -391,6 +391,8 @@ } SOCK_UNLOCK(so); + if (error == 0) + in_pcblisten(inp); if (tp->t_flags & TF_FASTOPEN) tp->t_tfo_pending = tcp_fastopen_alloc_counter(); @@ -448,6 +450,8 @@ } SOCK_UNLOCK(so); + if (error == 0) + in_pcblisten(inp); if (tp->t_flags & TF_FASTOPEN) tp->t_tfo_pending = tcp_fastopen_alloc_counter(); diff --git a/sys/sys/socketvar.h b/sys/sys/socketvar.h --- a/sys/sys/socketvar.h +++ b/sys/sys/socketvar.h @@ -596,6 +596,8 @@ int accept_filt_generic_mod_event(module_t mod, int event, void *data); #endif +int pr_listen_notsupp(struct socket *so, int backlog, struct thread *td); + #endif /* _KERNEL */ /* diff --git a/tests/sys/netinet/Makefile b/tests/sys/netinet/Makefile --- a/tests/sys/netinet/Makefile +++ b/tests/sys/netinet/Makefile @@ -27,6 +27,8 @@ ATF_TESTS_PYTEST+= carp.py ATF_TESTS_PYTEST+= igmp.py +LIBADD.so_reuseport_lb_test= pthread + # Some of the arp tests look for log messages in the dmesg buffer, so run them # serially to avoid problems with interleaved output. TEST_METADATA.arp+= is_exclusive="true" diff --git a/tests/sys/netinet/so_reuseport_lb_test.c b/tests/sys/netinet/so_reuseport_lb_test.c --- a/tests/sys/netinet/so_reuseport_lb_test.c +++ b/tests/sys/netinet/so_reuseport_lb_test.c @@ -28,12 +28,16 @@ */ #include +#include #include #include +#include #include #include +#include +#include #include #include @@ -235,10 +239,149 @@ } } +struct concurrent_add_softc { + struct sockaddr_storage ss; + int socks[128]; + int kq; +}; + +static void * +listener(void *arg) +{ + for (struct concurrent_add_softc *sc = arg;;) { + struct kevent kev; + ssize_t n; + int error, count, cs, s; + uint8_t b; + + count = kevent(sc->kq, NULL, 0, &kev, 1, NULL); + ATF_REQUIRE_MSG(count == 1, + "kevent() failed: %s", strerror(errno)); + + s = (int)kev.ident; + cs = accept(s, NULL, NULL); + ATF_REQUIRE_MSG(cs >= 0, + "accept() failed: %s", strerror(errno)); + + b = 'M'; + n = write(cs, &b, sizeof(b)); + ATF_REQUIRE_MSG(n >= 0, "write() failed: %s", strerror(errno)); + ATF_REQUIRE(n == 1); + + error = close(cs); + ATF_REQUIRE_MSG(error == 0 || errno == ECONNRESET, + "close() failed: %s", strerror(errno)); + } +} + +static void * +connector(void *arg) +{ + for (struct concurrent_add_softc *sc = arg;;) { + ssize_t n; + int error, s; + uint8_t b; + + s = socket(sc->ss.ss_family, SOCK_STREAM, 0); + ATF_REQUIRE_MSG(s >= 0, "socket() failed: %s", strerror(errno)); + + error = setsockopt(s, SOL_SOCKET, SO_REUSEADDR, (int[]){1}, + sizeof(int)); + + error = connect(s, (struct sockaddr *)&sc->ss, sc->ss.ss_len); + ATF_REQUIRE_MSG(error == 0, "connect() failed: %s", + strerror(errno)); + + n = read(s, &b, sizeof(b)); + ATF_REQUIRE_MSG(n >= 0, "read() failed: %s", + strerror(errno)); + ATF_REQUIRE(n == 1); + ATF_REQUIRE(b == 'M'); + error = close(s); + ATF_REQUIRE_MSG(error == 0, + "close() failed: %s", strerror(errno)); + } +} + +/* + * Run three threads. One accepts connections from listening sockets on a + * kqueue, while the other makes connections. The third thread slowly adds + * sockets to the LB group. This is meant to help flush out race conditions. + */ +ATF_TC_WITHOUT_HEAD(concurrent_add); +ATF_TC_BODY(concurrent_add, tc) +{ + struct concurrent_add_softc sc; + struct sockaddr_in *sin; + pthread_t threads[4]; + int error; + + sc.kq = kqueue(); + ATF_REQUIRE_MSG(sc.kq >= 0, "kqueue() failed: %s", strerror(errno)); + + error = pthread_create(&threads[0], NULL, listener, &sc); + ATF_REQUIRE_MSG(error == 0, "pthread_create() failed: %s", + strerror(error)); + + sin = (struct sockaddr_in *)&sc.ss; + memset(sin, 0, sizeof(*sin)); + sin->sin_len = sizeof(*sin); + sin->sin_family = AF_INET; + sin->sin_port = htons(0); + sin->sin_addr.s_addr = htonl(INADDR_LOOPBACK); + + for (size_t i = 0; i < nitems(sc.socks); i++) { + struct kevent kev; + int s; + + sc.socks[i] = s = socket(AF_INET, SOCK_STREAM, 0); + ATF_REQUIRE_MSG(s >= 0, "socket() failed: %s", strerror(errno)); + + error = setsockopt(s, SOL_SOCKET, SO_REUSEPORT_LB, (int[]){1}, + sizeof(int)); + ATF_REQUIRE_MSG(error == 0, + "setsockopt(SO_REUSEPORT_LB) failed: %s", strerror(errno)); + + error = bind(s, (struct sockaddr *)sin, sizeof(*sin)); + ATF_REQUIRE_MSG(error == 0, "bind() failed: %s", + strerror(errno)); + + error = listen(s, 5); + ATF_REQUIRE_MSG(error == 0, "listen() failed: %s", + strerror(errno)); + + EV_SET(&kev, s, EVFILT_READ, EV_ADD | EV_ENABLE, 0, 0, 0); + error = kevent(sc.kq, &kev, 1, NULL, 0, NULL); + ATF_REQUIRE_MSG(error == 0, "kevent() failed: %s", + strerror(errno)); + + if (i == 0) { + socklen_t slen = sizeof(sc.ss); + + error = getsockname(sc.socks[i], + (struct sockaddr *)&sc.ss, &slen); + ATF_REQUIRE_MSG(error == 0, "getsockname() failed: %s", + strerror(errno)); + ATF_REQUIRE(sc.ss.ss_family == AF_INET); + + for (size_t j = 1; j < nitems(threads); j++) { + error = pthread_create(&threads[j], NULL, + connector, &sc); + ATF_REQUIRE_MSG(error == 0, + "pthread_create() failed: %s", + strerror(error)); + } + } + + usleep(20000); + } +} + ATF_TP_ADD_TCS(tp) { ATF_TP_ADD_TC(tp, basic_ipv4); ATF_TP_ADD_TC(tp, basic_ipv6); + ATF_TP_ADD_TC(tp, concurrent_add); return (atf_no_error()); }