diff --git a/sys/kern/uipc_ktls.c b/sys/kern/uipc_ktls.c --- a/sys/kern/uipc_ktls.c +++ b/sys/kern/uipc_ktls.c @@ -1223,8 +1223,6 @@ if (!ktls_offload_enable) return (ENOTSUP); - if (SOLISTENING(so)) - return (EINVAL); counter_u64_add(ktls_offload_enable_calls, 1); @@ -1256,7 +1254,12 @@ } /* Mark the socket as using TLS offload. */ - SOCKBUF_LOCK(&so->so_rcv); + SOCK_RECVBUF_LOCK(so); + if (SOLISTENING(so)) { + SOCK_RECVBUF_UNLOCK(so); + ktls_free(tls); + return (EINVAL); + } so->so_rcv.sb_tls_seqno = be64dec(en->rec_seq); so->so_rcv.sb_tls_info = tls; so->so_rcv.sb_flags |= SB_TLS_RX; @@ -1264,7 +1267,7 @@ /* Mark existing data as not ready until it can be decrypted. */ sb_mark_notready(&so->so_rcv); ktls_check_rx(&so->so_rcv); - SOCKBUF_UNLOCK(&so->so_rcv); + SOCK_RECVBUF_UNLOCK(so); /* Prefer TOE -> ifnet TLS -> software TLS. */ #ifdef TCP_OFFLOAD @@ -1290,8 +1293,6 @@ if (!ktls_offload_enable) return (ENOTSUP); - if (SOLISTENING(so)) - return (EINVAL); counter_u64_add(ktls_offload_enable_calls, 1); @@ -1334,6 +1335,10 @@ return (error); } + /* + * Serialize with sosend_generic() and make sure that we're not + * operating on a listening socket. + */ error = SOCK_IO_SEND_LOCK(so, SBL_WAIT); if (error) { ktls_free(tls); @@ -1347,7 +1352,7 @@ */ inp = so->so_pcb; INP_WLOCK(inp); - SOCKBUF_LOCK(&so->so_snd); + SOCK_SENDBUF_LOCK(so); so->so_snd.sb_tls_seqno = be64dec(en->rec_seq); so->so_snd.sb_tls_info = tls; if (tls->mode != TCP_TLS_MODE_SW) { @@ -1357,7 +1362,7 @@ if (tp->t_fb->tfb_hwtls_change != NULL) (*tp->t_fb->tfb_hwtls_change)(tp, 1); } - SOCKBUF_UNLOCK(&so->so_snd); + SOCK_SENDBUF_UNLOCK(so); INP_WUNLOCK(inp); SOCK_IO_SEND_UNLOCK(so); diff --git a/sys/kern/uipc_socket.c b/sys/kern/uipc_socket.c --- a/sys/kern/uipc_socket.c +++ b/sys/kern/uipc_socket.c @@ -987,13 +987,24 @@ mtx_lock(&so->so_snd_mtx); mtx_lock(&so->so_rcv_mtx); - /* Interlock with soo_aio_queue(). */ - if (!SOLISTENING(so) && - ((so->so_snd.sb_flags & (SB_AIO | SB_AIO_RUNNING)) != 0 || - (so->so_rcv.sb_flags & (SB_AIO | SB_AIO_RUNNING)) != 0)) { - solisten_proto_abort(so); - return (EINVAL); + /* Interlock with soo_aio_queue() and KTLS. */ + if (!SOLISTENING(so)) { + bool ktls; + +#ifdef KERN_TLS + ktls = so->so_snd.sb_tls_info != NULL || + so->so_rcv.sb_tls_info != NULL; +#else + ktls = false; +#endif + if (ktls || + (so->so_snd.sb_flags & (SB_AIO | SB_AIO_RUNNING)) != 0 || + (so->so_rcv.sb_flags & (SB_AIO | SB_AIO_RUNNING)) != 0) { + solisten_proto_abort(so); + return (EINVAL); + } } + return (0); } diff --git a/tests/sys/kern/ktls_test.c b/tests/sys/kern/ktls_test.c --- a/tests/sys/kern/ktls_test.c +++ b/tests/sys/kern/ktls_test.c @@ -2767,6 +2767,51 @@ ATF_REQUIRE(close(s) == 0); } +/* + * Make sure that listen(2) returns an error for KTLS-enabled sockets, and + * verify that an attempt to enable KTLS on a listening socket fails. + */ +ATF_TC_WITHOUT_HEAD(ktls_listening_socket); +ATF_TC_BODY(ktls_listening_socket, tc) +{ + struct tls_enable en; + struct sockaddr_in sin; + int s; + + ATF_REQUIRE_KTLS(); + + s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + ATF_REQUIRE(s >= 0); + build_tls_enable(tc, CRYPTO_AES_NIST_GCM_16, 128 / 8, 0, + TLS_MINOR_VER_THREE, (uint64_t)random(), &en); + ATF_REQUIRE(setsockopt(s, IPPROTO_TCP, TCP_TXTLS_ENABLE, &en, + sizeof(en)) == 0); + ATF_REQUIRE_ERRNO(EINVAL, listen(s, 1) == -1); + ATF_REQUIRE(close(s) == 0); + + s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + ATF_REQUIRE(s >= 0); + build_tls_enable(tc, CRYPTO_AES_NIST_GCM_16, 128 / 8, 0, + TLS_MINOR_VER_THREE, (uint64_t)random(), &en); + ATF_REQUIRE(setsockopt(s, IPPROTO_TCP, TCP_RXTLS_ENABLE, &en, + sizeof(en)) == 0); + ATF_REQUIRE_ERRNO(EINVAL, listen(s, 1) == -1); + ATF_REQUIRE(close(s) == 0); + + s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + ATF_REQUIRE(s >= 0); + memset(&sin, 0, sizeof(sin)); + ATF_REQUIRE(bind(s, (struct sockaddr *)&sin, sizeof(sin)) == 0); + ATF_REQUIRE(listen(s, 1) == 0); + build_tls_enable(tc, CRYPTO_AES_NIST_GCM_16, 128 / 8, 0, + TLS_MINOR_VER_THREE, (uint64_t)random(), &en); + ATF_REQUIRE_ERRNO(ENOTCONN, + setsockopt(s, IPPROTO_TCP, TCP_TXTLS_ENABLE, &en, sizeof(en)) != 0); + ATF_REQUIRE_ERRNO(EINVAL, + setsockopt(s, IPPROTO_TCP, TCP_RXTLS_ENABLE, &en, sizeof(en)) != 0); + ATF_REQUIRE(close(s) == 0); +} + ATF_TP_ADD_TCS(tp) { /* Transmit tests */ @@ -2792,6 +2837,7 @@ /* Miscellaneous */ ATF_TP_ADD_TC(tp, ktls_sendto_baddst); + ATF_TP_ADD_TC(tp, ktls_listening_socket); return (atf_no_error()); }