diff --git a/tests/sys/kern/ktls_test.c b/tests/sys/kern/ktls_test.c --- a/tests/sys/kern/ktls_test.c +++ b/tests/sys/kern/ktls_test.c @@ -262,6 +262,68 @@ return (true); } +static bool +aead_encrypt(const EVP_CIPHER *cipher, const char *key, const char *nonce, + const void *aad, size_t aad_len, const char *input, char *output, + size_t size, char *tag, size_t tag_len) +{ + EVP_CIPHER_CTX *ctx; + int outl, total; + + ctx = EVP_CIPHER_CTX_new(); + if (ctx == NULL) { + warnx("EVP_CIPHER_CTX_new failed: %s", + ERR_error_string(ERR_get_error(), NULL)); + return (false); + } + if (EVP_EncryptInit_ex(ctx, cipher, NULL, (const u_char *)key, + (const u_char *)nonce) != 1) { + warnx("EVP_EncryptInit_ex failed: %s", + ERR_error_string(ERR_get_error(), NULL)); + EVP_CIPHER_CTX_free(ctx); + return (false); + } + EVP_CIPHER_CTX_set_padding(ctx, 0); + if (aad != NULL) { + if (EVP_EncryptUpdate(ctx, NULL, &outl, (const u_char *)aad, + aad_len) != 1) { + warnx("EVP_EncryptUpdate for AAD failed: %s", + ERR_error_string(ERR_get_error(), NULL)); + EVP_CIPHER_CTX_free(ctx); + return (false); + } + } + if (EVP_EncryptUpdate(ctx, (u_char *)output, &outl, + (const u_char *)input, size) != 1) { + warnx("EVP_EncryptUpdate failed: %s", + ERR_error_string(ERR_get_error(), NULL)); + EVP_CIPHER_CTX_free(ctx); + return (false); + } + total = outl; + if (EVP_EncryptFinal_ex(ctx, (u_char *)output + outl, &outl) != 1) { + warnx("EVP_EncryptFinal_ex failed: %s", + ERR_error_string(ERR_get_error(), NULL)); + EVP_CIPHER_CTX_free(ctx); + return (false); + } + total += outl; + if ((size_t)total != size) { + warnx("encrypt size mismatch: %zu vs %d", size, total); + EVP_CIPHER_CTX_free(ctx); + return (false); + } + if (EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_AEAD_GET_TAG, tag_len, tag) != + 1) { + warnx("EVP_CIPHER_CTX_ctrl(EVP_CTRL_AEAD_GET_TAG) failed: %s", + ERR_error_string(ERR_get_error(), NULL)); + EVP_CIPHER_CTX_free(ctx); + return (false); + } + EVP_CIPHER_CTX_free(ctx); + return (true); +} + static bool aead_decrypt(const EVP_CIPHER *cipher, const char *key, const char *nonce, const void *aad, size_t aad_len, const char *input, char *output, @@ -714,6 +776,68 @@ record_type)); } +/* + * Encrypt a TLS record of type 'record_type' with payload 'len' bytes + * long at 'src' and store the result at 'dst'. If 'dst' doesn't have + * sufficient room ('avail'), fail the test. + */ +static size_t +encrypt_tls_12_aead(struct tls_enable *en, uint8_t record_type, uint64_t seqno, + const void *src, size_t len, void *dst) +{ + struct tls_record_layer *hdr; + struct tls_aead_data aad; + char nonce[12]; + size_t hdr_len, mac_len, record_len; + + hdr = dst; + + hdr_len = tls_header_len(en); + mac_len = tls_mac_len(en); + record_len = hdr_len + len + mac_len; + + hdr->tls_type = record_type; + hdr->tls_vmajor = TLS_MAJOR_VER_ONE; + hdr->tls_vminor = TLS_MINOR_VER_TWO; + hdr->tls_length = htons(record_len - sizeof(*hdr)); + if (en->cipher_algorithm == CRYPTO_AES_NIST_GCM_16) + memcpy(hdr + 1, &seqno, sizeof(seqno)); + + tls_12_aead_aad(en, len, hdr, seqno, &aad); + if (en->cipher_algorithm == CRYPTO_AES_NIST_GCM_16) + tls_12_gcm_nonce(en, hdr, nonce); + else + tls_13_nonce(en, seqno, nonce); + + ATF_REQUIRE(aead_encrypt(tls_EVP_CIPHER(en), en->cipher_key, nonce, + &aad, sizeof(aad), src, (char *)dst + hdr_len, len, + (char *)dst + hdr_len + len, mac_len)); + + return (record_len); +} + +static size_t +encrypt_tls_aead(struct tls_enable *en, uint8_t record_type, uint64_t seqno, + const void *src, size_t len, void *dst, size_t avail) +{ + size_t record_len; + + record_len = tls_header_len(en) + len + tls_trailer_len(en); + ATF_REQUIRE(record_len <= avail); + + ATF_REQUIRE(encrypt_tls_12_aead(en, record_type, seqno, src, len, + dst) == record_len); + + return (record_len); +} + +static size_t +encrypt_tls_record(struct tls_enable *en, uint8_t record_type, uint64_t seqno, + const void *src, size_t len, void *dst, size_t avail) +{ + return (encrypt_tls_aead(en, record_type, seqno, src, len, dst, avail)); +} + static void test_ktls_transmit_app_data(struct tls_enable *en, uint64_t seqno, size_t len) { @@ -963,12 +1087,156 @@ close(sockets[0]); } +static size_t +ktls_receive_tls_record(struct tls_enable *en, int fd, uint8_t record_type, + void *data, size_t len) +{ + struct msghdr msg; + struct cmsghdr *cmsg; + struct tls_get_record *tgr; + char cbuf[CMSG_SPACE(sizeof(*tgr))]; + struct iovec iov; + ssize_t rv; + + memset(&msg, 0, sizeof(msg)); + + msg.msg_control = cbuf; + msg.msg_controllen = sizeof(cbuf); + + iov.iov_base = data; + iov.iov_len = len; + msg.msg_iov = &iov; + msg.msg_iovlen = 1; + + ATF_REQUIRE((rv = recvmsg(fd, &msg, 0)) > 0); + + ATF_REQUIRE((msg.msg_flags & (MSG_EOR | MSG_CTRUNC)) == MSG_EOR); + + cmsg = CMSG_FIRSTHDR(&msg); + ATF_REQUIRE(cmsg != NULL); + ATF_REQUIRE(cmsg->cmsg_level == IPPROTO_TCP); + ATF_REQUIRE(cmsg->cmsg_type == TLS_GET_RECORD); + ATF_REQUIRE(cmsg->cmsg_len == CMSG_LEN(sizeof(*tgr))); + + tgr = (struct tls_get_record *)CMSG_DATA(cmsg); + ATF_REQUIRE(tgr->tls_type == record_type); + ATF_REQUIRE(tgr->tls_vmajor == en->tls_vmajor); + ATF_REQUIRE(tgr->tls_vminor == en->tls_vminor); + ATF_REQUIRE(tgr->tls_length == htons(rv)); + + return (rv); +} + +static void +test_ktls_receive_app_data(struct tls_enable *en, uint64_t seqno, size_t len) +{ + struct kevent ev; + char *plaintext, *received, *outbuf; + size_t outbuf_cap, outbuf_len, outbuf_sent, received_len, todo, written; + ssize_t rv; + int kq, sockets[2]; + + plaintext = alloc_buffer(len); + received = malloc(len); + outbuf_cap = tls_header_len(en) + TLS_MAX_MSG_SIZE_V10_2 + + tls_trailer_len(en); + outbuf = malloc(outbuf_cap); + + ATF_REQUIRE((kq = kqueue()) != -1); + + ATF_REQUIRE_MSG(socketpair_tcp(sockets), "failed to create sockets"); + + ATF_REQUIRE(setsockopt(sockets[0], IPPROTO_TCP, TCP_RXTLS_ENABLE, en, + sizeof(*en)) == 0); + + EV_SET(&ev, sockets[0], EVFILT_READ, EV_ADD, 0, 0, NULL); + ATF_REQUIRE(kevent(kq, &ev, 1, NULL, 0, NULL) == 0); + EV_SET(&ev, sockets[1], EVFILT_WRITE, EV_ADD, 0, 0, NULL); + ATF_REQUIRE(kevent(kq, &ev, 1, NULL, 0, NULL) == 0); + + received_len = 0; + outbuf_len = 0; + written = 0; + + while (received_len != len) { + ATF_REQUIRE(kevent(kq, NULL, 0, &ev, 1, NULL) == 1); + + switch (ev.filter) { + case EVFILT_WRITE: + /* + * Compose the next TLS record to send. + */ + if (outbuf_len == 0) { + ATF_REQUIRE(written < len); + todo = len - written; + if (todo > TLS_MAX_MSG_SIZE_V10_2) + todo = TLS_MAX_MSG_SIZE_V10_2; + outbuf_len = encrypt_tls_record(en, + TLS_RLTYPE_APP, seqno, plaintext + written, + todo, outbuf, outbuf_cap); + outbuf_sent = 0; + written += todo; + seqno++; + } + + /* + * Try to write the remainder of the current + * TLS record. + */ + rv = write(ev.ident, outbuf + outbuf_sent, + outbuf_len - outbuf_sent); + ATF_REQUIRE_MSG(rv > 0, + "failed to write to socket"); + outbuf_sent += rv; + if (outbuf_sent == outbuf_len) { + outbuf_len = 0; + if (written == len) { + ev.flags = EV_DISABLE; + ATF_REQUIRE(kevent(kq, &ev, 1, NULL, 0, + NULL) == 0); + } + } + break; + + case EVFILT_READ: + ATF_REQUIRE((ev.flags & EV_EOF) == 0); + + rv = ktls_receive_tls_record(en, ev.ident, + TLS_RLTYPE_APP, received + received_len, + len - received_len); + received_len += rv; + break; + } + } + + ATF_REQUIRE_MSG(written == received_len, + "read %zu decrypted bytes, but wrote %zu", received_len, written); + + ATF_REQUIRE(memcmp(plaintext, received, len) == 0); + + free(outbuf); + free(received); + free(plaintext); + + close(sockets[1]); + close(sockets[0]); + close(kq); +} + #define TLS_10_TESTS(M) \ M(aes128_cbc_1_0_sha1, CRYPTO_AES_CBC, 128 / 8, \ CRYPTO_SHA1_HMAC) \ M(aes256_cbc_1_0_sha1, CRYPTO_AES_CBC, 256 / 8, \ CRYPTO_SHA1_HMAC) +#define TLS_12_TESTS(M) \ + M(aes128_gcm_1_2, CRYPTO_AES_NIST_GCM_16, 128 / 8, 0, \ + TLS_MINOR_VER_TWO) \ + M(aes256_gcm_1_2, CRYPTO_AES_NIST_GCM_16, 256 / 8, 0, \ + TLS_MINOR_VER_TWO) \ + M(chacha20_poly1305_1_2, CRYPTO_CHACHA20_POLY1305, 256 / 8, 0, \ + TLS_MINOR_VER_TWO) + #define AES_CBC_TESTS(M) \ M(aes128_cbc_1_0_sha1, CRYPTO_AES_CBC, 128 / 8, \ CRYPTO_SHA1_HMAC, TLS_MINOR_VER_ZERO) \ @@ -1251,8 +1519,57 @@ */ INVALID_CIPHER_SUITES(GEN_INVALID_TRANSMIT_TEST); +#define GEN_RECEIVE_APP_DATA_TEST(cipher_name, cipher_alg, key_size, \ + auth_alg, minor, name, len) \ +ATF_TC_WITHOUT_HEAD(ktls_receive_##cipher_name##_##name); \ +ATF_TC_BODY(ktls_receive_##cipher_name##_##name, tc) \ +{ \ + struct tls_enable en; \ + uint64_t seqno; \ + \ + ATF_REQUIRE_KTLS(); \ + seqno = random(); \ + build_tls_enable(cipher_alg, key_size, auth_alg, minor, seqno, \ + &en); \ + test_ktls_receive_app_data(&en, seqno, len); \ + free_tls_enable(&en); \ +} + +#define ADD_RECEIVE_APP_DATA_TEST(cipher_name, cipher_alg, key_size, \ + auth_alg, minor, name) \ + ATF_TP_ADD_TC(tp, ktls_receive_##cipher_name##_##name); + +#define GEN_RECEIVE_TESTS(cipher_name, cipher_alg, key_size, auth_alg, \ + minor) \ + GEN_RECEIVE_APP_DATA_TEST(cipher_name, cipher_alg, key_size, \ + auth_alg, minor, short, 64) \ + GEN_RECEIVE_APP_DATA_TEST(cipher_name, cipher_alg, key_size, \ + auth_alg, minor, long, 64 * 1024) + +#define ADD_RECEIVE_TESTS(cipher_name, cipher_alg, key_size, auth_alg, \ + minor) \ + ADD_RECEIVE_APP_DATA_TEST(cipher_name, cipher_alg, key_size, \ + auth_alg, minor, short) \ + ADD_RECEIVE_APP_DATA_TEST(cipher_name, cipher_alg, key_size, \ + auth_alg, minor, long) + +/* + * For each supported cipher suite, run two receive tests: + * + * - a short test which sends 64 bytes of application data (likely as + * a single TLS record) + * + * - a long test which sends 64KB of application data (split across + * multiple TLS records) + * + * Note that receive is currently only supported for TLS 1.2 AEAD + * cipher suites. + */ +TLS_12_TESTS(GEN_RECEIVE_TESTS); + ATF_TP_ADD_TCS(tp) { + /* Transmit tests */ AES_CBC_TESTS(ADD_TRANSMIT_TESTS); AES_GCM_TESTS(ADD_TRANSMIT_TESTS); CHACHA20_TESTS(ADD_TRANSMIT_TESTS); @@ -1260,5 +1577,7 @@ TLS_10_TESTS(ADD_TRANSMIT_EMPTY_FRAGMENT_TEST); INVALID_CIPHER_SUITES(ADD_INVALID_TRANSMIT_TEST); + /* Receive tests */ + TLS_12_TESTS(ADD_RECEIVE_TESTS); return (atf_no_error()); }