diff --git a/sys/modules/ktest/Makefile b/sys/modules/ktest/Makefile new file mode 100644 index 000000000000..21c94caabc30 --- /dev/null +++ b/sys/modules/ktest/Makefile @@ -0,0 +1,7 @@ +SYSDIR?=${SRCTOP}/sys +.include "${SYSDIR}/conf/kern.opts.mk" + +SUBDIR= ktest \ + ktest_example + +.include diff --git a/sys/modules/ktest/ktest/Makefile b/sys/modules/ktest/ktest/Makefile new file mode 100644 index 000000000000..86ed957ac2b7 --- /dev/null +++ b/sys/modules/ktest/ktest/Makefile @@ -0,0 +1,14 @@ +# $FreeBSD$ + +PACKAGE= tests + +SYSDIR?=${SRCTOP}/sys +.include "${SYSDIR}/conf/kern.opts.mk" + +.PATH: ${SYSDIR}/tests + +KMOD= ktest +SRCS= ktest.c +SRCS+= opt_netlink.h + +.include diff --git a/sys/modules/ktest/ktest_example/Makefile b/sys/modules/ktest/ktest_example/Makefile new file mode 100644 index 000000000000..b4a3e778e2ed --- /dev/null +++ b/sys/modules/ktest/ktest_example/Makefile @@ -0,0 +1,13 @@ +# $FreeBSD$ + +PACKAGE= tests + +SYSDIR?=${SRCTOP}/sys +.include "${SYSDIR}/conf/kern.opts.mk" + +.PATH: ${SYSDIR}/tests + +KMOD= ktest_example +SRCS= ktest_example.c + +.include diff --git a/sys/tests/ktest.c b/sys/tests/ktest.c new file mode 100644 index 000000000000..fcb40130bcef --- /dev/null +++ b/sys/tests/ktest.c @@ -0,0 +1,414 @@ +/*- + * SPDX-License-Identifier: BSD-2-Clause + * + * Copyright (c) 2023 Alexander V. Chernikov + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS + * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT + * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY + * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF + * SUCH DAMAGE. + */ + +#include "opt_netlink.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include + +struct mtx ktest_mtx; +#define KTEST_LOCK() mtx_lock(&ktest_mtx) +#define KTEST_UNLOCK() mtx_unlock(&ktest_mtx) +#define KTEST_LOCK_ASSERT() mtx_assert(&ktest_mtx, MA_OWNED) + +MTX_SYSINIT(ktest_mtx, &ktest_mtx, "ktest mutex", MTX_DEF); + +struct ktest_module { + struct ktest_module_info *info; + volatile u_int refcount; + TAILQ_ENTRY(ktest_module) entries; +}; +static TAILQ_HEAD(, ktest_module) module_list = TAILQ_HEAD_INITIALIZER(module_list); + +struct nl_ktest_parsed { + char *mod_name; + char *test_name; + struct nlattr *test_meta; +}; + +#define _IN(_field) offsetof(struct genlmsghdr, _field) +#define _OUT(_field) offsetof(struct nl_ktest_parsed, _field) + +static const struct nlattr_parser nla_p_get[] = { + { .type = KTEST_ATTR_MOD_NAME, .off = _OUT(mod_name), .cb = nlattr_get_string }, + { .type = KTEST_ATTR_TEST_NAME, .off = _OUT(test_name), .cb = nlattr_get_string }, + { .type = KTEST_ATTR_TEST_META, .off = _OUT(test_meta), .cb = nlattr_get_nla }, +}; +static const struct nlfield_parser nlf_p_get[] = { +}; +NL_DECLARE_PARSER(ktest_parser, struct genlmsghdr, nlf_p_get, nla_p_get); +#undef _IN +#undef _OUT + +static bool +create_reply(struct nl_writer *nw, struct nlmsghdr *hdr, int cmd) +{ + if (!nlmsg_reply(nw, hdr, sizeof(struct genlmsghdr))) + return (false); + + struct genlmsghdr *ghdr_new = nlmsg_reserve_object(nw, struct genlmsghdr); + ghdr_new->cmd = cmd; + ghdr_new->version = 0; + ghdr_new->reserved = 0; + + return (true); +} + +static int +dump_mod_test(struct nlmsghdr *hdr, struct nl_pstate *npt, + struct ktest_module *mod, const struct ktest_test_info *test_info) +{ + struct nl_writer *nw = npt->nw; + + if (!create_reply(nw, hdr, KTEST_CMD_NEWTEST)) + goto enomem; + + nlattr_add_string(nw, KTEST_ATTR_MOD_NAME, mod->info->name); + nlattr_add_string(nw, KTEST_ATTR_TEST_NAME, test_info->name); + nlattr_add_string(nw, KTEST_ATTR_TEST_DESCR, test_info->desc); + + if (nlmsg_end(nw)) + return (0); +enomem: + nlmsg_abort(nw); + return (ENOMEM); +} + +static int +dump_mod_tests(struct nlmsghdr *hdr, struct nl_pstate *npt, + struct ktest_module *mod, struct nl_ktest_parsed *attrs) +{ + for (int i = 0; i < mod->info->num_tests; i++) { + const struct ktest_test_info *test_info = &mod->info->tests[i]; + if (attrs->test_name != NULL && strcmp(attrs->test_name, test_info->name)) + continue; + int error = dump_mod_test(hdr, npt, mod, test_info); + if (error != 0) + return (error); + } + + return (0); +} + +static int +dump_tests(struct nlmsghdr *hdr, struct nl_pstate *npt) +{ + struct nl_ktest_parsed attrs = { }; + struct ktest_module *mod; + int error; + + error = nl_parse_nlmsg(hdr, &ktest_parser, npt, &attrs); + if (error != 0) + return (error); + + hdr->nlmsg_flags |= NLM_F_MULTI; + + KTEST_LOCK(); + TAILQ_FOREACH(mod, &module_list, entries) { + if (attrs.mod_name && strcmp(attrs.mod_name, mod->info->name)) + continue; + error = dump_mod_tests(hdr, npt, mod, &attrs); + if (error != 0) + break; + } + KTEST_UNLOCK(); + + if (!nlmsg_end_dump(npt->nw, error, hdr)) { + //NL_LOG(LOG_DEBUG, "Unable to finalize the dump"); + return (ENOMEM); + } + + return (error); +} + +static int +run_test(struct nlmsghdr *hdr, struct nl_pstate *npt) +{ + struct nl_ktest_parsed attrs = { }; + struct ktest_module *mod; + int error; + + error = nl_parse_nlmsg(hdr, &ktest_parser, npt, &attrs); + if (error != 0) + return (error); + + if (attrs.mod_name == NULL) { + nlmsg_report_err_msg(npt, "KTEST_ATTR_MOD_NAME not set"); + return (EINVAL); + } + + if (attrs.test_name == NULL) { + nlmsg_report_err_msg(npt, "KTEST_ATTR_TEST_NAME not set"); + return (EINVAL); + } + + const struct ktest_test_info *test = NULL; + + KTEST_LOCK(); + TAILQ_FOREACH(mod, &module_list, entries) { + if (strcmp(attrs.mod_name, mod->info->name)) + continue; + + const struct ktest_module_info *info = mod->info; + + for (int i = 0; i < info->num_tests; i++) { + const struct ktest_test_info *test_info = &info->tests[i]; + + if (!strcmp(attrs.test_name, test_info->name)) { + test = test_info; + break; + } + } + break; + } + if (test != NULL) + refcount_acquire(&mod->refcount); + KTEST_UNLOCK(); + + if (test == NULL) + return (ESRCH); + + /* Run the test */ + struct ktest_test_context ctx = { + .npt = npt, + .hdr = hdr, + .buf = npt_alloc(npt, KTEST_MAX_BUF), + .bufsize = KTEST_MAX_BUF, + }; + + if (ctx.buf == NULL) { + //NL_LOG(LOG_DEBUG, "unable to allocate temporary buffer"); + return (ENOMEM); + } + + if (test->parse != NULL && attrs.test_meta != NULL) { + error = test->parse(&ctx, attrs.test_meta); + if (error != 0) + return (error); + } + + hdr->nlmsg_flags |= NLM_F_MULTI; + + KTEST_LOG_LEVEL(&ctx, LOG_INFO, "start running %s", test->name); + error = test->func(&ctx); + KTEST_LOG_LEVEL(&ctx, LOG_INFO, "end running %s", test->name); + + refcount_release(&mod->refcount); + + if (!nlmsg_end_dump(npt->nw, error, hdr)) { + //NL_LOG(LOG_DEBUG, "Unable to finalize the dump"); + return (ENOMEM); + } + + return (error); +} + + +/* USER API */ +static void +register_test_module(struct ktest_module_info *info) +{ + struct ktest_module *mod = malloc(sizeof(*mod), M_TEMP, M_WAITOK | M_ZERO); + + mod->info = info; + info->module_ptr = mod; + KTEST_LOCK(); + TAILQ_INSERT_TAIL(&module_list, mod, entries); + KTEST_UNLOCK(); +} + +static void +unregister_test_module(struct ktest_module_info *info) +{ + struct ktest_module *mod = info->module_ptr; + + info->module_ptr = NULL; + + KTEST_LOCK(); + TAILQ_REMOVE(&module_list, mod, entries); + KTEST_UNLOCK(); + + free(mod, M_TEMP); +} + +static bool +can_unregister(struct ktest_module_info *info) +{ + struct ktest_module *mod = info->module_ptr; + + return (refcount_load(&mod->refcount) == 0); +} + +int +ktest_default_modevent(module_t mod, int type, void *arg) +{ + struct ktest_module_info *info = (struct ktest_module_info *)arg; + int error = 0; + + switch (type) { + case MOD_LOAD: + register_test_module(info); + break; + case MOD_UNLOAD: + if (!can_unregister(info)) + return (EBUSY); + unregister_test_module(info); + break; + default: + error = EOPNOTSUPP; + break; + } + return (error); +} + +bool +ktest_start_msg(struct ktest_test_context *ctx) +{ + return (create_reply(ctx->npt->nw, ctx->hdr, KTEST_CMD_NEWMESSAGE)); +} + +void +ktest_add_msg_meta(struct ktest_test_context *ctx, const char *func, + const char *fname, int line) +{ + struct nl_writer *nw = ctx->npt->nw; + struct timespec ts; + + nanouptime(&ts); + nlattr_add(nw, KTEST_MSG_ATTR_TS, sizeof(ts), &ts); + + nlattr_add_string(nw, KTEST_MSG_ATTR_FUNC, func); + nlattr_add_string(nw, KTEST_MSG_ATTR_FILE, fname); + nlattr_add_u32(nw, KTEST_MSG_ATTR_LINE, line); +} + +void +ktest_add_msg_text(struct ktest_test_context *ctx, int msg_level, + const char *fmt, ...) +{ + va_list ap; + + va_start(ap, fmt); + vsnprintf(ctx->buf, ctx->bufsize, fmt, ap); + va_end(ap); + + nlattr_add_u8(ctx->npt->nw, KTEST_MSG_ATTR_LEVEL, msg_level); + nlattr_add_string(ctx->npt->nw, KTEST_MSG_ATTR_TEXT, ctx->buf); +} + +void +ktest_end_msg(struct ktest_test_context *ctx) +{ + nlmsg_end(ctx->npt->nw); +} + +/* Module glue */ + +static const struct nlhdr_parser *all_parsers[] = { &ktest_parser }; + +static const struct genl_cmd ktest_cmds[] = { + { + .cmd_num = KTEST_CMD_LIST, + .cmd_name = "KTEST_CMD_LIST", + .cmd_cb = dump_tests, + .cmd_flags = GENL_CMD_CAP_DO | GENL_CMD_CAP_DUMP | GENL_CMD_CAP_HASPOL, + }, + { + .cmd_num = KTEST_CMD_RUN, + .cmd_name = "KTEST_CMD_RUN", + .cmd_cb = run_test, + .cmd_flags = GENL_CMD_CAP_DO | GENL_CMD_CAP_HASPOL, + .cmd_priv = PRIV_KLD_LOAD, + }, +}; + +static void +ktest_nl_register(void) +{ + bool ret __diagused; + int family_id __diagused; + + NL_VERIFY_PARSERS(all_parsers); + family_id = genl_register_family(KTEST_FAMILY_NAME, 0, 1, KTEST_CMD_MAX); + MPASS(family_id != 0); + + ret = genl_register_cmds(KTEST_FAMILY_NAME, ktest_cmds, NL_ARRAY_LEN(ktest_cmds)); + MPASS(ret); +} + +static void +ktest_nl_unregister(void) +{ + MPASS(TAILQ_EMPTY(&module_list)); + + genl_unregister_family(KTEST_FAMILY_NAME); +} + +static int +ktest_modevent(module_t mod, int type, void *unused) +{ + int error = 0; + + switch (type) { + case MOD_LOAD: + ktest_nl_register(); + break; + case MOD_UNLOAD: + ktest_nl_unregister(); + break; + default: + error = EOPNOTSUPP; + break; + } + return (error); +} + +static moduledata_t ktestmod = { + "ktest", + ktest_modevent, + 0 +}; + +DECLARE_MODULE(ktestmod, ktestmod, SI_SUB_PSEUDO, SI_ORDER_ANY); +MODULE_VERSION(ktestmod, 1); +MODULE_DEPEND(ktestmod, netlink, 1, 1, 1); + diff --git a/sys/tests/ktest.h b/sys/tests/ktest.h new file mode 100644 index 000000000000..feadb800551b --- /dev/null +++ b/sys/tests/ktest.h @@ -0,0 +1,141 @@ +/*- + * SPDX-License-Identifier: BSD-2-Clause + * + * Copyright (c) 2023 Alexander V. Chernikov + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS + * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT + * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY + * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF + * SUCH DAMAGE. + */ + +#ifndef SYS_TESTS_KTEST_H_ +#define SYS_TESTS_KTEST_H_ + +#ifdef _KERNEL + +#include +#include +#include +#include + +struct nlattr; +struct nl_pstate; +struct nlmsghdr; + +struct ktest_test_context { + void *arg; + struct nl_pstate *npt; + struct nlmsghdr *hdr; + char *buf; + size_t bufsize; +}; + +typedef int (*ktest_run_t)(struct ktest_test_context *ctx); +typedef int (*ktest_parse_t)(struct ktest_test_context *ctx, struct nlattr *container); + +struct ktest_test_info { + const char *name; + const char *desc; + ktest_run_t func; + ktest_parse_t parse; +}; + +struct ktest_module_info { + const char *name; + const struct ktest_test_info *tests; + int num_tests; + void *module_ptr; +}; + +int ktest_default_modevent(module_t mod, int type, void *arg); + +bool ktest_start_msg(struct ktest_test_context *ctx); +void ktest_add_msg_meta(struct ktest_test_context *ctx, const char *func, + const char *fname, int line); +void ktest_add_msg_text(struct ktest_test_context *ctx, int msg_level, + const char *fmt, ...); +void ktest_end_msg(struct ktest_test_context *ctx); + +#define KTEST_LOG_LEVEL(_ctx, _l, _fmt, ...) { \ + if (ktest_start_msg(_ctx)) { \ + ktest_add_msg_meta(_ctx, __func__, __FILE__, __LINE__); \ + ktest_add_msg_text(_ctx, _l, _fmt, ## __VA_ARGS__); \ + ktest_end_msg(_ctx); \ + } \ +} + +#define KTEST_LOG(_ctx, _fmt, ...) \ + KTEST_LOG_LEVEL(_ctx, LOG_DEBUG, _fmt, ## __VA_ARGS__) + +#define KTEST_MAX_BUF 512 + +#define KTEST_MODULE_DECLARE(_n, _t) \ +static struct ktest_module_info _module_info = { \ + .name = #_n, \ + .tests = _t, \ + .num_tests = nitems(_t), \ +}; \ + \ +static moduledata_t _module_data = { \ + "__" #_n "_module", \ + ktest_default_modevent, \ + &_module_info, \ +}; \ + \ +DECLARE_MODULE(ktest_##_n, _module_data, SI_SUB_PSEUDO, SI_ORDER_ANY); \ +MODULE_VERSION(ktest_##_n, 1); \ +MODULE_DEPEND(ktest_##_n, ktestmod, 1, 1, 1); \ + +#endif /* _KERNEL */ + +/* genetlink definitions */ +#define KTEST_FAMILY_NAME "ktest" + +/* commands */ +enum { + KTEST_CMD_UNSPEC = 0, + KTEST_CMD_LIST = 1, + KTEST_CMD_RUN = 2, + KTEST_CMD_NEWTEST = 3, + KTEST_CMD_NEWMESSAGE = 4, + __KTEST_CMD_MAX, +}; +#define KTEST_CMD_MAX (__KTEST_CMD_MAX - 1) + +enum ktest_attr_type_t { + KTEST_ATTR_UNSPEC, + KTEST_ATTR_MOD_NAME = 1, /* string: test module name */ + KTEST_ATTR_TEST_NAME = 2, /* string: test name */ + KTEST_ATTR_TEST_DESCR = 3, /* string: test description */ + KTEST_ATTR_TEST_META = 4, /* nested: container with test-specific metadata */ +}; + +enum ktest_msg_attr_type_t { + KTEST_MSG_ATTR_UNSPEC, + KTEST_MSG_ATTR_TS = 1, /* struct timespec */ + KTEST_MSG_ATTR_FUNC = 2, /* string: function name */ + KTEST_MSG_ATTR_FILE = 3, /* string: file name */ + KTEST_MSG_ATTR_LINE = 4, /* u32: line in the file */ + KTEST_MSG_ATTR_TEXT = 5, /* string: actual message data */ + KTEST_MSG_ATTR_LEVEL = 6, /* u8: syslog loglevel */ + KTEST_MSG_ATTR_META = 7, /* nested: message metadata */ +}; + +#endif diff --git a/sys/tests/ktest_example.c b/sys/tests/ktest_example.c new file mode 100644 index 000000000000..7cccaad7a855 --- /dev/null +++ b/sys/tests/ktest_example.c @@ -0,0 +1,134 @@ +/*- + * SPDX-License-Identifier: BSD-2-Clause + * + * Copyright (c) 2023 Alexander V. Chernikov + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS + * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) + * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT + * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY + * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF + * SUCH DAMAGE. + */ + +#include +#include +#include + + +static int +test_something(struct ktest_test_context *ctx) +{ + KTEST_LOG(ctx, "I'm here, [%s]", __func__); + + pause("sleeping...", hz / 10); + + KTEST_LOG(ctx, "done"); + + return (0); +} + +static int +test_something_else(struct ktest_test_context *ctx) +{ + return (0); +} + +static int +test_failed(struct ktest_test_context *ctx) +{ + return (EBUSY); +} + +static int +test_failed2(struct ktest_test_context *ctx) +{ + KTEST_LOG(ctx, "failed because it always fails"); + return (EBUSY); +} + +#include +#include +#include + +struct test1_attrs { + uint32_t arg1; + uint32_t arg2; + char *text; +}; + +#define _OUT(_field) offsetof(struct test1_attrs, _field) +static const struct nlattr_parser nla_p_test1[] = { + { .type = 1, .off = _OUT(arg1), .cb = nlattr_get_uint32 }, + { .type = 2, .off = _OUT(arg2), .cb = nlattr_get_uint32 }, + { .type = 3, .off = _OUT(text), .cb = nlattr_get_string }, +}; +#undef _OUT +NL_DECLARE_ATTR_PARSER(test1_parser, nla_p_test1); + +static int +test_with_params_parser(struct ktest_test_context *ctx, struct nlattr *nla) +{ + struct test1_attrs *attrs = npt_alloc(ctx->npt, sizeof(*attrs)); + + ctx->arg = attrs; + if (attrs != NULL) + return (nl_parse_nested(nla, &test1_parser, ctx->npt, attrs)); + return (ENOMEM); +} + +static int +test_with_params(struct ktest_test_context *ctx) +{ + struct test1_attrs *attrs = ctx->arg; + + if (attrs->text != NULL) + KTEST_LOG(ctx, "Get '%s'", attrs->text); + KTEST_LOG(ctx, "%u + %u = %u", attrs->arg1, attrs->arg2, + attrs->arg1 + attrs->arg2); + return (0); +} + +static const struct ktest_test_info tests[] = { + { + .name = "test_something", + .desc = "example description", + .func = &test_something, + }, + { + .name = "test_something_else", + .desc = "example description 2", + .func = &test_something_else, + }, + { + .name = "test_failed", + .desc = "always failing test", + .func = &test_failed, + }, + { + .name = "test_failed2", + .desc = "always failing test", + .func = &test_failed2, + }, + { + .name = "test_with_params", + .desc = "test summing integers", + .func = &test_with_params, + .parse = &test_with_params_parser, + }, +}; +KTEST_MODULE_DECLARE(ktest_example, tests); diff --git a/tests/atf_python/Makefile b/tests/atf_python/Makefile index 1a2fec387eda..889cdcdf9592 100644 --- a/tests/atf_python/Makefile +++ b/tests/atf_python/Makefile @@ -1,12 +1,12 @@ .include .PATH: ${.CURDIR} -FILES= __init__.py atf_pytest.py utils.py +FILES= __init__.py atf_pytest.py ktest.py utils.py SUBDIR= sys .include FILESDIR= ${TESTSBASE}/atf_python .include diff --git a/tests/atf_python/atf_pytest.py b/tests/atf_python/atf_pytest.py index 0dd3a225b73d..19b5f88fa200 100644 --- a/tests/atf_python/atf_pytest.py +++ b/tests/atf_python/atf_pytest.py @@ -1,292 +1,298 @@ import types from typing import Any from typing import Dict from typing import List from typing import NamedTuple from typing import Optional from typing import Tuple +from atf_python.ktest import generate_ktests from atf_python.utils import nodeid_to_method_name import pytest import os class ATFCleanupItem(pytest.Item): def runtest(self): """Runs cleanup procedure for the test instead of the test itself""" instance = self.parent.cls() cleanup_name = "cleanup_{}".format(nodeid_to_method_name(self.nodeid)) if hasattr(instance, cleanup_name): cleanup = getattr(instance, cleanup_name) cleanup(self.nodeid) elif hasattr(instance, "cleanup"): instance.cleanup(self.nodeid) def setup_method_noop(self, method): """Overrides runtest setup method""" pass def teardown_method_noop(self, method): """Overrides runtest teardown method""" pass class ATFTestObj(object): def __init__(self, obj, has_cleanup): # Use nodeid without name to properly name class-derived tests self.ident = obj.nodeid.split("::", 1)[1] self.description = self._get_test_description(obj) self.has_cleanup = has_cleanup self.obj = obj def _get_test_description(self, obj): """Returns first non-empty line from func docstring or func name""" + if getattr(obj, "descr", None) is not None: + return getattr(obj, "descr") docstr = obj.function.__doc__ if docstr: for line in docstr.split("\n"): if line: return line return obj.name @staticmethod def _convert_user_mark(mark, obj, ret: Dict): username = mark.args[0] if username == "unprivileged": # Special unprivileged user requested. # First, require the unprivileged-user config option presence key = "require.config" if key not in ret: ret[key] = "unprivileged_user" else: ret[key] = "{} {}".format(ret[key], "unprivileged_user") # Check if the framework requires root test_cls = ATFHandler.get_test_class(obj) if test_cls and getattr(test_cls, "NEED_ROOT", False): # Yes, so we ask kyua to run us under root instead # It is up to the implementation to switch back to the desired # user ret["require.user"] = "root" else: ret["require.user"] = username def _convert_marks(self, obj) -> Dict[str, Any]: wj_func = lambda x: " ".join(x) # noqa: E731 _map: Dict[str, Dict] = { "require_user": {"handler": self._convert_user_mark}, "require_arch": {"name": "require.arch", "fmt": wj_func}, "require_diskspace": {"name": "require.diskspace"}, "require_files": {"name": "require.files", "fmt": wj_func}, "require_machine": {"name": "require.machine", "fmt": wj_func}, "require_memory": {"name": "require.memory"}, "require_progs": {"name": "require.progs", "fmt": wj_func}, "timeout": {}, } ret = {} for mark in obj.iter_markers(): if mark.name in _map: if "handler" in _map[mark.name]: _map[mark.name]["handler"](mark, obj, ret) continue name = _map[mark.name].get("name", mark.name) if "fmt" in _map[mark.name]: val = _map[mark.name]["fmt"](mark.args[0]) else: val = mark.args[0] ret[name] = val return ret def as_lines(self) -> List[str]: """Output test definition in ATF-specific format""" ret = [] ret.append("ident: {}".format(self.ident)) ret.append("descr: {}".format(self._get_test_description(self.obj))) if self.has_cleanup: ret.append("has.cleanup: true") for key, value in self._convert_marks(self.obj).items(): ret.append("{}: {}".format(key, value)) return ret class ATFHandler(object): class ReportState(NamedTuple): state: str reason: str def __init__(self, report_file_name: Optional[str]): self._tests_state_map: Dict[str, ReportStatus] = {} self._report_file_name = report_file_name self._report_file_handle = None def setup_configure(self): fname = self._report_file_name if fname: self._report_file_handle = open(fname, mode="w") def setup_method_pre(self, item): """Called before actually running the test setup_method""" # Check if we need to manually drop the privileges for mark in item.iter_markers(): if mark.name == "require_user": cls = self.get_test_class(item) cls.TARGET_USER = mark.args[0] break def override_runtest(self, obj): # Override basic runtest command obj.runtest = types.MethodType(ATFCleanupItem.runtest, obj) # Override class setup/teardown obj.parent.cls.setup_method = ATFCleanupItem.setup_method_noop obj.parent.cls.teardown_method = ATFCleanupItem.teardown_method_noop @staticmethod def get_test_class(obj): if hasattr(obj, "parent") and obj.parent is not None: if hasattr(obj.parent, "cls"): return obj.parent.cls def has_object_cleanup(self, obj): cls = self.get_test_class(obj) if cls is not None: method_name = nodeid_to_method_name(obj.nodeid) cleanup_name = "cleanup_{}".format(method_name) if hasattr(cls, "cleanup") or hasattr(cls, cleanup_name): return True return False def _generate_test_cleanups(self, items): new_items = [] for obj in items: if self.has_object_cleanup(obj): self.override_runtest(obj) new_items.append(obj) items.clear() items.extend(new_items) + def expand_tests(self, collector, name, obj): + return generate_ktests(collector, name, obj) + def modify_tests(self, items, config): if config.option.atf_cleanup: self._generate_test_cleanups(items) def list_tests(self, tests: List[str]): print('Content-Type: application/X-atf-tp; version="1"') print() for test_obj in tests: has_cleanup = self.has_object_cleanup(test_obj) atf_test = ATFTestObj(test_obj, has_cleanup) for line in atf_test.as_lines(): print(line) print() def set_report_state(self, test_name: str, state: str, reason: str): self._tests_state_map[test_name] = self.ReportState(state, reason) def _extract_report_reason(self, report): data = report.longrepr if data is None: return None if isinstance(data, Tuple): # ('/path/to/test.py', 23, 'Skipped: unable to test') reason = data[2] for prefix in "Skipped: ": if reason.startswith(prefix): reason = reason[len(prefix):] return reason else: # string/ traceback / exception report. Capture the last line return str(data).split("\n")[-1] return None def add_report(self, report): # MAP pytest report state to the atf-desired state # # ATF test states: # (1) expected_death, (2) expected_exit, (3) expected_failure # (4) expected_signal, (5) expected_timeout, (6) passed # (7) skipped, (8) failed # # Note that ATF don't have the concept of "soft xfail" - xpass # is a failure. It also calls teardown routine in a separate # process, thus teardown states (pytest-only) are handled as # body continuation. # (stage, state, wasxfail) # Just a passing test: WANT: passed # GOT: (setup, passed, F), (call, passed, F), (teardown, passed, F) # # Failing body test: WHAT: failed # GOT: (setup, passed, F), (call, failed, F), (teardown, passed, F) # # pytest.skip test decorator: WANT: skipped # GOT: (setup,skipped, False), (teardown, passed, False) # # pytest.skip call inside test function: WANT: skipped # GOT: (setup, passed, F), (call, skipped, F), (teardown,passed, F) # # mark.xfail decorator+pytest.xfail: WANT: expected_failure # GOT: (setup, passed, F), (call, skipped, T), (teardown, passed, F) # # mark.xfail decorator+pass: WANT: failed # GOT: (setup, passed, F), (call, passed, T), (teardown, passed, F) test_name = report.location[2] stage = report.when state = report.outcome reason = self._extract_report_reason(report) # We don't care about strict xfail - it gets translated to False if stage == "setup": if state in ("skipped", "failed"): # failed init -> failed test, skipped setup -> xskip # for the whole test self.set_report_state(test_name, state, reason) elif stage == "call": # "call" stage shouldn't matter if setup failed if test_name in self._tests_state_map: if self._tests_state_map[test_name].state == "failed": return if state == "failed": # Record failure & override "skipped" state self.set_report_state(test_name, state, reason) elif state == "skipped": if hasattr(reason, "wasxfail"): # xfail() called in the test body state = "expected_failure" else: # skip inside the body pass self.set_report_state(test_name, state, reason) elif state == "passed": if hasattr(reason, "wasxfail"): # the test was expected to fail but didn't # mark as hard failure state = "failed" self.set_report_state(test_name, state, reason) elif stage == "teardown": if state == "failed": # teardown should be empty, as the cleanup # procedures should be implemented as a separate # function/method, so mark teardown failure as # global failure self.set_report_state(test_name, state, reason) def write_report(self): if self._report_file_handle is None: return if self._tests_state_map: # If we're executing in ATF mode, there has to be just one test # Anyway, deterministically pick the first one first_test_name = next(iter(self._tests_state_map)) test = self._tests_state_map[first_test_name] if test.state == "passed": line = test.state else: line = "{}: {}".format(test.state, test.reason) print(line, file=self._report_file_handle) self._report_file_handle.close() @staticmethod def get_atf_vars() -> Dict[str, str]: px = "_ATF_VAR_" return {k[len(px):]: v for k, v in os.environ.items() if k.startswith(px)} diff --git a/tests/atf_python/ktest.py b/tests/atf_python/ktest.py new file mode 100644 index 000000000000..4cd9970aaec1 --- /dev/null +++ b/tests/atf_python/ktest.py @@ -0,0 +1,173 @@ +import logging +import time +from typing import NamedTuple + +import pytest +from atf_python.sys.netlink.attrs import NlAttrNested +from atf_python.sys.netlink.attrs import NlAttrStr +from atf_python.sys.netlink.netlink import NetlinkMultipartIterator +from atf_python.sys.netlink.netlink import NlHelper +from atf_python.sys.netlink.netlink import Nlsock +from atf_python.sys.netlink.netlink_generic import KtestAttrType +from atf_python.sys.netlink.netlink_generic import KtestInfoMessage +from atf_python.sys.netlink.netlink_generic import KtestLogMsgType +from atf_python.sys.netlink.netlink_generic import KtestMsgAttrType +from atf_python.sys.netlink.netlink_generic import KtestMsgType +from atf_python.sys.netlink.netlink_generic import timespec +from atf_python.sys.netlink.utils import NlConst +from atf_python.utils import BaseTest +from atf_python.utils import libc +from atf_python.utils import nodeid_to_method_name + + +datefmt = "%H:%M:%S" +fmt = "%(asctime)s.%(msecs)03d %(filename)s:%(funcName)s:%(lineno)d %(message)s" +logging.basicConfig(level=logging.DEBUG, format=fmt, datefmt=datefmt) +logger = logging.getLogger("ktest") + + +NETLINK_FAMILY = "ktest" + + +class KtestItem(pytest.Item): + def __init__(self, *, descr, kcls, **kwargs): + super().__init__(**kwargs) + self.descr = descr + self._kcls = kcls + + def runtest(self): + self._kcls().runtest() + + +class KtestCollector(pytest.Class): + def collect(self): + obj = self.obj + exclude_names = set([n for n in dir(obj) if not n.startswith("_")]) + + autoload = obj.KTEST_MODULE_AUTOLOAD + module_name = obj.KTEST_MODULE_NAME + loader = KtestLoader(module_name, autoload) + ktests = loader.load_ktests() + if not ktests: + return + + orig = pytest.Class.from_parent(self.parent, name=self.name, obj=obj) + for py_test in orig.collect(): + yield py_test + + for ktest in ktests: + name = ktest["name"] + descr = ktest["desc"] + if name in exclude_names: + continue + yield KtestItem.from_parent(self, name=name, descr=descr, kcls=obj) + + +class KtestLoader(object): + def __init__(self, module_name: str, autoload: bool): + self.module_name = module_name + self.autoload = autoload + self.helper = NlHelper() + self.nlsock = Nlsock(NlConst.NETLINK_GENERIC, self.helper) + self.family_id = self._get_family_id() + + def _get_family_id(self): + try: + family_id = self.nlsock.get_genl_family_id(NETLINK_FAMILY) + except ValueError: + if self.autoload: + libc.kldload(self.module_name) + family_id = self.nlsock.get_genl_family_id(NETLINK_FAMILY) + else: + raise + return family_id + + def _load_ktests(self): + msg = KtestInfoMessage(self.helper, self.family_id, KtestMsgType.KTEST_CMD_LIST) + msg.set_request() + msg.add_nla(NlAttrStr(KtestAttrType.KTEST_ATTR_MOD_NAME, self.module_name)) + self.nlsock.write_message(msg, verbose=False) + nlmsg_seq = msg.nl_hdr.nlmsg_seq + + ret = [] + for rx_msg in NetlinkMultipartIterator(self.nlsock, nlmsg_seq, self.family_id): + # test_msg.print_message() + tst = { + "mod_name": rx_msg.get_nla(KtestAttrType.KTEST_ATTR_MOD_NAME).text, + "name": rx_msg.get_nla(KtestAttrType.KTEST_ATTR_TEST_NAME).text, + "desc": rx_msg.get_nla(KtestAttrType.KTEST_ATTR_TEST_DESCR).text, + } + ret.append(tst) + return ret + + def load_ktests(self): + ret = self._load_ktests() + if not ret and self.autoload: + libc.kldload(self.module_name) + ret = self._load_ktests() + return ret + + +def generate_ktests(collector, name, obj): + if getattr(obj, "KTEST_MODULE_NAME", None) is not None: + return KtestCollector.from_parent(collector, name=name, obj=obj) + return None + + +class BaseKernelTest(BaseTest): + KTEST_MODULE_AUTOLOAD = True + KTEST_MODULE_NAME = None + + def _get_record_time(self, msg) -> float: + timespec = msg.get_nla(KtestMsgAttrType.KTEST_MSG_ATTR_TS).ts + epoch_ktime = timespec.tv_sec * 1.0 + timespec.tv_nsec * 1.0 / 1000000000 + if not hasattr(self, "_start_epoch"): + self._start_ktime = epoch_ktime + self._start_time = time.time() + epoch_time = self._start_time + else: + epoch_time = time.time() - self._start_time + epoch_ktime + return epoch_time + + def _log_message(self, msg): + # Convert syslog-type l + syslog_level = msg.get_nla(KtestMsgAttrType.KTEST_MSG_ATTR_LEVEL).u8 + if syslog_level <= 6: + loglevel = logging.INFO + else: + loglevel = logging.DEBUG + rec = logging.LogRecord( + self.KTEST_MODULE_NAME, + loglevel, + msg.get_nla(KtestMsgAttrType.KTEST_MSG_ATTR_FILE).text, + msg.get_nla(KtestMsgAttrType.KTEST_MSG_ATTR_LINE).u32, + "%s", + (msg.get_nla(KtestMsgAttrType.KTEST_MSG_ATTR_TEXT).text), + None, + msg.get_nla(KtestMsgAttrType.KTEST_MSG_ATTR_FUNC).text, + None, + ) + rec.created = self._get_record_time(msg) + logger.handle(rec) + + def _runtest_name(self, test_name: str, test_data): + module_name = self.KTEST_MODULE_NAME + # print("Running kernel test {} for module {}".format(test_name, module_name)) + helper = NlHelper() + nlsock = Nlsock(NlConst.NETLINK_GENERIC, helper) + family_id = nlsock.get_genl_family_id(NETLINK_FAMILY) + msg = KtestInfoMessage(helper, family_id, KtestMsgType.KTEST_CMD_RUN) + msg.set_request() + msg.add_nla(NlAttrStr(KtestAttrType.KTEST_ATTR_MOD_NAME, module_name)) + msg.add_nla(NlAttrStr(KtestAttrType.KTEST_ATTR_TEST_NAME, test_name)) + if test_data is not None: + msg.add_nla(NlAttrNested(KtestAttrType.KTEST_ATTR_TEST_META, test_data)) + nlsock.write_message(msg, verbose=False) + + for log_msg in NetlinkMultipartIterator( + nlsock, msg.nl_hdr.nlmsg_seq, family_id + ): + self._log_message(log_msg) + + def runtest(self, test_data=None): + self._runtest_name(nodeid_to_method_name(self.test_id), test_data) diff --git a/tests/atf_python/sys/netlink/attrs.py b/tests/atf_python/sys/netlink/attrs.py index f6fe9ee43c98..58fbab7fc8db 100644 --- a/tests/atf_python/sys/netlink/attrs.py +++ b/tests/atf_python/sys/netlink/attrs.py @@ -1,287 +1,289 @@ import socket import struct from enum import Enum from atf_python.sys.netlink.utils import align4 from atf_python.sys.netlink.utils import enum_or_int class NlAttr(object): + HDR_LEN = 4 # sizeof(struct nlattr) + def __init__(self, nla_type, data): if isinstance(nla_type, Enum): self._nla_type = nla_type.value self._enum = nla_type else: self._nla_type = nla_type self._enum = None self.nla_list = [] self._data = data @property def nla_type(self): return self._nla_type & 0x3F @property def nla_len(self): return len(self._data) + 4 def add_nla(self, nla): self.nla_list.append(nla) def print_attr(self, prepend=""): if self._enum is not None: type_str = self._enum.name else: type_str = "nla#{}".format(self.nla_type) print( "{}len={} type={}({}){}".format( prepend, self.nla_len, type_str, self.nla_type, self._print_attr_value() ) ) @staticmethod def _validate(data): if len(data) < 4: raise ValueError("attribute too short") nla_len, nla_type = struct.unpack("@HH", data[:4]) if nla_len > len(data): raise ValueError("attribute length too big") if nla_len < 4: raise ValueError("attribute length too short") @classmethod def _parse(cls, data): nla_len, nla_type = struct.unpack("@HH", data[:4]) return cls(nla_type, data[4:]) @classmethod def from_bytes(cls, data, attr_type_enum=None): cls._validate(data) attr = cls._parse(data) attr._enum = attr_type_enum return attr def _to_bytes(self, data: bytes): ret = data if align4(len(ret)) != len(ret): ret = data + bytes(align4(len(ret)) - len(ret)) return struct.pack("@HH", len(data) + 4, self._nla_type) + ret def __bytes__(self): return self._to_bytes(self._data) def _print_attr_value(self): return " " + " ".join(["x{:02X}".format(b) for b in self._data]) class NlAttrNested(NlAttr): def __init__(self, nla_type, val): super().__init__(nla_type, b"") self.nla_list = val @property def nla_len(self): return align4(len(b"".join([bytes(nla) for nla in self.nla_list]))) + 4 def print_attr(self, prepend=""): if self._enum is not None: type_str = self._enum.name else: type_str = "nla#{}".format(self.nla_type) print( "{}len={} type={}({}) {{".format( prepend, self.nla_len, type_str, self.nla_type ) ) for nla in self.nla_list: nla.print_attr(prepend + " ") print("{}}}".format(prepend)) def __bytes__(self): return self._to_bytes(b"".join([bytes(nla) for nla in self.nla_list])) class NlAttrU32(NlAttr): def __init__(self, nla_type, val): self.u32 = enum_or_int(val) super().__init__(nla_type, b"") @property def nla_len(self): return 8 def _print_attr_value(self): return " val={}".format(self.u32) @staticmethod def _validate(data): assert len(data) == 8 nla_len, nla_type = struct.unpack("@HH", data[:4]) assert nla_len == 8 @classmethod def _parse(cls, data): nla_len, nla_type, val = struct.unpack("@HHI", data) return cls(nla_type, val) def __bytes__(self): return self._to_bytes(struct.pack("@I", self.u32)) class NlAttrU16(NlAttr): def __init__(self, nla_type, val): self.u16 = enum_or_int(val) super().__init__(nla_type, b"") @property def nla_len(self): return 6 def _print_attr_value(self): return " val={}".format(self.u16) @staticmethod def _validate(data): assert len(data) == 6 nla_len, nla_type = struct.unpack("@HH", data[:4]) assert nla_len == 6 @classmethod def _parse(cls, data): nla_len, nla_type, val = struct.unpack("@HHH", data) return cls(nla_type, val) def __bytes__(self): return self._to_bytes(struct.pack("@H", self.u16)) class NlAttrU8(NlAttr): def __init__(self, nla_type, val): self.u8 = enum_or_int(val) super().__init__(nla_type, b"") @property def nla_len(self): return 5 def _print_attr_value(self): return " val={}".format(self.u8) @staticmethod def _validate(data): assert len(data) == 5 nla_len, nla_type = struct.unpack("@HH", data[:4]) assert nla_len == 5 @classmethod def _parse(cls, data): nla_len, nla_type, val = struct.unpack("@HHB", data) return cls(nla_type, val) def __bytes__(self): return self._to_bytes(struct.pack("@B", self.u8)) class NlAttrIp(NlAttr): def __init__(self, nla_type, addr: str): super().__init__(nla_type, b"") self.addr = addr if ":" in self.addr: self.family = socket.AF_INET6 else: self.family = socket.AF_INET @staticmethod def _validate(data): nla_len, nla_type = struct.unpack("@HH", data[:4]) data_len = nla_len - 4 if data_len != 4 and data_len != 16: raise ValueError( "Error validating attr {}: nla_len is not valid".format( # noqa: E501 nla_type ) ) @property def nla_len(self): if self.family == socket.AF_INET6: return 20 else: return 8 return align4(len(self._data)) + 4 @classmethod def _parse(cls, data): nla_len, nla_type = struct.unpack("@HH", data[:4]) data_len = len(data) - 4 if data_len == 4: addr = socket.inet_ntop(socket.AF_INET, data[4:8]) else: addr = socket.inet_ntop(socket.AF_INET6, data[4:20]) return cls(nla_type, addr) def __bytes__(self): return self._to_bytes(socket.inet_pton(self.family, self.addr)) def _print_attr_value(self): return " addr={}".format(self.addr) class NlAttrStr(NlAttr): def __init__(self, nla_type, text): super().__init__(nla_type, b"") self.text = text @staticmethod def _validate(data): NlAttr._validate(data) try: data[4:].decode("utf-8") except Exception as e: raise ValueError("wrong utf-8 string: {}".format(e)) @property def nla_len(self): return len(self.text) + 5 @classmethod def _parse(cls, data): text = data[4:-1].decode("utf-8") nla_len, nla_type = struct.unpack("@HH", data[:4]) return cls(nla_type, text) def __bytes__(self): return self._to_bytes(bytes(self.text, encoding="utf-8") + bytes(1)) def _print_attr_value(self): return ' val="{}"'.format(self.text) class NlAttrStrn(NlAttr): def __init__(self, nla_type, text): super().__init__(nla_type, b"") self.text = text @staticmethod def _validate(data): NlAttr._validate(data) try: data[4:].decode("utf-8") except Exception as e: raise ValueError("wrong utf-8 string: {}".format(e)) @property def nla_len(self): return len(self.text) + 4 @classmethod def _parse(cls, data): text = data[4:].decode("utf-8") nla_len, nla_type = struct.unpack("@HH", data[:4]) return cls(nla_type, text) def __bytes__(self): return self._to_bytes(bytes(self.text, encoding="utf-8")) def _print_attr_value(self): return ' val="{}"'.format(self.text) diff --git a/tests/atf_python/sys/netlink/base_headers.py b/tests/atf_python/sys/netlink/base_headers.py index 759d8827fb3c..71771a249b3d 100644 --- a/tests/atf_python/sys/netlink/base_headers.py +++ b/tests/atf_python/sys/netlink/base_headers.py @@ -1,65 +1,72 @@ from ctypes import c_ubyte from ctypes import c_uint from ctypes import c_ushort from ctypes import Structure from enum import Enum class Nlmsghdr(Structure): _fields_ = [ ("nlmsg_len", c_uint), ("nlmsg_type", c_ushort), ("nlmsg_flags", c_ushort), ("nlmsg_seq", c_uint), ("nlmsg_pid", c_uint), ] +class Nlattr(Structure): + _fields_ = [ + ("nla_len", c_ushort), + ("nla_type", c_ushort), + ] + + class NlMsgType(Enum): NLMSG_NOOP = 1 NLMSG_ERROR = 2 NLMSG_DONE = 3 NLMSG_OVERRUN = 4 class NlmBaseFlags(Enum): NLM_F_REQUEST = 0x01 NLM_F_MULTI = 0x02 NLM_F_ACK = 0x04 NLM_F_ECHO = 0x08 NLM_F_DUMP_INTR = 0x10 NLM_F_DUMP_FILTERED = 0x20 # XXX: in python3.8 it is possible to # class NlmGetFlags(Enum, NlmBaseFlags): class NlmGetFlags(Enum): NLM_F_ROOT = 0x100 NLM_F_MATCH = 0x200 NLM_F_ATOMIC = 0x400 class NlmNewFlags(Enum): NLM_F_REPLACE = 0x100 NLM_F_EXCL = 0x200 NLM_F_CREATE = 0x400 NLM_F_APPEND = 0x800 class NlmDeleteFlags(Enum): NLM_F_NONREC = 0x100 class NlmAckFlags(Enum): NLM_F_CAPPED = 0x100 NLM_F_ACK_TLVS = 0x200 class GenlMsgHdr(Structure): _fields_ = [ ("cmd", c_ubyte), ("version", c_ubyte), ("reserved", c_ushort), ] diff --git a/tests/atf_python/sys/netlink/netlink.py b/tests/atf_python/sys/netlink/netlink.py index f813727d55b4..4bdefc2d5014 100644 --- a/tests/atf_python/sys/netlink/netlink.py +++ b/tests/atf_python/sys/netlink/netlink.py @@ -1,407 +1,407 @@ #!/usr/local/bin/python3 import os import socket import sys from ctypes import c_int from ctypes import c_ubyte from ctypes import c_uint from ctypes import c_ushort from ctypes import sizeof from ctypes import Structure from enum import auto from enum import Enum from atf_python.sys.netlink.attrs import NlAttr from atf_python.sys.netlink.attrs import NlAttrStr from atf_python.sys.netlink.attrs import NlAttrU32 from atf_python.sys.netlink.base_headers import GenlMsgHdr from atf_python.sys.netlink.base_headers import NlmBaseFlags from atf_python.sys.netlink.base_headers import Nlmsghdr from atf_python.sys.netlink.base_headers import NlMsgType from atf_python.sys.netlink.message import BaseNetlinkMessage from atf_python.sys.netlink.message import NlMsgCategory from atf_python.sys.netlink.message import NlMsgProps from atf_python.sys.netlink.message import StdNetlinkMessage -from atf_python.sys.netlink.netlink_generic import GenlCtrlMsgType from atf_python.sys.netlink.netlink_generic import GenlCtrlAttrType +from atf_python.sys.netlink.netlink_generic import GenlCtrlMsgType from atf_python.sys.netlink.netlink_generic import handler_classes as genl_classes from atf_python.sys.netlink.netlink_route import handler_classes as rt_classes from atf_python.sys.netlink.utils import align4 from atf_python.sys.netlink.utils import AttrDescr from atf_python.sys.netlink.utils import build_propmap from atf_python.sys.netlink.utils import enum_or_int from atf_python.sys.netlink.utils import get_bitmask_map from atf_python.sys.netlink.utils import NlConst from atf_python.sys.netlink.utils import prepare_attrs_map class SockaddrNl(Structure): _fields_ = [ ("nl_len", c_ubyte), ("nl_family", c_ubyte), ("nl_pad", c_ushort), ("nl_pid", c_uint), ("nl_groups", c_uint), ] class Nlmsgdone(Structure): _fields_ = [ ("error", c_int), ] class Nlmsgerr(Structure): _fields_ = [ ("error", c_int), ("msg", Nlmsghdr), ] class NlErrattrType(Enum): NLMSGERR_ATTR_UNUSED = 0 NLMSGERR_ATTR_MSG = auto() NLMSGERR_ATTR_OFFS = auto() NLMSGERR_ATTR_COOKIE = auto() NLMSGERR_ATTR_POLICY = auto() class AddressFamilyLinux(Enum): AF_INET = socket.AF_INET AF_INET6 = socket.AF_INET6 AF_NETLINK = 16 class AddressFamilyBsd(Enum): AF_INET = socket.AF_INET AF_INET6 = socket.AF_INET6 AF_NETLINK = 38 class NlHelper: def __init__(self): self._pmap = {} self._af_cls = self.get_af_cls() self._seq_counter = 1 self.pid = os.getpid() def get_seq(self): ret = self._seq_counter self._seq_counter += 1 return ret def get_af_cls(self): if sys.platform.startswith("freebsd"): cls = AddressFamilyBsd else: cls = AddressFamilyLinux return cls def get_propmap(self, cls): if cls not in self._pmap: self._pmap[cls] = build_propmap(cls) return self._pmap[cls] def get_name_propmap(self, cls): ret = {} for prop in dir(cls): if not prop.startswith("_"): ret[prop] = getattr(cls, prop).value return ret def get_attr_byval(self, cls, attr_val): propmap = self.get_propmap(cls) return propmap.get(attr_val) def get_af_name(self, family): v = self.get_attr_byval(self._af_cls, family) if v is not None: return v return "af#{}".format(family) def get_af_value(self, family_str: str) -> int: propmap = self.get_name_propmap(self._af_cls) return propmap.get(family_str) def get_bitmask_str(self, cls, val): bmap = get_bitmask_map(self.get_propmap(cls), val) return ",".join([v for k, v in bmap.items()]) @staticmethod def get_bitmask_str_uncached(cls, val): pmap = NlHelper.build_propmap(cls) bmap = NlHelper.get_bitmask_map(pmap, val) return ",".join([v for k, v in bmap.items()]) nldone_attrs = prepare_attrs_map([]) nlerr_attrs = prepare_attrs_map( [ AttrDescr(NlErrattrType.NLMSGERR_ATTR_MSG, NlAttrStr), AttrDescr(NlErrattrType.NLMSGERR_ATTR_OFFS, NlAttrU32), AttrDescr(NlErrattrType.NLMSGERR_ATTR_COOKIE, NlAttr), ] ) class NetlinkDoneMessage(StdNetlinkMessage): messages = [NlMsgProps(NlMsgType.NLMSG_DONE, NlMsgCategory.ACK)] nl_attrs_map = nldone_attrs @property def error_code(self): return self.base_hdr.error def parse_base_header(self, data): if len(data) < sizeof(Nlmsgdone): raise ValueError("length less than nlmsgdone header") done_hdr = Nlmsgdone.from_buffer_copy(data) sz = sizeof(Nlmsgdone) return (done_hdr, sz) def print_base_header(self, hdr, prepend=""): print("{}error={}".format(prepend, hdr.error)) class NetlinkErrorMessage(StdNetlinkMessage): messages = [NlMsgProps(NlMsgType.NLMSG_ERROR, NlMsgCategory.ACK)] nl_attrs_map = nlerr_attrs @property def error_code(self): return self.base_hdr.error @property def error_str(self): nla = self.get_nla(NlErrattrType.NLMSGERR_ATTR_MSG) if nla: return nla.text return None @property def error_offset(self): nla = self.get_nla(NlErrattrType.NLMSGERR_ATTR_OFFS) if nla: return nla.u32 return None @property def cookie(self): return self.get_nla(NlErrattrType.NLMSGERR_ATTR_COOKIE) def parse_base_header(self, data): if len(data) < sizeof(Nlmsgerr): raise ValueError("length less than nlmsgerr header") err_hdr = Nlmsgerr.from_buffer_copy(data) sz = sizeof(Nlmsgerr) if (self.nl_hdr.nlmsg_flags & 0x100) == 0: sz += align4(err_hdr.msg.nlmsg_len - sizeof(Nlmsghdr)) return (err_hdr, sz) def print_base_header(self, errhdr, prepend=""): print("{}error={}, ".format(prepend, errhdr.error), end="") hdr = errhdr.msg print( "{}len={}, type={}, flags={}(0x{:X}), seq={}, pid={}".format( prepend, hdr.nlmsg_len, "msg#{}".format(hdr.nlmsg_type), self.helper.get_bitmask_str(NlmBaseFlags, hdr.nlmsg_flags), hdr.nlmsg_flags, hdr.nlmsg_seq, hdr.nlmsg_pid, ) ) core_classes = { "netlink_core": [ NetlinkDoneMessage, NetlinkErrorMessage, ], } class Nlsock: HANDLER_CLASSES = [core_classes, rt_classes, genl_classes] def __init__(self, family, helper): self.helper = helper self.sock_fd = self._setup_netlink(family) self._sock_family = family self._data = bytes() self.msgmap = self.build_msgmap() self._family_map = { NlConst.GENL_ID_CTRL: "nlctrl", } def build_msgmap(self): handler_classes = {} for d in self.HANDLER_CLASSES: handler_classes.update(d) xmap = {} # 'family_name': [class.messages[MsgProps.msg], ] for family_id, family_classes in handler_classes.items(): xmap[family_id] = {} for cls in family_classes: for msg_props in cls.messages: xmap[family_id][enum_or_int(msg_props.msg)] = cls return xmap def _setup_netlink(self, netlink_family) -> int: family = self.helper.get_af_value("AF_NETLINK") s = socket.socket(family, socket.SOCK_RAW, netlink_family) s.setsockopt(270, 10, 1) # NETLINK_CAP_ACK s.setsockopt(270, 11, 1) # NETLINK_EXT_ACK return s def set_groups(self, mask: int): self.sock_fd.setsockopt(socket.SOL_SOCKET, 1, mask) # snl = SockaddrNl(nl_len = sizeof(SockaddrNl), nl_family=38, # nl_pid=self.pid, nl_groups=mask) # xbuffer = create_string_buffer(sizeof(SockaddrNl)) # memmove(xbuffer, addressof(snl), sizeof(SockaddrNl)) # k = struct.pack("@BBHII", 12, 38, 0, self.pid, mask) # self.sock_fd.bind(k) def write_message(self, msg, verbose=True): if verbose: print("vvvvvvvv OUT vvvvvvvv") msg.print_message() msg_bytes = bytes(msg) try: ret = os.write(self.sock_fd.fileno(), msg_bytes) assert ret == len(msg_bytes) except Exception as e: print("write({}) -> {}".format(len(msg_bytes), e)) def parse_message(self, data: bytes): if len(data) < sizeof(Nlmsghdr): raise Exception("Short read from nl: {} bytes".format(len(data))) hdr = Nlmsghdr.from_buffer_copy(data) if hdr.nlmsg_type < 16: family_name = "netlink_core" nlmsg_type = hdr.nlmsg_type elif self._sock_family == NlConst.NETLINK_ROUTE: family_name = "netlink_route" nlmsg_type = hdr.nlmsg_type else: # Genetlink if len(data) < sizeof(Nlmsghdr) + sizeof(GenlMsgHdr): raise Exception("Short read from genl: {} bytes".format(len(data))) family_name = self._family_map.get(hdr.nlmsg_type, "") ghdr = GenlMsgHdr.from_buffer_copy(data[sizeof(Nlmsghdr):]) nlmsg_type = ghdr.cmd cls = self.msgmap.get(family_name, {}).get(nlmsg_type) if not cls: cls = BaseNetlinkMessage return cls.from_bytes(self.helper, data) def get_genl_family_id(self, family_name): hdr = Nlmsghdr( nlmsg_type=NlConst.GENL_ID_CTRL, nlmsg_flags=NlmBaseFlags.NLM_F_REQUEST.value, nlmsg_seq = self.helper.get_seq(), ) ghdr = GenlMsgHdr(cmd=GenlCtrlMsgType.CTRL_CMD_GETFAMILY.value) nla = NlAttrStr(GenlCtrlAttrType.CTRL_ATTR_FAMILY_NAME, family_name) hdr.nlmsg_len = sizeof(Nlmsghdr) + sizeof(GenlMsgHdr) + len(bytes(nla)) msg_bytes = bytes(hdr) + bytes(ghdr) + bytes(nla) self.write_data(msg_bytes) while True: rx_msg = self.read_message() if hdr.nlmsg_seq == rx_msg.nl_hdr.nlmsg_seq: if rx_msg.is_type(NlMsgType.NLMSG_ERROR): if rx_msg.error_code != 0: raise ValueError("unable to get family {}".format(family_name)) else: family_id = rx_msg.get_nla(GenlCtrlAttrType.CTRL_ATTR_FAMILY_ID).u16 self._family_map[family_id] = family_name return family_id raise ValueError("unable to get family {}".format(family_name)) def write_data(self, data: bytes): self.sock_fd.send(data) def read_data(self): while True: data = self.sock_fd.recv(65535) self._data += data if len(self._data) >= sizeof(Nlmsghdr): break def read_message(self) -> bytes: if len(self._data) < sizeof(Nlmsghdr): self.read_data() hdr = Nlmsghdr.from_buffer_copy(self._data) while hdr.nlmsg_len > len(self._data): self.read_data() raw_msg = self._data[: hdr.nlmsg_len] self._data = self._data[hdr.nlmsg_len:] return self.parse_message(raw_msg) class NetlinkMultipartIterator(object): def __init__(self, obj, seq_number: int, msg_type): self._obj = obj self._seq = seq_number self._msg_type = msg_type def __iter__(self): return self def __next__(self): msg = self._obj.read_message() if self._seq != msg.nl_hdr.nlmsg_seq: raise ValueError("bad sequence number") if msg.is_type(NlMsgType.NLMSG_ERROR): raise ValueError( "error while handling multipart msg: {}".format(msg.error_code) ) elif msg.is_type(NlMsgType.NLMSG_DONE): if msg.error_code == 0: raise StopIteration raise ValueError( "error listing some parts of the multipart msg: {}".format( msg.error_code ) ) elif not msg.is_type(self._msg_type): raise ValueError("bad message type: {}".format(msg)) return msg class NetlinkTestTemplate(object): REQUIRED_MODULES = ["netlink"] def setup_netlink(self, netlink_family: NlConst): self.helper = NlHelper() self.nlsock = Nlsock(netlink_family, self.helper) def write_message(self, msg, silent=False): if not silent: print("") print("============= >> TX MESSAGE =============") msg.print_message() msg.print_as_bytes(bytes(msg), "-- DATA --") self.nlsock.write_data(bytes(msg)) def read_message(self, silent=False): msg = self.nlsock.read_message() if not silent: print("") print("============= << RX MESSAGE =============") msg.print_message() return msg def get_reply(self, tx_msg): self.write_message(tx_msg) while True: rx_msg = self.read_message() if tx_msg.nl_hdr.nlmsg_seq == rx_msg.nl_hdr.nlmsg_seq: return rx_msg def read_msg_list(self, seq, msg_type): return list(NetlinkMultipartIterator(self, seq, msg_type)) diff --git a/tests/atf_python/sys/netlink/netlink_generic.py b/tests/atf_python/sys/netlink/netlink_generic.py index ee75d5bf37f3..06dc8704fe07 100644 --- a/tests/atf_python/sys/netlink/netlink_generic.py +++ b/tests/atf_python/sys/netlink/netlink_generic.py @@ -1,110 +1,228 @@ #!/usr/local/bin/python3 +from ctypes import c_int64 +from ctypes import c_long from ctypes import sizeof +from ctypes import Structure from enum import Enum +import struct +from atf_python.sys.netlink.attrs import NlAttr from atf_python.sys.netlink.attrs import NlAttrStr from atf_python.sys.netlink.attrs import NlAttrU16 from atf_python.sys.netlink.attrs import NlAttrU32 +from atf_python.sys.netlink.attrs import NlAttrU8 from atf_python.sys.netlink.base_headers import GenlMsgHdr from atf_python.sys.netlink.message import NlMsgCategory from atf_python.sys.netlink.message import NlMsgProps from atf_python.sys.netlink.message import StdNetlinkMessage from atf_python.sys.netlink.utils import AttrDescr from atf_python.sys.netlink.utils import prepare_attrs_map from atf_python.sys.netlink.utils import enum_or_int class NetlinkGenlMessage(StdNetlinkMessage): messages = [] nl_attrs_map = {} family_name = None def __init__(self, helper, family_id, cmd=0): super().__init__(helper, family_id) self.base_hdr = GenlMsgHdr(cmd=enum_or_int(cmd)) def parse_base_header(self, data): if len(data) < sizeof(GenlMsgHdr): raise ValueError("length less than GenlMsgHdr header") ghdr = GenlMsgHdr.from_buffer_copy(data) return (ghdr, sizeof(GenlMsgHdr)) def _get_msg_type(self): return self.base_hdr.cmd def print_nl_header(self, prepend=""): # len=44, type=RTM_DELROUTE, flags=NLM_F_REQUEST|NLM_F_ACK, seq=1641163704, pid=0 # noqa: E501 hdr = self.nl_hdr print( "{}len={}, family={}, flags={}(0x{:X}), seq={}, pid={}".format( prepend, hdr.nlmsg_len, self.family_name, self.get_nlm_flags_str(), hdr.nlmsg_flags, hdr.nlmsg_seq, hdr.nlmsg_pid, ) ) def print_base_header(self, hdr, prepend=""): print( "{}cmd={} version={} reserved={}".format( prepend, self.msg_name, hdr.version, hdr.reserved ) ) GenlCtrlFamilyName = "nlctrl" class GenlCtrlMsgType(Enum): CTRL_CMD_UNSPEC = 0 CTRL_CMD_NEWFAMILY = 1 CTRL_CMD_DELFAMILY = 2 CTRL_CMD_GETFAMILY = 3 CTRL_CMD_NEWOPS = 4 CTRL_CMD_DELOPS = 5 CTRL_CMD_GETOPS = 6 CTRL_CMD_NEWMCAST_GRP = 7 CTRL_CMD_DELMCAST_GRP = 8 CTRL_CMD_GETMCAST_GRP = 9 CTRL_CMD_GETPOLICY = 10 class GenlCtrlAttrType(Enum): CTRL_ATTR_FAMILY_ID = 1 CTRL_ATTR_FAMILY_NAME = 2 CTRL_ATTR_VERSION = 3 CTRL_ATTR_HDRSIZE = 4 CTRL_ATTR_MAXATTR = 5 CTRL_ATTR_OPS = 6 CTRL_ATTR_MCAST_GROUPS = 7 CTRL_ATTR_POLICY = 8 CTRL_ATTR_OP_POLICY = 9 CTRL_ATTR_OP = 10 genl_ctrl_attrs = prepare_attrs_map( [ AttrDescr(GenlCtrlAttrType.CTRL_ATTR_FAMILY_ID, NlAttrU16), AttrDescr(GenlCtrlAttrType.CTRL_ATTR_FAMILY_NAME, NlAttrStr), AttrDescr(GenlCtrlAttrType.CTRL_ATTR_VERSION, NlAttrU32), AttrDescr(GenlCtrlAttrType.CTRL_ATTR_HDRSIZE, NlAttrU32), AttrDescr(GenlCtrlAttrType.CTRL_ATTR_MAXATTR, NlAttrU32), ] ) class NetlinkGenlCtrlMessage(NetlinkGenlMessage): messages = [ NlMsgProps(GenlCtrlMsgType.CTRL_CMD_NEWFAMILY, NlMsgCategory.NEW), NlMsgProps(GenlCtrlMsgType.CTRL_CMD_GETFAMILY, NlMsgCategory.GET), NlMsgProps(GenlCtrlMsgType.CTRL_CMD_DELFAMILY, NlMsgCategory.DELETE), ] nl_attrs_map = genl_ctrl_attrs family_name = GenlCtrlFamilyName +KtestFamilyName = "ktest" + + +class KtestMsgType(Enum): + KTEST_CMD_UNSPEC = 0 + KTEST_CMD_LIST = 1 + KTEST_CMD_RUN = 2 + KTEST_CMD_NEWTEST = 3 + KTEST_CMD_NEWMESSAGE = 4 + + +class KtestAttrType(Enum): + KTEST_ATTR_MOD_NAME = 1 + KTEST_ATTR_TEST_NAME = 2 + KTEST_ATTR_TEST_DESCR = 3 + KTEST_ATTR_TEST_META = 4 + + +class KtestLogMsgType(Enum): + KTEST_MSG_START = 1 + KTEST_MSG_END = 2 + KTEST_MSG_LOG = 3 + KTEST_MSG_FAIL = 4 + + +class KtestMsgAttrType(Enum): + KTEST_MSG_ATTR_TS = 1 + KTEST_MSG_ATTR_FUNC = 2 + KTEST_MSG_ATTR_FILE = 3 + KTEST_MSG_ATTR_LINE = 4 + KTEST_MSG_ATTR_TEXT = 5 + KTEST_MSG_ATTR_LEVEL = 6 + KTEST_MSG_ATTR_META = 7 + + +class timespec(Structure): + _fields_ = [ + ("tv_sec", c_int64), + ("tv_nsec", c_long), + ] + + +class NlAttrTS(NlAttr): + DATA_LEN = sizeof(timespec) + + def __init__(self, nla_type, val): + self.ts = val + super().__init__(nla_type, b"") + + @property + def nla_len(self): + return NlAttr.HDR_LEN + self.DATA_LEN + + def _print_attr_value(self): + return " tv_sec={} tv_nsec={}".format(self.ts.tv_sec, self.ts.tv_nsec) + + @staticmethod + def _validate(data): + assert len(data) == NlAttr.HDR_LEN + NlAttrTS.DATA_LEN + nla_len, nla_type = struct.unpack("@HH", data[:NlAttr.HDR_LEN]) + assert nla_len == NlAttr.HDR_LEN + NlAttrTS.DATA_LEN + + @classmethod + def _parse(cls, data): + nla_len, nla_type = struct.unpack("@HH", data[:NlAttr.HDR_LEN]) + val = timespec.from_buffer_copy(data[NlAttr.HDR_LEN:]) + return cls(nla_type, val) + + def __bytes__(self): + return self._to_bytes(bytes(self.ts)) + + +ktest_info_attrs = prepare_attrs_map( + [ + AttrDescr(KtestAttrType.KTEST_ATTR_MOD_NAME, NlAttrStr), + AttrDescr(KtestAttrType.KTEST_ATTR_TEST_NAME, NlAttrStr), + AttrDescr(KtestAttrType.KTEST_ATTR_TEST_DESCR, NlAttrStr), + ] +) + + +ktest_msg_attrs = prepare_attrs_map( + [ + AttrDescr(KtestMsgAttrType.KTEST_MSG_ATTR_FUNC, NlAttrStr), + AttrDescr(KtestMsgAttrType.KTEST_MSG_ATTR_FILE, NlAttrStr), + AttrDescr(KtestMsgAttrType.KTEST_MSG_ATTR_LINE, NlAttrU32), + AttrDescr(KtestMsgAttrType.KTEST_MSG_ATTR_TEXT, NlAttrStr), + AttrDescr(KtestMsgAttrType.KTEST_MSG_ATTR_LEVEL, NlAttrU8), + AttrDescr(KtestMsgAttrType.KTEST_MSG_ATTR_TS, NlAttrTS), + ] +) + + +class KtestInfoMessage(NetlinkGenlMessage): + messages = [ + NlMsgProps(KtestMsgType.KTEST_CMD_LIST, NlMsgCategory.GET), + NlMsgProps(KtestMsgType.KTEST_CMD_RUN, NlMsgCategory.NEW), + NlMsgProps(KtestMsgType.KTEST_CMD_NEWTEST, NlMsgCategory.NEW), + ] + nl_attrs_map = ktest_info_attrs + family_name = KtestFamilyName + + +class KtestMsgMessage(NetlinkGenlMessage): + messages = [ + NlMsgProps(KtestMsgType.KTEST_CMD_NEWMESSAGE, NlMsgCategory.NEW), + ] + nl_attrs_map = ktest_msg_attrs + family_name = KtestFamilyName + + handler_classes = { GenlCtrlFamilyName: [NetlinkGenlCtrlMessage], + KtestFamilyName: [KtestInfoMessage, KtestMsgMessage], } diff --git a/tests/atf_python/utils.py b/tests/atf_python/utils.py index 591a532ca476..1c0a68dad383 100644 --- a/tests/atf_python/utils.py +++ b/tests/atf_python/utils.py @@ -1,78 +1,83 @@ #!/usr/bin/env python3 import os import pwd from ctypes import CDLL from ctypes import get_errno from ctypes.util import find_library from typing import Dict from typing import List from typing import Optional import pytest def nodeid_to_method_name(nodeid: str) -> str: """file_name.py::ClassName::method_name[parametrize] -> method_name""" return nodeid.split("::")[-1].split("[")[0] class LibCWrapper(object): def __init__(self): path: Optional[str] = find_library("c") if path is None: raise RuntimeError("libc not found") self._libc = CDLL(path, use_errno=True) def modfind(self, mod_name: str) -> int: if self._libc.modfind(bytes(mod_name, encoding="ascii")) == -1: return get_errno() return 0 + def kldload(self, kld_name: str) -> int: + if self._libc.kldload(bytes(kld_name, encoding="ascii")) == -1: + return get_errno() + return 0 + def jail_attach(self, jid: int) -> int: if self._libc.jail_attach(jid) != 0: return get_errno() return 0 libc = LibCWrapper() class BaseTest(object): NEED_ROOT: bool = False # True if the class needs root privileges for the setup TARGET_USER = None # Set to the target user by the framework REQUIRED_MODULES: List[str] = [] def _check_modules(self): for mod_name in self.REQUIRED_MODULES: error_code = libc.modfind(mod_name) if error_code != 0: err_str = os.strerror(error_code) pytest.skip( "kernel module '{}' not available: {}".format(mod_name, err_str) ) @property def atf_vars(self) -> Dict[str, str]: px = "_ATF_VAR_" return {k[len(px):]: v for k, v in os.environ.items() if k.startswith(px)} def drop_privileges_user(self, user: str): uid = pwd.getpwnam(user)[2] print("Dropping privs to {}/{}".format(user, uid)) os.setuid(uid) def drop_privileges(self): if self.TARGET_USER: if self.TARGET_USER == "unprivileged": user = self.atf_vars["unprivileged-user"] else: user = self.TARGET_USER self.drop_privileges_user(user) @property def test_id(self) -> str: # 'test_ip6_output.py::TestIP6Output::test_output6_pktinfo[ipandif] (setup)' return os.environ.get("PYTEST_CURRENT_TEST").split(" ")[0] def setup_method(self, method): """Run all pre-requisits for the test execution""" self._check_modules() diff --git a/tests/conftest.py b/tests/conftest.py index 5d319863af73..8e3c004b74d6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,130 +1,136 @@ import pytest from atf_python.atf_pytest import ATFHandler from typing import Dict PLUGIN_ENABLED = False DEFAULT_HANDLER = None def set_handler(config): global DEFAULT_HANDLER, PLUGIN_ENABLED DEFAULT_HANDLER = ATFHandler(report_file_name=config.option.atf_file) PLUGIN_ENABLED = True return DEFAULT_HANDLER def get_handler(): return DEFAULT_HANDLER def pytest_addoption(parser): """Add file output""" # Add meta-values group = parser.getgroup("general", "Running and selection options") group.addoption( "--atf-source-dir", type=str, dest="atf_source_dir", help="Path to the test source directory", ) group.addoption( "--atf-cleanup", default=False, action="store_true", dest="atf_cleanup", help="Call cleanup procedure for a given test", ) group = parser.getgroup("terminal reporting", "reporting", after="general") group.addoption( "--atf", default=False, action="store_true", help="Enable test listing/results output in atf format", ) group.addoption( "--atf-file", type=str, dest="atf_file", help="Path to the status file provided by atf runtime", ) @pytest.fixture(autouse=True, scope="session") def atf_vars() -> Dict[str, str]: return ATFHandler.get_atf_vars() @pytest.hookimpl(trylast=True) def pytest_configure(config): if config.option.help: return # Register markings anyway to avoid warnings config.addinivalue_line("markers", "require_user(name): user to run the test with") config.addinivalue_line( "markers", "require_arch(names): List[str] of support archs" ) # config.addinivalue_line("markers", "require_config(config): List[Tuple[str,Any]] of k=v pairs") config.addinivalue_line( "markers", "require_diskspace(amount): str with required diskspace" ) config.addinivalue_line( "markers", "require_files(space): List[str] with file paths" ) config.addinivalue_line( "markers", "require_machine(names): List[str] of support machine types" ) config.addinivalue_line( "markers", "require_memory(amount): str with required memory" ) config.addinivalue_line( "markers", "require_progs(space): List[str] with file paths" ) config.addinivalue_line( "markers", "timeout(dur): int/float with max duration in sec" ) if not config.option.atf: return handler = set_handler(config) if config.option.collectonly: # Need to output list of tests to stdout, hence override # standard reporter plugin reporter = config.pluginmanager.getplugin("terminalreporter") if reporter: config.pluginmanager.unregister(reporter) else: handler.setup_configure() +def pytest_pycollect_makeitem(collector, name, obj): + if PLUGIN_ENABLED: + handler = get_handler() + return handler.expand_tests(collector, name, obj) + + def pytest_collection_modifyitems(session, config, items): """If cleanup is requested, replace collected tests with their cleanups (if any)""" if PLUGIN_ENABLED: handler = get_handler() handler.modify_tests(items, config) def pytest_collection_finish(session): if PLUGIN_ENABLED and session.config.option.collectonly: handler = get_handler() handler.list_tests(session.items) def pytest_runtest_setup(item): if PLUGIN_ENABLED: handler = get_handler() handler.setup_method_pre(item) def pytest_runtest_logreport(report): if PLUGIN_ENABLED: handler = get_handler() handler.add_report(report) def pytest_unconfigure(config): if PLUGIN_ENABLED and config.option.atf_file: handler = get_handler() handler.write_report() diff --git a/tests/examples/Makefile b/tests/examples/Makefile index 7a5d84a98dfe..6bb87b300ee7 100644 --- a/tests/examples/Makefile +++ b/tests/examples/Makefile @@ -1,10 +1,11 @@ # $FreeBSD$ PACKAGE= tests TESTSDIR= ${TESTSBASE}/examples ATF_TESTS_PYTEST += test_examples.py +ATF_TESTS_PYTEST += test_ktest_example.py .include diff --git a/tests/examples/test_ktest_example.py b/tests/examples/test_ktest_example.py new file mode 100644 index 000000000000..c11f178cb054 --- /dev/null +++ b/tests/examples/test_ktest_example.py @@ -0,0 +1,35 @@ +import pytest + +from atf_python.ktest import BaseKernelTest + +from atf_python.sys.netlink.attrs import NlAttrStr +from atf_python.sys.netlink.attrs import NlAttrU32 + + +class TestExample(BaseKernelTest): + KTEST_MODULE_NAME = "ktest_example" + + @pytest.mark.parametrize( + "numbers", + [ + pytest.param([1, 2], id="1_2_Sum"), + pytest.param([3, 4], id="3_4_Sum"), + ], + ) + def test_with_params(self, numbers): + """override to parametrize""" + + test_meta = [ + NlAttrU32(1, numbers[0]), + NlAttrU32(2, numbers[1]), + NlAttrStr(3, "test string"), + ] + self.runtest(test_meta) + + @pytest.mark.skip(reason="comment me ( or delete the func) to run the test") + def test_failed(self): + pass + + @pytest.mark.skip(reason="comment me ( or delete the func) to run the test") + def test_failed2(self): + pass