Index: lib/libnv/msgio.c =================================================================== --- lib/libnv/msgio.c +++ lib/libnv/msgio.c @@ -63,23 +63,9 @@ /* Linux: arbitrary size, but must be lower than SCM_MAX_FD. */ #define PKG_MAX_SIZE ((64U - 1) * CMSG_SPACE(sizeof(int))) #else -#define PKG_MAX_SIZE (MCLBYTES / CMSG_SPACE(sizeof(int)) - 1) +#define PKG_MAX_SIZE ((MCLBYTES - CMSG_SPACE(0)) / sizeof(void *)) #endif -static int -msghdr_add_fd(struct cmsghdr *cmsg, int fd) -{ - - PJDLOG_ASSERT(fd >= 0); - - cmsg->cmsg_level = SOL_SOCKET; - cmsg->cmsg_type = SCM_RIGHTS; - cmsg->cmsg_len = CMSG_LEN(sizeof(fd)); - bcopy(&fd, CMSG_DATA(cmsg), sizeof(fd)); - - return (0); -} - static void fd_wait(int fd, bool doread) { @@ -222,7 +208,6 @@ struct msghdr msg; struct cmsghdr *cmsg; struct iovec iov; - unsigned int i; int serrno, ret; uint8_t dummy; @@ -241,24 +226,19 @@ msg.msg_iov = &iov; msg.msg_iovlen = 1; - msg.msg_controllen = nfds * CMSG_SPACE(sizeof(int)); + msg.msg_controllen = CMSG_SPACE(nfds * sizeof(int)); msg.msg_control = calloc(1, msg.msg_controllen); if (msg.msg_control == NULL) return (-1); - ret = -1; + cmsg = msg.msg_control; + cmsg->cmsg_level = SOL_SOCKET; + cmsg->cmsg_type = SCM_RIGHTS; + cmsg->cmsg_len = CMSG_LEN(nfds * sizeof(int)); + memcpy(CMSG_DATA(cmsg), fds, nfds * sizeof(int)); - for (i = 0, cmsg = CMSG_FIRSTHDR(&msg); i < nfds && cmsg != NULL; - i++, cmsg = CMSG_NXTHDR(&msg, cmsg)) { - if (msghdr_add_fd(cmsg, fds[i]) == -1) - goto end; - } + ret = msg_send(sock, &msg); - if (msg_send(sock, &msg) == -1) - goto end; - - ret = 0; -end: serrno = errno; free(msg.msg_control); errno = serrno; @@ -268,11 +248,11 @@ static int fd_package_recv(int sock, int *fds, size_t nfds) { + struct iovec iov; struct msghdr msg; struct cmsghdr *cmsg; - unsigned int i; + unsigned int i, n; int serrno, ret; - struct iovec iov; uint8_t dummy; PJDLOG_ASSERT(sock >= 0); @@ -290,7 +270,7 @@ msg.msg_iov = &iov; msg.msg_iovlen = 1; - msg.msg_controllen = nfds * CMSG_SPACE(sizeof(int)); + msg.msg_controllen = CMSG_SPACE(nfds * sizeof(int)); msg.msg_control = calloc(1, msg.msg_controllen); if (msg.msg_control == NULL) return (-1); @@ -300,42 +280,19 @@ if (msg_recv(sock, &msg) == -1) goto end; - i = 0; cmsg = CMSG_FIRSTHDR(&msg); - while (cmsg && i < nfds) { - unsigned int n; - - if (cmsg->cmsg_level != SOL_SOCKET || - cmsg->cmsg_type != SCM_RIGHTS) { - errno = EINVAL; - break; - } - n = (cmsg->cmsg_len - CMSG_LEN(0)) / sizeof(int); - if (i + n > nfds) { - errno = EINVAL; - break; - } - bcopy(CMSG_DATA(cmsg), fds + i, sizeof(int) * n); - cmsg = CMSG_NXTHDR(&msg, cmsg); - i += n; + if (cmsg->cmsg_level != SOL_SOCKET || cmsg->cmsg_type != SCM_RIGHTS) { + errno = EINVAL; + goto end; } - - if (cmsg != NULL || i < nfds) { - unsigned int last; - - /* - * We need to close all received descriptors, even if we have - * different control message (eg. SCM_CREDS) in between. - */ - last = i; - for (i = 0; i < last; i++) { - if (fds[i] >= 0) { - close(fds[i]); - } - } + n = (cmsg->cmsg_len - CMSG_LEN(0)) / sizeof(int); + if (n != nfds) { errno = EINVAL; + for (i = 0; i < n; i++) + (void)close(*((int *)(void *)CMSG_DATA(cmsg) + i)); goto end; } + memcpy(fds, CMSG_DATA(cmsg), sizeof(int) * nfds); #ifndef MSG_CMSG_CLOEXEC /* @@ -344,7 +301,7 @@ * consistency. */ for (i = 0; i < nfds; i++) { - (void) fcntl(fds[i], F_SETFD, FD_CLOEXEC); + (void)fcntl(fds[i], F_SETFD, FD_CLOEXEC); } #endif Index: lib/libnv/tests/nvlist_send_recv_test.c =================================================================== --- lib/libnv/tests/nvlist_send_recv_test.c +++ lib/libnv/tests/nvlist_send_recv_test.c @@ -30,8 +30,9 @@ #include __FBSDID("$FreeBSD$"); -#include +#include #include +#include #include #include @@ -39,6 +40,7 @@ #include #include #include +#include #include #include #include @@ -375,11 +377,106 @@ ATF_REQUIRE_ERRNO(EBADF, nvlist_send(socks[1], nvl) != 0); } +static int +nopenfds(void) +{ + size_t len; + int error, mib[4], n; + + mib[0] = CTL_KERN; + mib[1] = KERN_PROC; + mib[2] = KERN_PROC_NFDS; + mib[3] = 0; + + len = sizeof(n); + error = sysctl(mib, nitems(mib), &n, &len, NULL, 0); + if (error != 0) + return (-1); + return (n); +} + +#define NFDS 512 + +static void +send_many_fds_child(int sock) +{ + char name[16]; + nvlist_t *nvl; + int anfds, bnfds, fd, i, j; + + fd = open(_PATH_DEVNULL, O_RDONLY); + ATF_REQUIRE(fd >= 0); + + for (i = 1; i < NFDS; i++) { + nvl = nvlist_create(0); + bnfds = nopenfds(); + if (bnfds == -1) + err(EXIT_FAILURE, "sysctl"); + + for (j = 0; j < i; j++) { + snprintf(name, sizeof(name), "fd%d", j); + nvlist_add_descriptor(nvl, name, fd); + } + nvlist_send(sock, nvl); + nvlist_destroy(nvl); + + anfds = nopenfds(); + if (anfds == -1) + err(EXIT_FAILURE, "sysctl"); + if (anfds != bnfds) + errx(EXIT_FAILURE, "fd count mismatch"); + } +} + +ATF_TC_WITHOUT_HEAD(nvlist_send_recv__send_many_fds); +ATF_TC_BODY(nvlist_send_recv__send_many_fds, tc) +{ + char name[16]; + nvlist_t *nvl; + int anfds, bnfds, fd, i, j, socks[2], status; + pid_t pid; + + ATF_REQUIRE(socketpair(PF_UNIX, SOCK_STREAM, 0, socks) == 0); + + pid = fork(); + ATF_REQUIRE(pid >= 0); + if (pid == 0) { + /* Child. */ + (void)close(socks[0]); + send_many_fds_child(socks[1]); + _exit(0); + } + + (void)close(socks[1]); + + for (i = 1; i < NFDS; i++) { + bnfds = nopenfds(); + ATF_REQUIRE(bnfds != -1); + + nvl = nvlist_recv(socks[0], 0); + ATF_REQUIRE(nvl != NULL); + for (j = 0; j < i; j++) { + snprintf(name, sizeof(name), "fd%d", j); + fd = nvlist_take_descriptor(nvl, name); + ATF_REQUIRE(close(fd) == 0); + } + nvlist_destroy(nvl); + + anfds = nopenfds(); + ATF_REQUIRE(anfds != -1); + ATF_REQUIRE(anfds == bnfds); + } + + ATF_REQUIRE(waitpid(pid, &status, 0) == pid); + ATF_REQUIRE(status == 0); +} + ATF_TP_ADD_TCS(tp) { ATF_TP_ADD_TC(tp, nvlist_send_recv__send_nvlist); ATF_TP_ADD_TC(tp, nvlist_send_recv__send_closed_fd); + ATF_TP_ADD_TC(tp, nvlist_send_recv__send_many_fds); return (atf_no_error()); }