Index: sys/kern/uipc_socket.c =================================================================== --- sys/kern/uipc_socket.c +++ sys/kern/uipc_socket.c @@ -2847,13 +2847,14 @@ int soshutdown(struct socket *so, int how) { - struct protosw *pr = so->so_proto; + struct protosw *pr; int error, soerror_enotconn; if (!(how == SHUT_RD || how == SHUT_WR || how == SHUT_RDWR)) return (EINVAL); soerror_enotconn = 0; + SOCK_LOCK(so); if ((so->so_state & (SS_ISCONNECTED | SS_ISCONNECTING | SS_ISDISCONNECTING)) == 0) { /* @@ -2865,21 +2866,26 @@ * both backward-compatibility and POSIX requirements by forcing * ENOTCONN but still asking protocol to perform pru_shutdown(). */ - if (so->so_type != SOCK_DGRAM && !SOLISTENING(so)) + if (so->so_type != SOCK_DGRAM && !SOLISTENING(so)) { + SOCK_UNLOCK(so); return (ENOTCONN); + } soerror_enotconn = 1; } if (SOLISTENING(so)) { if (how != SHUT_WR) { - SOLISTEN_LOCK(so); so->so_error = ECONNABORTED; solisten_wakeup(so); /* unlocks so */ + } else { + SOCK_UNLOCK(so); } goto done; } + SOCK_UNLOCK(so); CURVNET_SET(so->so_vnet); + pr = so->so_proto; if (pr->pr_usrreqs->pru_flush != NULL) (*pr->pr_usrreqs->pru_flush)(so, how); if (how != SHUT_WR) @@ -2900,20 +2906,21 @@ void sorflush(struct socket *so) { - struct sockbuf *sb = &so->so_rcv; - struct protosw *pr = so->so_proto; struct socket aso; + struct protosw *pr; int error; VNET_SO_ASSERT(so); /* * In order to avoid calling dom_dispose with the socket buffer mutex - * held, and in order to generally avoid holding the lock for a long - * time, we make a copy of the socket buffer and clear the original - * (except locks, state). The new socket buffer copy won't have - * initialized locks so we can only call routines that won't use or - * assert those locks. + * held, we make a partial copy of the socket buffer and clear the + * original. The new socket buffer copy won't have initialized locks so + * we can only call routines that won't use or assert those locks. + * Ideally calling socantrcvmore() would prevent data from being added + * to the buffer, but currently it merely prevents buffered data from + * being read by userspace. We make this effort to free buffered data + * nonetheless. * * Dislodge threads currently blocked in receive and wait to acquire * a lock against other simultaneous readers before clearing the @@ -2921,28 +2928,31 @@ * despite any existing socket disposition on interruptable waiting. */ socantrcvmore(so); + error = SOCK_IO_RECV_LOCK(so, SBL_WAIT | SBL_NOINTR); - KASSERT(error == 0, ("%s: cannot lock sock %p recv buffer", - __func__, so)); + if (error != 0) { + KASSERT(SOLISTENING(so), + ("%s: soiolock(%p) failed", __func__, so)); + return; + } - /* - * Invalidate/clear most of the sockbuf structure, but leave selinfo - * and mutex data unchanged. - */ - SOCKBUF_LOCK(sb); + SOCK_RECVBUF_LOCK(so); bzero(&aso, sizeof(aso)); aso.so_pcb = so->so_pcb; - bcopy(&sb->sb_startzero, &aso.so_rcv.sb_startzero, - sizeof(*sb) - offsetof(struct sockbuf, sb_startzero)); - bzero(&sb->sb_startzero, - sizeof(*sb) - offsetof(struct sockbuf, sb_startzero)); - SOCKBUF_UNLOCK(sb); + bcopy(&so->so_rcv.sb_startzero, &aso.so_rcv.sb_startzero, + offsetof(struct sockbuf, sb_endzero) - + offsetof(struct sockbuf, sb_startzero)); + bzero(&so->so_rcv.sb_startzero, + offsetof(struct sockbuf, sb_endzero) - + offsetof(struct sockbuf, sb_startzero)); + SOCK_RECVBUF_UNLOCK(so); SOCK_IO_RECV_UNLOCK(so); /* * Dispose of special rights and flush the copied socket. Don't call * any unsafe routines (that rely on locks being initialized) on aso. */ + pr = so->so_proto; if (pr->pr_flags & PR_RIGHTS && pr->pr_domain->dom_dispose != NULL) (*pr->pr_domain->dom_dispose)(&aso); sbrelease_internal(&aso.so_rcv, so); Index: sys/sys/sockbuf.h =================================================================== --- sys/sys/sockbuf.h +++ sys/sys/sockbuf.h @@ -104,12 +104,13 @@ u_int sb_tlsdcc; /* (a) TLS characters being decrypted */ int sb_lowat; /* (a) low water mark */ sbintime_t sb_timeo; /* (a) timeout for read/write */ - uint64_t sb_tls_seqno; /* (a) TLS seqno */ - struct ktls_session *sb_tls_info; /* (a + b) TLS state */ struct mbuf *sb_mtls; /* (a) TLS mbuf chain */ struct mbuf *sb_mtlstail; /* (a) last mbuf in TLS chain */ int (*sb_upcall)(struct socket *, void *, int); /* (a) */ void *sb_upcallarg; /* (a) */ +#define sb_endzero sb_tls_seqno + uint64_t sb_tls_seqno; /* (a) TLS seqno */ + struct ktls_session *sb_tls_info; /* (a + b) TLS state */ TAILQ_HEAD(, kaiocb) sb_aiojobq; /* (a) pending AIO ops */ struct task sb_aiotask; /* AIO task */ }; Index: tests/sys/aio/aio_test.c =================================================================== --- tests/sys/aio/aio_test.c +++ tests/sys/aio/aio_test.c @@ -54,6 +54,7 @@ #include #include #include +#include #include #include #include @@ -1319,6 +1320,72 @@ close(s[0]); } +/* + * Test handling of aio_read() and aio_write() on shut-down sockets. + */ +ATF_TC_WITHOUT_HEAD(aio_socket_shutdown); +ATF_TC_BODY(aio_socket_shutdown, tc) +{ + struct aiocb iocb; + sigset_t set; + char *buffer; + ssize_t len; + size_t bsz; + int error, s[2]; + + ATF_REQUIRE_KERNEL_MODULE("aio"); + + ATF_REQUIRE(socketpair(PF_UNIX, SOCK_STREAM, 0, s) != -1); + + bsz = 1024; + buffer = malloc(bsz); + memset(buffer, 0, bsz); + + /* Put some data in s[0]'s recv buffer. */ + ATF_REQUIRE(send(s[1], buffer, bsz, 0) == (ssize_t)bsz); + + /* No more reading from s[0]. */ + ATF_REQUIRE(shutdown(s[0], SHUT_RD) != -1); + + ATF_REQUIRE(buffer != NULL); + + memset(&iocb, 0, sizeof(iocb)); + iocb.aio_fildes = s[0]; + iocb.aio_buf = buffer; + iocb.aio_nbytes = bsz; + ATF_REQUIRE(aio_read(&iocb) == 0); + + /* Expect to see zero bytes, analogous to recv(2). */ + while ((error = aio_error(&iocb)) == EINPROGRESS) + usleep(25000); + ATF_REQUIRE_MSG(error == 0, "aio_error() returned %d", error); + len = aio_return(&iocb); + ATF_REQUIRE_MSG(len == 0, "read job returned %zd bytes", len); + + /* No more writing to s[1]. */ + ATF_REQUIRE(shutdown(s[1], SHUT_WR) != -1); + + /* Block SIGPIPE so that we can detect the error in-band. */ + sigemptyset(&set); + sigaddset(&set, SIGPIPE); + ATF_REQUIRE(sigprocmask(SIG_BLOCK, &set, NULL) == 0); + + memset(&iocb, 0, sizeof(iocb)); + iocb.aio_fildes = s[1]; + iocb.aio_buf = buffer; + iocb.aio_nbytes = bsz; + ATF_REQUIRE(aio_write(&iocb) == 0); + + /* Expect an error, analogous to send(2). */ + while ((error = aio_error(&iocb)) == EINPROGRESS) + usleep(25000); + ATF_REQUIRE_MSG(error == EPIPE, "aio_error() returned %d", error); + + ATF_REQUIRE(close(s[0]) != -1); + ATF_REQUIRE(close(s[1]) != -1); + free(buffer); +} + /* * test aio_fsync's behavior with bad inputs */ @@ -1885,6 +1952,7 @@ ATF_TP_ADD_TC(tp, aio_socket_listen_fail); ATF_TP_ADD_TC(tp, aio_socket_listen_pending); ATF_TP_ADD_TC(tp, aio_socket_short_write_cancel); + ATF_TP_ADD_TC(tp, aio_socket_shutdown); ATF_TP_ADD_TC(tp, aio_writev_dos_iov_len); ATF_TP_ADD_TC(tp, aio_writev_dos_iovcnt); ATF_TP_ADD_TC(tp, aio_writev_efault);