diff --git a/sys/rpc/clnt_rc.c b/sys/rpc/clnt_rc.c --- a/sys/rpc/clnt_rc.c +++ b/sys/rpc/clnt_rc.c @@ -133,7 +133,6 @@ int one = 1; struct ucred *oldcred; CLIENT *newclient = NULL; - uint64_t ssl[3]; uint32_t reterr; mtx_lock(&rc->rc_lock); @@ -200,8 +199,10 @@ (struct sockaddr *) &rc->rc_addr, rc->rc_prog, rc->rc_vers, rc->rc_sendsz, rc->rc_recvsz, rc->rc_intr); if (rc->rc_tls && newclient != NULL) { + CURVNET_SET(so->so_vnet); stat = rpctls_connect(newclient, rc->rc_tlscertname, so, - ssl, &reterr); + &reterr); + CURVNET_RESTORE(); if (stat != RPC_SUCCESS || reterr != RPCTLSERR_OK) { if (stat == RPC_SUCCESS) stat = RPC_FAILED; @@ -213,6 +214,8 @@ td->td_ucred = oldcred; goto out; } + CLNT_CONTROL(newclient, CLSET_TLS, + &(int){RPCTLS_COMPLETE}); } if (newclient != NULL) { int optval = 1; @@ -239,8 +242,6 @@ CLNT_CONTROL(newclient, CLSET_RETRY_TIMEOUT, &rc->rc_retry); CLNT_CONTROL(newclient, CLSET_WAITCHAN, rc->rc_waitchan); CLNT_CONTROL(newclient, CLSET_INTERRUPTIBLE, &rc->rc_intr); - if (rc->rc_tls) - CLNT_CONTROL(newclient, CLSET_TLS, ssl); if (rc->rc_backchannel != NULL) CLNT_CONTROL(newclient, CLSET_BACKCHANNEL, rc->rc_backchannel); stat = RPC_SUCCESS; diff --git a/sys/rpc/clnt_vc.c b/sys/rpc/clnt_vc.c --- a/sys/rpc/clnt_vc.c +++ b/sys/rpc/clnt_vc.c @@ -265,7 +265,7 @@ ct->ct_raw = NULL; ct->ct_record = NULL; ct->ct_record_resid = 0; - ct->ct_sslrefno = 0; + ct->ct_tlsstate = RPCTLS_NONE; TAILQ_INIT(&ct->ct_pending); return (cl); @@ -413,7 +413,7 @@ TAILQ_INSERT_TAIL(&ct->ct_pending, cr, cr_link); mtx_unlock(&ct->ct_lock); - if (ct->ct_sslrefno != 0) { + if (ct->ct_tlsstate > RPCTLS_NONE) { /* * Copy the mbuf chain to a chain of ext_pgs mbuf(s) * as required by KERN_TLS. @@ -632,7 +632,6 @@ struct ct_data *ct = (struct ct_data *)cl->cl_private; void *infop = info; SVCXPRT *xprt; - uint64_t *p; int error; static u_int thrdnum = 0; @@ -751,18 +750,15 @@ if (ct->ct_backchannelxprt == NULL) { SVC_ACQUIRE(xprt); xprt->xp_p2 = ct; - if (ct->ct_sslrefno != 0) + if (ct->ct_tlsstate > RPCTLS_NONE) xprt->xp_tls = RPCTLS_FLAGS_HANDSHAKE; ct->ct_backchannelxprt = xprt; } break; case CLSET_TLS: - p = (uint64_t *)info; - ct->ct_sslsec = *p++; - ct->ct_sslusec = *p++; - ct->ct_sslrefno = *p; - if (ct->ct_sslrefno != RPCTLS_REFNO_HANDSHAKE) { + ct->ct_tlsstate = *(int *)info; + if (ct->ct_tlsstate == RPCTLS_COMPLETE) { /* cl ref cnt is released by clnt_vc_dotlsupcall(). */ CLNT_ACQUIRE(cl); mtx_unlock(&ct->ct_lock); @@ -843,7 +839,7 @@ ct->ct_closing = FALSE; ct->ct_closed = TRUE; - wakeup(&ct->ct_sslrefno); + wakeup(&ct->ct_tlsstate); mtx_unlock(&ct->ct_lock); wakeup(ct); } @@ -872,37 +868,35 @@ /* Wait for the upcall kthread to terminate. */ while ((ct->ct_rcvstate & RPCRCVSTATE_UPCALLTHREAD) != 0) - msleep(&ct->ct_sslrefno, &ct->ct_lock, 0, + msleep(&ct->ct_tlsstate, &ct->ct_lock, 0, "clntvccl", hz); mtx_unlock(&ct->ct_lock); mtx_destroy(&ct->ct_lock); so = ct->ct_closeit ? ct->ct_socket : NULL; if (so) { - if (ct->ct_sslrefno != 0) { - /* - * If the TLS handshake is in progress, the upcall - * will fail, but the socket should be closed by the - * daemon, since the connect upcall has just failed. - */ - if (ct->ct_sslrefno != RPCTLS_REFNO_HANDSHAKE) { - /* - * If the upcall fails, the socket has - * probably been closed via the rpctlscd - * daemon having crashed or been - * restarted, so ignore return stat. - */ - rpctls_cl_disconnect(ct->ct_sslsec, - ct->ct_sslusec, ct->ct_sslrefno, - &reterr); - } + /* + * If the TLS handshake is in progress, the upcall will fail, + * but the socket should be closed by the daemon, since the + * connect upcall has just failed. If the upcall fails, the + * socket has probably been closed via the rpctlscd daemon + * having crashed or been restarted, so ignore return stat. + */ + CURVNET_SET(so->so_vnet); + switch (ct->ct_tlsstate) { + case RPCTLS_COMPLETE: + rpctls_cl_disconnect(so, &reterr); + /* FALLTHROUGH */ + case RPCTLS_INHANDSHAKE: /* Must sorele() to get rid of reference. */ - CURVNET_SET(so->so_vnet); sorele(so); CURVNET_RESTORE(); - } else { + break; + case RPCTLS_NONE: + CURVNET_RESTORE(); soshutdown(so, SHUT_WR); soclose(so); + break; } } m_freem(ct->ct_record); @@ -978,7 +972,7 @@ uio.uio_td = curthread; m2 = m = NULL; rcvflag = MSG_DONTWAIT | MSG_SOCALLBCK; - if (ct->ct_sslrefno != 0 && (ct->ct_rcvstate & + if (ct->ct_tlsstate > RPCTLS_NONE && (ct->ct_rcvstate & RPCRCVSTATE_NORMAL) != 0) rcvflag |= MSG_TLSAPPDATA; SOCK_RECVBUF_UNLOCK(so); @@ -1013,7 +1007,7 @@ * This record needs to be handled in userland * via an SSL_read() call, so do an upcall to the daemon. */ - if (ct->ct_sslrefno != 0 && error == ENXIO) { + if (ct->ct_tlsstate > RPCTLS_NONE && error == ENXIO) { /* Disable reception, marking an upcall needed. */ mtx_lock(&ct->ct_lock); ct->ct_rcvstate |= RPCRCVSTATE_UPCALLNEEDED; @@ -1021,7 +1015,7 @@ * If an upcall in needed, wake up the kthread * that runs clnt_vc_dotlsupcall(). */ - wakeup(&ct->ct_sslrefno); + wakeup(&ct->ct_tlsstate); mtx_unlock(&ct->ct_lock); break; } @@ -1275,11 +1269,10 @@ if ((ct->ct_rcvstate & RPCRCVSTATE_UPCALLNEEDED) != 0) { ct->ct_rcvstate &= ~RPCRCVSTATE_UPCALLNEEDED; ct->ct_rcvstate |= RPCRCVSTATE_UPCALLINPROG; - if (ct->ct_sslrefno != 0 && ct->ct_sslrefno != - RPCTLS_REFNO_HANDSHAKE) { + if (ct->ct_tlsstate == RPCTLS_COMPLETE) { mtx_unlock(&ct->ct_lock); - ret = rpctls_cl_handlerecord(ct->ct_sslsec, - ct->ct_sslusec, ct->ct_sslrefno, &reterr); + ret = rpctls_cl_handlerecord(ct->ct_socket, + &reterr); mtx_lock(&ct->ct_lock); } ct->ct_rcvstate &= ~RPCRCVSTATE_UPCALLINPROG; @@ -1297,10 +1290,10 @@ SOCK_RECVBUF_UNLOCK(ct->ct_socket); mtx_lock(&ct->ct_lock); } - msleep(&ct->ct_sslrefno, &ct->ct_lock, 0, "clntvcdu", hz); + msleep(&ct->ct_tlsstate, &ct->ct_lock, 0, "clntvcdu", hz); } ct->ct_rcvstate &= ~RPCRCVSTATE_UPCALLTHREAD; - wakeup(&ct->ct_sslrefno); + wakeup(&ct->ct_tlsstate); mtx_unlock(&ct->ct_lock); CLNT_RELEASE(cl); CURVNET_RESTORE(); diff --git a/sys/rpc/krpc.h b/sys/rpc/krpc.h --- a/sys/rpc/krpc.h +++ b/sys/rpc/krpc.h @@ -114,9 +114,11 @@ struct ct_request_list ct_pending; int ct_upcallrefs; /* Ref cnt of upcalls in prog. */ SVCXPRT *ct_backchannelxprt; /* xprt for backchannel */ - uint64_t ct_sslsec; /* RPC-over-TLS connection. */ - uint64_t ct_sslusec; - uint64_t ct_sslrefno; + enum tlsstate { + RPCTLS_NONE = 0, + RPCTLS_INHANDSHAKE, /* fd given to the daemon, daemon is working */ + RPCTLS_COMPLETE, /* daemon reported success rpctlscd_connect() */ + } ct_tlsstate; uint32_t ct_rcvstate; /* Handle receiving for TLS upcalls */ struct mbuf *ct_raw; /* Raw mbufs recv'd */ }; diff --git a/sys/rpc/rpcsec_tls.h b/sys/rpc/rpcsec_tls.h --- a/sys/rpc/rpcsec_tls.h +++ b/sys/rpc/rpcsec_tls.h @@ -56,13 +56,11 @@ #ifdef _KERNEL /* Functions that perform upcalls to the rpctlsd daemon. */ enum clnt_stat rpctls_connect(CLIENT *newclient, char *certname, - struct socket *so, uint64_t *sslp, uint32_t *reterr); -enum clnt_stat rpctls_cl_handlerecord(uint64_t sec, uint64_t usec, - uint64_t ssl, uint32_t *reterr); + struct socket *so, uint32_t *reterr); +enum clnt_stat rpctls_cl_handlerecord(void *socookie, uint32_t *reterr); enum clnt_stat rpctls_srv_handlerecord(uint64_t sec, uint64_t usec, uint64_t ssl, int procpos, uint32_t *reterr); -enum clnt_stat rpctls_cl_disconnect(uint64_t sec, uint64_t usec, - uint64_t ssl, uint32_t *reterr); +enum clnt_stat rpctls_cl_disconnect(void *socookie, uint32_t *reterr); enum clnt_stat rpctls_srv_disconnect(uint64_t sec, uint64_t usec, uint64_t ssl, int procpos, uint32_t *reterr); @@ -76,9 +74,6 @@ /* String for AUTH_TLS reply verifier. */ #define RPCTLS_START_STRING "STARTTLS" -/* ssl refno value to indicate TLS handshake being done. */ -#define RPCTLS_REFNO_HANDSHAKE 0xFFFFFFFFFFFFFFFFULL - /* Macros for VIMAGE. */ /* Just define the KRPC_VNETxxx() macros as VNETxxx() macros. */ #define KRPC_VNET_NAME(n) VNET_NAME(n) diff --git a/sys/rpc/rpcsec_tls/rpctls_impl.c b/sys/rpc/rpcsec_tls/rpctls_impl.c --- a/sys/rpc/rpcsec_tls/rpctls_impl.c +++ b/sys/rpc/rpcsec_tls/rpctls_impl.c @@ -55,6 +55,7 @@ #include #include +#include #include #include @@ -172,7 +173,6 @@ struct file *fp; struct upsock *ups; int fd = -1, error; - uint64_t ssl[3]; error = priv_check(td, PRIV_NFS_DAEMON); if (error != 0) @@ -200,13 +200,12 @@ switch (uap->op) { case RPCTLS_SYSC_CLSOCKET: /* - * Set ssl refno so that clnt_vc_destroy() will - * not close the socket and will leave that for - * the daemon to do. + * Initialize TLS state so that clnt_vc_destroy() will + * not close the socket and will leave that for the + * daemon to do. */ - ssl[0] = ssl[1] = 0; - ssl[2] = RPCTLS_REFNO_HANDSHAKE; - CLNT_CONTROL(ups->cl, CLSET_TLS, ssl); + CLNT_CONTROL(ups->cl, CLSET_TLS, + &(int){RPCTLS_INHANDSHAKE}); break; case RPCTLS_SYSC_SRVSOCKET: /* @@ -234,27 +233,23 @@ /* Do an upcall for a new socket connect using TLS. */ enum clnt_stat rpctls_connect(CLIENT *newclient, char *certname, struct socket *so, - uint64_t *sslp, uint32_t *reterr) + uint32_t *reterr) { struct rpctlscd_connect_arg arg; struct rpctlscd_connect_res res; struct rpc_callextra ext; - struct timeval utimeout; enum clnt_stat stat; struct upsock ups = { .so = so, .cl = newclient, }; CLIENT *cl = KRPC_VNET(rpctls_connect_handle); - int val; /* First, do the AUTH_TLS NULL RPC. */ memset(&ext, 0, sizeof(ext)); - utimeout.tv_sec = 30; - utimeout.tv_usec = 0; ext.rc_auth = authtls_create(); stat = clnt_call_private(newclient, &ext, NULLPROC, (xdrproc_t)xdr_void, - NULL, (xdrproc_t)xdr_void, NULL, utimeout); + NULL, (xdrproc_t)xdr_void, NULL, (struct timeval){ .tv_sec = 30 }); AUTH_DESTROY(ext.rc_auth); if (stat == RPC_AUTHERROR) return (stat); @@ -266,8 +261,7 @@ mtx_unlock(&rpctls_lock); /* Temporarily block reception during the handshake upcall. */ - val = 1; - CLNT_CONTROL(newclient, CLSET_BLOCKRCV, &val); + CLNT_CONTROL(newclient, CLSET_BLOCKRCV, &(int){1}); /* Do the connect handshake upcall. */ if (certname != NULL) { @@ -275,18 +269,13 @@ arg.certname.certname_val = certname; } else arg.certname.certname_len = 0; - arg.socookie = (uintptr_t)so; - stat = rpctlscd_connect_1(&arg, &res, cl); + arg.socookie = (uint64_t)so; + stat = rpctlscd_connect_2(&arg, &res, cl); if (stat == RPC_SUCCESS) { #ifdef INVARIANTS MPASS((RB_FIND(upsock_t, &upcall_sockets, &ups) == NULL)); #endif *reterr = res.reterr; - if (res.reterr == 0) { - *sslp++ = res.sec; - *sslp++ = res.usec; - *sslp = res.ssl; - } } else { mtx_lock(&rpctls_lock); if (RB_FIND(upsock_t, &upcall_sockets, &ups)) { @@ -313,16 +302,14 @@ } /* Unblock reception. */ - val = 0; - CLNT_CONTROL(newclient, CLSET_BLOCKRCV, &val); + CLNT_CONTROL(newclient, CLSET_BLOCKRCV, &(int){0}); return (stat); } /* Do an upcall to handle an non-application data record using TLS. */ enum clnt_stat -rpctls_cl_handlerecord(uint64_t sec, uint64_t usec, uint64_t ssl, - uint32_t *reterr) +rpctls_cl_handlerecord(void *socookie, uint32_t *reterr) { struct rpctlscd_handlerecord_arg arg; struct rpctlscd_handlerecord_res res; @@ -330,10 +317,8 @@ CLIENT *cl = KRPC_VNET(rpctls_connect_handle); /* Do the handlerecord upcall. */ - arg.sec = sec; - arg.usec = usec; - arg.ssl = ssl; - stat = rpctlscd_handlerecord_1(&arg, &res, cl); + arg.socookie = (uint64_t)socookie; + stat = rpctlscd_handlerecord_2(&arg, &res, cl); if (stat == RPC_SUCCESS) *reterr = res.reterr; return (stat); @@ -360,8 +345,7 @@ /* Do an upcall to shut down a socket using TLS. */ enum clnt_stat -rpctls_cl_disconnect(uint64_t sec, uint64_t usec, uint64_t ssl, - uint32_t *reterr) +rpctls_cl_disconnect(void *socookie, uint32_t *reterr) { struct rpctlscd_disconnect_arg arg; struct rpctlscd_disconnect_res res; @@ -369,10 +353,8 @@ CLIENT *cl = KRPC_VNET(rpctls_connect_handle); /* Do the disconnect upcall. */ - arg.sec = sec; - arg.usec = usec; - arg.ssl = ssl; - stat = rpctlscd_disconnect_1(&arg, &res, cl); + arg.socookie = (uint64_t)socookie; + stat = rpctlscd_disconnect_2(&arg, &res, cl); if (stat == RPC_SUCCESS) *reterr = res.reterr; return (stat); @@ -420,7 +402,7 @@ /* Do the server upcall. */ res.gid.gid_val = NULL; - arg.socookie = (uintptr_t)so; + arg.socookie = (uint64_t)so; stat = rpctlssd_connect_1(&arg, &res, cl); if (stat == RPC_SUCCESS) { #ifdef INVARIANTS diff --git a/sys/rpc/rpcsec_tls/rpctlscd.x b/sys/rpc/rpcsec_tls/rpctlscd.x --- a/sys/rpc/rpcsec_tls/rpctlscd.x +++ b/sys/rpc/rpcsec_tls/rpctlscd.x @@ -35,15 +35,10 @@ struct rpctlscd_connect_res { uint32_t reterr; - uint64_t sec; - uint64_t usec; - uint64_t ssl; }; struct rpctlscd_handlerecord_arg { - uint64_t sec; - uint64_t usec; - uint64_t ssl; + uint64_t socookie; }; struct rpctlscd_handlerecord_res { @@ -51,9 +46,7 @@ }; struct rpctlscd_disconnect_arg { - uint64_t sec; - uint64_t usec; - uint64_t ssl; + uint64_t socookie; }; struct rpctlscd_disconnect_res { @@ -72,5 +65,5 @@ rpctlscd_disconnect_res RPCTLSCD_DISCONNECT(rpctlscd_disconnect_arg) = 3; - } = 1; + } = 2; } = 0x40677374;