diff --git a/sys/kern/uipc_ktls.c b/sys/kern/uipc_ktls.c --- a/sys/kern/uipc_ktls.c +++ b/sys/kern/uipc_ktls.c @@ -2073,7 +2073,7 @@ SBCHECK(sb); SOCKBUF_UNLOCK(sb); - error = tls->sw_decrypt(tls, hdr, data, seqno, &trail_len); + error = ktls_ocf_decrypt(tls, hdr, data, seqno, &trail_len); if (error == 0) { if (tls13) error = tls13_find_record_type(tls, data, @@ -2262,7 +2262,7 @@ /* Anonymous mbufs are encrypted in place. */ if ((m->m_epg_flags & EPG_FLAG_ANON) != 0) - return (tls->sw_encrypt(state, tls, m, NULL, 0)); + return (ktls_ocf_encrypt(state, tls, m, NULL, 0)); /* * For file-backed mbufs (from sendfile), anonymous wired @@ -2292,7 +2292,7 @@ state->dst_iov[i].iov_base = m->m_epg_trail; state->dst_iov[i].iov_len = m->m_epg_trllen; - error = tls->sw_encrypt(state, tls, m, state->dst_iov, i + 1); + error = ktls_ocf_encrypt(state, tls, m, state->dst_iov, i + 1); if (__predict_false(error != 0)) { /* Free the anonymous pages. */ diff --git a/sys/opencrypto/ktls.h b/sys/opencrypto/ktls.h --- a/sys/opencrypto/ktls.h +++ b/sys/opencrypto/ktls.h @@ -49,5 +49,11 @@ void ktls_encrypt_cb(struct ktls_ocf_encrypt_state *state, int error); void ktls_ocf_free(struct ktls_session *tls); int ktls_ocf_try(struct socket *so, struct ktls_session *tls, int direction); +int ktls_ocf_encrypt(struct ktls_ocf_encrypt_state *state, + struct ktls_session *tls, struct mbuf *m, struct iovec *outiov, + int outiovcnt); +int ktls_ocf_decrypt(struct ktls_session *tls, + const struct tls_record_layer *hdr, struct mbuf *m, uint64_t seqno, + int *trailer_len); #endif /* !__OPENCRYPTO_KTLS_H__ */ diff --git a/sys/opencrypto/ktls_ocf.c b/sys/opencrypto/ktls_ocf.c --- a/sys/opencrypto/ktls_ocf.c +++ b/sys/opencrypto/ktls_ocf.c @@ -47,7 +47,20 @@ #include #include +struct ktls_ocf_sw { + /* Encrypt a single outbound TLS record. */ + int (*encrypt)(struct ktls_ocf_encrypt_state *state, + struct ktls_session *tls, struct mbuf *m, + struct iovec *outiov, int outiovcnt); + + /* Decrypt a received TLS record. */ + int (*decrypt)(struct ktls_session *tls, + const struct tls_record_layer *hdr, struct mbuf *m, + uint64_t seqno, int *trailer_len); +}; + struct ktls_ocf_session { + const struct ktls_ocf_sw *sw; crypto_session_t sid; crypto_session_t mac_sid; struct mtx lock; @@ -386,6 +399,10 @@ return (error); } +static const struct ktls_ocf_sw ktls_ocf_tls_cbc_sw = { + .encrypt = ktls_ocf_tls_cbc_encrypt +}; + static int ktls_ocf_tls12_aead_encrypt(struct ktls_ocf_encrypt_state *state, struct ktls_session *tls, struct mbuf *m, struct iovec *outiov, @@ -532,6 +549,11 @@ return (error); } +static const struct ktls_ocf_sw ktls_ocf_tls12_aead_sw = { + .encrypt = ktls_ocf_tls12_aead_encrypt, + .decrypt = ktls_ocf_tls12_aead_decrypt, +}; + static int ktls_ocf_tls13_aead_encrypt(struct ktls_ocf_encrypt_state *state, struct ktls_session *tls, struct mbuf *m, struct iovec *outiov, @@ -662,6 +684,11 @@ return (error); } +static const struct ktls_ocf_sw ktls_ocf_tls13_aead_sw = { + .encrypt = ktls_ocf_tls13_aead_encrypt, + .decrypt = ktls_ocf_tls13_aead_decrypt, +}; + void ktls_ocf_free(struct ktls_session *tls) { @@ -806,19 +833,12 @@ tls->ocf_session = os; if (tls->params.cipher_algorithm == CRYPTO_AES_NIST_GCM_16 || tls->params.cipher_algorithm == CRYPTO_CHACHA20_POLY1305) { - if (direction == KTLS_TX) { - if (tls->params.tls_vminor == TLS_MINOR_VER_THREE) - tls->sw_encrypt = ktls_ocf_tls13_aead_encrypt; - else - tls->sw_encrypt = ktls_ocf_tls12_aead_encrypt; - } else { - if (tls->params.tls_vminor == TLS_MINOR_VER_THREE) - tls->sw_decrypt = ktls_ocf_tls13_aead_decrypt; - else - tls->sw_decrypt = ktls_ocf_tls12_aead_decrypt; - } + if (tls->params.tls_vminor == TLS_MINOR_VER_THREE) + os->sw = &ktls_ocf_tls13_aead_sw; + else + os->sw = &ktls_ocf_tls12_aead_sw; } else { - tls->sw_encrypt = ktls_ocf_tls_cbc_encrypt; + os->sw = &ktls_ocf_tls_cbc_sw; if (tls->params.tls_vminor == TLS_MINOR_VER_ZERO) { os->implicit_iv = true; memcpy(os->iv, tls->params.iv, AES_BLOCK_LEN); @@ -837,3 +857,19 @@ tls->params.cipher_algorithm == CRYPTO_AES_CBC; return (0); } + +int +ktls_ocf_encrypt(struct ktls_ocf_encrypt_state *state, + struct ktls_session *tls, struct mbuf *m, struct iovec *outiov, + int outiovcnt) +{ + return (tls->ocf_session->sw->encrypt(state, tls, m, outiov, + outiovcnt)); +} + +int +ktls_ocf_decrypt(struct ktls_session *tls, const struct tls_record_layer *hdr, + struct mbuf *m, uint64_t seqno, int *trailer_len) +{ + return (tls->ocf_session->sw->decrypt(tls, hdr, m, seqno, trailer_len)); +} diff --git a/sys/sys/ktls.h b/sys/sys/ktls.h --- a/sys/sys/ktls.h +++ b/sys/sys/ktls.h @@ -167,8 +167,8 @@ #define KTLS_RX 2 struct iovec; -struct ktls_ocf_session; struct ktls_ocf_encrypt_state; +struct ktls_ocf_session; struct ktls_session; struct m_snd_tag; struct mbuf; @@ -176,14 +176,6 @@ struct socket; struct ktls_session { - union { - int (*sw_encrypt)(struct ktls_ocf_encrypt_state *state, - struct ktls_session *tls, struct mbuf *m, - struct iovec *outiov, int outiovcnt); - int (*sw_decrypt)(struct ktls_session *tls, - const struct tls_record_layer *hdr, struct mbuf *m, - uint64_t seqno, int *trailer_len); - }; struct ktls_ocf_session *ocf_session; struct m_snd_tag *snd_tag; struct tls_session_params params;