diff --git a/sys/kern/uipc_ktls.c b/sys/kern/uipc_ktls.c --- a/sys/kern/uipc_ktls.c +++ b/sys/kern/uipc_ktls.c @@ -302,6 +302,86 @@ static void ktls_work_thread(void *ctx); static void ktls_reclaim_thread(void *ctx); +int +copyin_tls_enable(struct sockopt *sopt, struct tls_enable *tls) +{ + struct tls_enable_v0 tls_v0; + int error; + uint8_t *cipher_key = NULL, *iv = NULL, *auth_key = NULL; + + if (sopt->sopt_valsize == sizeof(tls_v0)) { + error = sooptcopyin(sopt, &tls_v0, sizeof(tls_v0), sizeof(tls_v0)); + if (error) + goto done; + memset(tls, 0, sizeof(*tls)); + tls->cipher_key = tls_v0.cipher_key; + tls->iv = tls_v0.iv; + tls->auth_key = tls_v0.auth_key; + tls->cipher_algorithm = tls_v0.cipher_algorithm; + tls->cipher_key_len = tls_v0.cipher_key_len; + tls->iv_len = tls_v0.iv_len; + tls->auth_algorithm = tls_v0.auth_algorithm; + tls->auth_key_len = tls_v0.auth_key_len; + tls->flags = tls_v0.flags; + tls->tls_vmajor = tls_v0.tls_vmajor; + tls->tls_vminor = tls_v0.tls_vminor; + } else + error = sooptcopyin(sopt, tls, sizeof(*tls), sizeof(*tls)); + + if (error) + goto done; + + /* + * Now do a deep copy of the variable-length arrays in the struct, so that + * subsequent consumers of it can reliably assume kernel memory. This + * requires doing our own allocations, which we will free in the + * error paths so that our caller need only worry about outstanding + * allocations existing on successful return. + */ + cipher_key = malloc(tls->cipher_key_len, M_KTLS, M_WAITOK); + iv = malloc(tls->iv_len, M_KTLS, M_WAITOK); + auth_key = malloc(tls->auth_key_len, M_KTLS, M_WAITOK); + if (sopt->sopt_td != NULL) { + error = copyin(tls->cipher_key, cipher_key, tls->cipher_key_len); + if (error) + goto done; + error = copyin(tls->iv, iv, tls->iv_len); + if (error) + goto done; + error = copyin(tls->auth_key, auth_key, tls->auth_key_len); + if (error) + goto done; + } else { + bcopy(tls->cipher_key, cipher_key, tls->cipher_key_len); + bcopy(tls->iv, iv, tls->iv_len); + bcopy(tls->auth_key, auth_key, tls->auth_key_len); + } + tls->cipher_key = cipher_key; + tls->iv = iv; + + tls->auth_key = auth_key; + +done: + if (error) { + if (cipher_key) + zfree(cipher_key, M_KTLS); + if (iv) + zfree(iv, M_KTLS); + if (auth_key) + zfree(auth_key, M_KTLS); + } + + return (error); +} + +void +cleanup_tls_enable(struct tls_enable *tls) +{ + zfree(__DECONST(void *, tls->cipher_key), M_KTLS); + zfree(__DECONST(void *, tls->iv), M_KTLS); + zfree(__DECONST(void *, tls->auth_key), M_KTLS); +} + static u_int ktls_get_cpu(struct socket *so) { @@ -702,18 +782,12 @@ tls->params.auth_key_len = en->auth_key_len; tls->params.auth_key = malloc(en->auth_key_len, M_KTLS, M_WAITOK); - error = copyin(en->auth_key, tls->params.auth_key, - en->auth_key_len); - if (error) - goto out; + bcopy(en->auth_key, tls->params.auth_key, en->auth_key_len); } tls->params.cipher_key_len = en->cipher_key_len; tls->params.cipher_key = malloc(en->cipher_key_len, M_KTLS, M_WAITOK); - error = copyin(en->cipher_key, tls->params.cipher_key, - en->cipher_key_len); - if (error) - goto out; + bcopy(en->cipher_key, tls->params.cipher_key, en->cipher_key_len); /* * This holds the implicit portion of the nonce for AEAD @@ -722,9 +796,7 @@ */ if (en->iv_len != 0) { tls->params.iv_len = en->iv_len; - error = copyin(en->iv, tls->params.iv, en->iv_len); - if (error) - goto out; + bcopy(en->iv, tls->params.iv, en->iv_len); /* * For TLS 1.2 with GCM, generate an 8-byte nonce as a @@ -740,10 +812,6 @@ *tlsp = tls; return (0); - -out: - ktls_free(tls); - return (error); } static struct ktls_session * diff --git a/sys/netinet/tcp_usrreq.c b/sys/netinet/tcp_usrreq.c --- a/sys/netinet/tcp_usrreq.c +++ b/sys/netinet/tcp_usrreq.c @@ -1907,37 +1907,6 @@ CTASSERT(TCP_LOG_REASON_LEN <= TCP_LOG_ID_LEN); #endif -#ifdef KERN_TLS -static int -copyin_tls_enable(struct sockopt *sopt, struct tls_enable *tls) -{ - struct tls_enable_v0 tls_v0; - int error; - - if (sopt->sopt_valsize == sizeof(tls_v0)) { - error = sooptcopyin(sopt, &tls_v0, sizeof(tls_v0), - sizeof(tls_v0)); - if (error) - return (error); - memset(tls, 0, sizeof(*tls)); - tls->cipher_key = tls_v0.cipher_key; - tls->iv = tls_v0.iv; - tls->auth_key = tls_v0.auth_key; - tls->cipher_algorithm = tls_v0.cipher_algorithm; - tls->cipher_key_len = tls_v0.cipher_key_len; - tls->iv_len = tls_v0.iv_len; - tls->auth_algorithm = tls_v0.auth_algorithm; - tls->auth_key_len = tls_v0.auth_key_len; - tls->flags = tls_v0.flags; - tls->tls_vmajor = tls_v0.tls_vmajor; - tls->tls_vminor = tls_v0.tls_vminor; - return (0); - } - - return (sooptcopyin(sopt, tls, sizeof(*tls), sizeof(*tls))); -} -#endif - extern struct cc_algo newreno_cc_algo; static int @@ -2289,6 +2258,7 @@ if (error) break; error = ktls_enable_tx(so, &tls); + cleanup_tls_enable(&tls); break; case TCP_TXTLS_MODE: INP_WUNLOCK(inp); @@ -2302,11 +2272,11 @@ break; case TCP_RXTLS_ENABLE: INP_WUNLOCK(inp); - error = sooptcopyin(sopt, &tls, sizeof(tls), - sizeof(tls)); + error = copyin_tls_enable(sopt, &tls); if (error) break; error = ktls_enable_rx(so, &tls); + cleanup_tls_enable(&tls); break; #endif case TCP_MAXUNACKTIME: diff --git a/sys/sys/ktls.h b/sys/sys/ktls.h --- a/sys/sys/ktls.h +++ b/sys/sys/ktls.h @@ -174,6 +174,7 @@ struct mbuf; struct sockbuf; struct socket; +struct sockopt; struct ktls_session { struct ktls_ocf_session *ocf_session; @@ -212,6 +213,9 @@ KTLS_MBUF_CRYPTO_ST_DECRYPTED = -1, } ktls_mbuf_crypto_st_t; +int copyin_tls_enable(struct sockopt *sopt, struct tls_enable *tls); +void cleanup_tls_enable(struct tls_enable *tls); + void ktls_check_rx(struct sockbuf *sb); ktls_mbuf_crypto_st_t ktls_mbuf_crypto_state(struct mbuf *mb, int offset, int len); void ktls_disable_ifnet(void *arg);