diff --git a/sys/rpc/rpcsec_tls.h b/sys/rpc/rpcsec_tls.h --- a/sys/rpc/rpcsec_tls.h +++ b/sys/rpc/rpcsec_tls.h @@ -29,9 +29,7 @@ #define _RPC_RPCSEC_TLS_H_ /* Operation values for rpctls syscall. */ -#define RPCTLS_SYSC_CLSETPATH 1 #define RPCTLS_SYSC_CLSOCKET 2 -#define RPCTLS_SYSC_CLSHUTDOWN 3 #define RPCTLS_SYSC_SRVSETPATH 4 #define RPCTLS_SYSC_SRVSOCKET 5 #define RPCTLS_SYSC_SRVSHUTDOWN 6 diff --git a/sys/rpc/rpcsec_tls/rpctls_impl.c b/sys/rpc/rpcsec_tls/rpctls_impl.c --- a/sys/rpc/rpcsec_tls/rpctls_impl.c +++ b/sys/rpc/rpcsec_tls/rpctls_impl.c @@ -49,6 +49,7 @@ #include #include #include +#include #include @@ -71,16 +72,14 @@ SYSCALL_INIT_LAST }; -static CLIENT *rpctls_connect_handle; static struct mtx rpctls_connect_lock; -static struct socket *rpctls_connect_so = NULL; -static CLIENT *rpctls_connect_cl = NULL; static struct mtx rpctls_server_lock; static struct opaque_auth rpctls_null_verf; KRPC_VNET_DECLARE(uint64_t, svc_vc_tls_handshake_success); KRPC_VNET_DECLARE(uint64_t, svc_vc_tls_handshake_failed); +KRPC_VNET_DEFINE_STATIC(CLIENT *, rpctls_connect_handle); KRPC_VNET_DEFINE_STATIC(CLIENT **, rpctls_server_handle); KRPC_VNET_DEFINE_STATIC(struct socket *, rpctls_server_so) = NULL; KRPC_VNET_DEFINE_STATIC(SVCXPRT *, rpctls_server_xprt) = NULL; @@ -88,7 +87,20 @@ KRPC_VNET_DEFINE_STATIC(int, rpctls_srv_prevproc) = 0; KRPC_VNET_DEFINE_STATIC(bool *, rpctls_server_busy); -static CLIENT *rpctls_connect_client(void); +struct upsock { + RB_ENTRY(upsock) tree; + struct socket *so; + CLIENT *cl; +}; + +static RB_HEAD(upsock_t, upsock) upcall_sockets; +static intptr_t +upsock_compare(const struct upsock *a, const struct upsock *b) +{ + return ((intptr_t)((uintptr_t)a->so/2 - (uintptr_t)b->so/2)); +} +RB_GENERATE_STATIC(upsock_t, upsock, tree, upsock_compare); + static CLIENT *rpctls_server_client(int procpos); static enum clnt_stat rpctls_server(SVCXPRT *xprt, struct socket *so, uint32_t *flags, uint64_t *sslp, @@ -98,6 +110,7 @@ static void rpctls_vnetinit(const void *unused __unused) { + CLIENT *cl; int i; KRPC_VNET(rpctls_server_handle) = malloc(sizeof(CLIENT *) * @@ -106,6 +119,22 @@ RPCTLS_SRV_MAXNPROCS, M_RPC, M_WAITOK | M_ZERO); for (i = 0; i < RPCTLS_SRV_MAXNPROCS; i++) KRPC_VNET(rpctls_server_busy)[i] = false; + + cl = client_nl_create("tlsclnt", RPCTLSCD, RPCTLSCDVERS); + KASSERT(cl, ("%s: netlink client already exist", __func__)); + /* + * Set the try_count to 1 so that no retries of the RPC occur. Since + * it is an upcall to a local daemon, requests should not be lost and + * doing one of these RPCs multiple times is not correct. If the + * server is not working correctly, the daemon can get stuck in + * SSL_connect() trying to read data from the socket during the upcall. + * Set a timeout (currently 15sec) and assume the daemon is hung when + * the timeout occurs. + */ + clnt_control(cl, CLSET_RETRIES, &(int){1}); + clnt_control(cl, CLSET_TIMEOUT, &(struct timeval){.tv_sec = 15}); + clnt_control(cl, CLSET_WAITCHAN, "tlsclntd"); + KRPC_VNET(rpctls_connect_handle) = cl; } VNET_SYSINIT(rpctls_vnetinit, SI_SUB_VNET_DONE, SI_ORDER_ANY, rpctls_vnetinit, NULL); @@ -147,10 +176,11 @@ struct netconfig *nconf; struct file *fp; struct socket *so; + struct upsock *ups; SVCXPRT *xprt; char path[MAXPATHLEN]; int fd = -1, error, i, try_count; - CLIENT *cl, *oldcl[RPCTLS_SRV_MAXNPROCS], *concl; + CLIENT *cl, *oldcl[RPCTLS_SRV_MAXNPROCS]; uint64_t ssl[3]; struct timeval timeo; #ifdef KERN_TLS @@ -186,65 +216,6 @@ } } break; - case RPCTLS_SYSC_CLSETPATH: - if (jailed(curthread->td_ucred)) - error = EPERM; - if (error == 0) - error = copyinstr(uap->path, path, sizeof(path), NULL); - if (error == 0) { - error = ENXIO; -#ifdef KERN_TLS - if (rpctls_getinfo(&maxlen, false, false)) - error = 0; -#endif - } - if (error == 0 && (strlen(path) + 1 > sizeof(sun.sun_path) || - strlen(path) == 0)) - error = EINVAL; - - cl = NULL; - if (error == 0) { - sun.sun_family = AF_LOCAL; - strlcpy(sun.sun_path, path, sizeof(sun.sun_path)); - sun.sun_len = SUN_LEN(&sun); - - nconf = getnetconfigent("local"); - cl = clnt_reconnect_create(nconf, - (struct sockaddr *)&sun, RPCTLSCD, RPCTLSCDVERS, - RPC_MAXDATASIZE, RPC_MAXDATASIZE); - /* - * The number of retries defaults to INT_MAX, which - * effectively means an infinite, uninterruptable loop. - * Set the try_count to 1 so that no retries of the - * RPC occur. Since it is an upcall to a local daemon, - * requests should not be lost and doing one of these - * RPCs multiple times is not correct. - * If the server is not working correctly, the - * daemon can get stuck in SSL_connect() trying - * to read data from the socket during the upcall. - * Set a timeout (currently 15sec) and assume the - * daemon is hung when the timeout occurs. - */ - if (cl != NULL) { - try_count = 1; - CLNT_CONTROL(cl, CLSET_RETRIES, &try_count); - timeo.tv_sec = 15; - timeo.tv_usec = 0; - CLNT_CONTROL(cl, CLSET_TIMEOUT, &timeo); - } else - error = EINVAL; - } - - mtx_lock(&rpctls_connect_lock); - oldcl[0] = rpctls_connect_handle; - rpctls_connect_handle = cl; - mtx_unlock(&rpctls_connect_lock); - - if (oldcl[0] != NULL) { - CLNT_CLOSE(oldcl[0]); - CLNT_RELEASE(oldcl[0]); - } - break; case RPCTLS_SYSC_SRVSETPATH: if (jailed(curthread->td_ucred) && !prison_check_nfsd(curthread->td_ucred)) @@ -327,17 +298,6 @@ } } break; - case RPCTLS_SYSC_CLSHUTDOWN: - mtx_lock(&rpctls_connect_lock); - oldcl[0] = rpctls_connect_handle; - rpctls_connect_handle = NULL; - mtx_unlock(&rpctls_connect_lock); - - if (oldcl[0] != NULL) { - CLNT_CLOSE(oldcl[0]); - CLNT_RELEASE(oldcl[0]); - } - break; case RPCTLS_SYSC_SRVSHUTDOWN: mtx_lock(&rpctls_server_lock); for (i = 0; i < RPCTLS_SRV_MAXNPROCS; i++) { @@ -356,30 +316,33 @@ break; case RPCTLS_SYSC_CLSOCKET: mtx_lock(&rpctls_connect_lock); - so = rpctls_connect_so; - rpctls_connect_so = NULL; - concl = rpctls_connect_cl; - rpctls_connect_cl = NULL; + ups = RB_FIND(upsock_t, &upcall_sockets, + &(struct upsock){ + .so = __DECONST(struct socket *, uap->path) }); + if (__predict_true(ups != NULL)) + RB_REMOVE(upsock_t, &upcall_sockets, ups); mtx_unlock(&rpctls_connect_lock); - if (so != NULL) { - error = falloc(td, &fp, &fd, 0); - if (error == 0) { - /* - * Set ssl refno so that clnt_vc_destroy() will - * not close the socket and will leave that for - * the daemon to do. - */ - soref(so); - ssl[0] = ssl[1] = 0; - ssl[2] = RPCTLS_REFNO_HANDSHAKE; - CLNT_CONTROL(concl, CLSET_TLS, ssl); - finit(fp, FREAD | FWRITE, DTYPE_SOCKET, so, - &socketops); - fdrop(fp, td); /* Drop fp reference. */ - td->td_retval[0] = fd; - } - } else + if (ups == NULL) { + printf("%s: socket lookup failed\n", __func__); error = EPERM; + break; + } + error = falloc(td, &fp, &fd, 0); + if (error == 0) { + /* + * Set ssl refno so that clnt_vc_destroy() will + * not close the socket and will leave that for + * the daemon to do. + */ + soref(ups->so); + ssl[0] = ssl[1] = 0; + ssl[2] = RPCTLS_REFNO_HANDSHAKE; + CLNT_CONTROL(ups->cl, CLSET_TLS, ssl); + finit(fp, FREAD | FWRITE, DTYPE_SOCKET, ups->so, + &socketops); + fdrop(fp, td); /* Drop fp reference. */ + td->td_retval[0] = fd; + } break; case RPCTLS_SYSC_SRVSOCKET: mtx_lock(&rpctls_server_lock); @@ -416,23 +379,6 @@ return (error); } -/* - * Acquire the rpctls_connect_handle and return it with a reference count, - * if it is available. - */ -static CLIENT * -rpctls_connect_client(void) -{ - CLIENT *cl; - - mtx_lock(&rpctls_connect_lock); - cl = rpctls_connect_handle; - if (cl != NULL) - CLNT_ACQUIRE(cl); - mtx_unlock(&rpctls_connect_lock); - return (cl); -} - /* * Acquire the rpctls_server_handle and return it with a reference count, * if it is available. @@ -462,13 +408,12 @@ struct rpc_callextra ext; struct timeval utimeout; enum clnt_stat stat; - CLIENT *cl; + struct upsock ups = { + .so = so, + .cl = newclient, + }; + CLIENT *cl = KRPC_VNET(rpctls_connect_handle); int val; - static bool rpctls_connect_busy = false; - - cl = rpctls_connect_client(); - if (cl == NULL) - return (RPC_AUTHERROR); /* First, do the AUTH_TLS NULL RPC. */ memset(&ext, 0, sizeof(ext)); @@ -483,14 +428,8 @@ if (stat != RPC_SUCCESS) return (RPC_SYSTEMERROR); - /* Serialize the connect upcalls. */ mtx_lock(&rpctls_connect_lock); - while (rpctls_connect_busy) - msleep(&rpctls_connect_busy, &rpctls_connect_lock, PVFS, - "rtlscn", 0); - rpctls_connect_busy = true; - rpctls_connect_so = so; - rpctls_connect_cl = newclient; + RB_INSERT(upsock_t, &upcall_sockets, &ups); mtx_unlock(&rpctls_connect_lock); /* Temporarily block reception during the handshake upcall. */ @@ -503,37 +442,47 @@ arg.certname.certname_val = certname; } else arg.certname.certname_len = 0; + arg.socookie = (uintptr_t)so; stat = rpctlscd_connect_1(&arg, &res, cl); if (stat == RPC_SUCCESS) { +#ifdef INVARIANTS + MPASS((RB_FIND(upsock_t, &upcall_sockets, &ups) == NULL)); +#endif *reterr = res.reterr; if (res.reterr == 0) { *sslp++ = res.sec; *sslp++ = res.usec; *sslp = res.ssl; } - } else if (stat == RPC_TIMEDOUT) { - /* - * Do a shutdown on the socket, since the daemon is probably - * stuck in SSL_connect() trying to read the socket. - * Do not soclose() the socket, since the daemon will close() - * the socket after SSL_connect() returns an error. - */ - soshutdown(so, SHUT_RD); + } else { + mtx_lock(&rpctls_connect_lock); + if (RB_FIND(upsock_t, &upcall_sockets, &ups)) { + struct upsock *removed __diagused; + + removed = RB_REMOVE(upsock_t, &upcall_sockets, &ups); + mtx_unlock(&rpctls_connect_lock); + MPASS(removed == &ups); + /* + * Do a shutdown on the socket, since the daemon is + * probably stuck in SSL_accept() trying to read the + * socket. Do not soclose() the socket, since the + * daemon will close() the socket after SSL_accept() + * returns an error. + */ + soshutdown(so, SHUT_RD); + } else { + /* + * The daemon has taken the socket from the tree, but + * failed to do the handshake. + */ + mtx_unlock(&rpctls_connect_lock); + } } - CLNT_RELEASE(cl); /* Unblock reception. */ val = 0; CLNT_CONTROL(newclient, CLSET_BLOCKRCV, &val); - /* Once the upcall is done, the daemon is done with the fp and so. */ - mtx_lock(&rpctls_connect_lock); - rpctls_connect_so = NULL; - rpctls_connect_cl = NULL; - rpctls_connect_busy = false; - wakeup(&rpctls_connect_busy); - mtx_unlock(&rpctls_connect_lock); - return (stat); } @@ -545,20 +494,13 @@ struct rpctlscd_handlerecord_arg arg; struct rpctlscd_handlerecord_res res; enum clnt_stat stat; - CLIENT *cl; - - cl = rpctls_connect_client(); - if (cl == NULL) { - *reterr = RPCTLSERR_NOSSL; - return (RPC_SUCCESS); - } + CLIENT *cl = KRPC_VNET(rpctls_connect_handle); /* Do the handlerecord upcall. */ arg.sec = sec; arg.usec = usec; arg.ssl = ssl; stat = rpctlscd_handlerecord_1(&arg, &res, cl); - CLNT_RELEASE(cl); if (stat == RPC_SUCCESS) *reterr = res.reterr; return (stat); @@ -598,20 +540,13 @@ struct rpctlscd_disconnect_arg arg; struct rpctlscd_disconnect_res res; enum clnt_stat stat; - CLIENT *cl; - - cl = rpctls_connect_client(); - if (cl == NULL) { - *reterr = RPCTLSERR_NOSSL; - return (RPC_SUCCESS); - } + CLIENT *cl = KRPC_VNET(rpctls_connect_handle); /* Do the disconnect upcall. */ arg.sec = sec; arg.usec = usec; arg.ssl = ssl; stat = rpctlscd_disconnect_1(&arg, &res, cl); - CLNT_RELEASE(cl); if (stat == RPC_SUCCESS) *reterr = res.reterr; return (stat); @@ -854,8 +789,6 @@ &maxlen, &siz, NULL, 0, NULL, 0); if (error != 0) return (false); - if (rpctlscd_run && rpctls_connect_handle == NULL) - return (false); KRPC_CURVNET_SET_QUIET(KRPC_TD_TO_VNET(curthread)); if (rpctlssd_run && KRPC_VNET(rpctls_server_handle)[0] == NULL) { KRPC_CURVNET_RESTORE(); diff --git a/sys/rpc/rpcsec_tls/rpctlscd.x b/sys/rpc/rpcsec_tls/rpctlscd.x --- a/sys/rpc/rpcsec_tls/rpctlscd.x +++ b/sys/rpc/rpcsec_tls/rpctlscd.x @@ -29,6 +29,7 @@ struct rpctlscd_connect_arg { + uint64_t socookie; char certname<>; };