Index: share/mk/atf.test.mk =================================================================== --- share/mk/atf.test.mk +++ share/mk/atf.test.mk @@ -22,6 +22,7 @@ ATF_TESTS_CXX?= ATF_TESTS_SH?= ATF_TESTS_KSH93?= +ATF_TESTS_PY?= .if !empty(ATF_TESTS_C) PROGS+= ${ATF_TESTS_C} @@ -109,3 +110,29 @@ mv ${.TARGET}.tmp ${.TARGET} .endfor .endif + +.if !empty(ATF_TESTS_PY) +SCRIPTS+= ${ATF_TESTS_PY} +_TESTS+= ${ATF_TESTS_PY} +.for _T in ${ATF_TESTS_PY} +SCRIPTSDIR_${_T}= ${TESTSDIR} +TEST_INTERFACE.${_T}= atf +TEST_METADATA.${_T}+= required_programs="python3" +CLEANFILES+= ${_T} ${_T}.tmp +# TODO(jmmv): It seems to me that this SED and SRC functionality should +# exist in bsd.prog.mk along the support for SCRIPTS. Move it there if +# this proves to be useful within the tests. +ATF_TESTS_PY_SED_${_T}?= # empty +ATF_TESTS_PY_SRC_${_T}?= ${_T}.py +${_T}: ${ATF_TESTS_PY_SRC_${_T}} + echo "#!/usr/bin/env -S PYTHONPATH=/usr/tests/sys python3" > ${.TARGET}.tmp +.if empty(ATF_TESTS_PY_SED_${_T}) + cat ${.ALLSRC:N*Makefile*} >>${.TARGET}.tmp +.else + cat ${.ALLSRC:N*Makefile*} \ + | sed ${ATF_TESTS_PY_SED_${_T}} >>${.TARGET}.tmp +.endif + chmod +x ${.TARGET}.tmp + mv ${.TARGET}.tmp ${.TARGET} +.endfor +.endif Index: tests/sys/Makefile =================================================================== --- tests/sys/Makefile +++ tests/sys/Makefile @@ -31,6 +31,7 @@ TESTS_SUBDIRS+= sys TESTS_SUBDIRS+= vfs TESTS_SUBDIRS+= vm +TESTS_SUBDIRS= net .if ${MK_AUDIT} != "no" _audit= audit @@ -45,4 +46,7 @@ SUBDIR+= common +PACKAGE= tests +${PACKAGE}FILES+= __init__.py + .include Index: tests/sys/common/Makefile =================================================================== --- tests/sys/common/Makefile +++ tests/sys/common/Makefile @@ -6,6 +6,9 @@ ${PACKAGE}FILES+= divert.py ${PACKAGE}FILES+= sender.py ${PACKAGE}FILES+= net_receiver.py +${PACKAGE}FILES+= atf.py +${PACKAGE}FILES+= vnet.py +${PACKAGE}FILES+= __init__.py ${PACKAGE}FILESMODE_divert.py=0555 ${PACKAGE}FILESMODE_sender.py=0555 Index: tests/sys/common/atf.py =================================================================== --- /dev/null +++ tests/sys/common/atf.py @@ -0,0 +1,155 @@ +#!/usr/local/bin/python3 + +import argparse +import unittest +import inspect +import sys +import traceback +from enum import auto, Enum + +from typing import Dict, List, Optional + + +class ATFTestResult(Enum): + FAILED = auto() + PASSED = auto() + SKIPPED = auto() + + +class ATFTestTemplate(unittest.TestCase): + description = "Base ATF test" + options = {} + + def __init__(self, params: Dict): + self._params = params + + def log(self, msg: str): + print(msg, file=sys.stderr) + + @classmethod + def test_name(cls): + return camel_to_lc(cls.__name__) + + +def camel_to_lc(name: str): + ret = [] + for c in name: + if c.isupper() and ret: + ret.append("_") + ret.append(c.lower()) + return ''.join(ret) + + +class ATFHandler(object): + def __init__(self, args, globals_dict): + self._args = args + self._test_params = ATFHandler._get_params(args.params) + self._tests = ATFHandler.get_tests(globals_dict) + if args.stdout: + self._status_fd = open(args.stdout, "w") + else: + self._status_fd = None + + def report_result(self, result: ATFTestResult, reason: Optional[str]=None): + exit_code = 1 + if result in {ATFTestResult.PASSED, ATFTestResult.SKIPPED}: + exit_code = 0 + text = result.name.lower() + if reason: + text = "{} {}".format(text, reason) + print(text, file=self._status_fd) + sys.exit(exit_code) + + def report_failure(self, reason: str): + self.report_result(ATFTestResult.FAILED, reason) + + @staticmethod + def is_test_class(cls): + return inspect.isclass(cls) and not cls.__name__.endswith("Template") + + @staticmethod + def get_tests(globals_dict): + candidates = globals_dict.values() + candidates = [cl for cl in candidates if ATFHandler.is_test_class(cl)] + return [cl for cl in candidates if issubclass(cl, ATFTestTemplate)] + + @staticmethod + def _get_params(params_list: List[str]) -> Dict[str, str]: + res = {} + for param in params_list: + nv = param.split("=", 1) + if len(nv) == 2: + res[nv[0]] = nv[1] + return res + + def list_test(self, test): + print("ident: {}".format(test.test_name())) + if hasattr(test, "cleanup"): + has_cleanup = "true" + else: + has_cleanup = "false" + print("has.cleanup: {}".format(has_cleanup)) + # Encap + print("descr: {}".format(test.description)) + for option_name, option_val in test.options.items(): + print("{}: {}".format(option_name, option_val)) + + def list_tests(self): + print("Content-Type: application/X-atf-tp; version=\"1\"") + print() + for test in self._tests: + self.list_test(test) + print() + + def get_test_instance(self, test_name: str): + for test in self._tests: + if test.test_name() == test_name: + return test(self._test_params) + return None + + def run_test(self, test_name: str) -> bool: + cleanup = False + if test_name.endswith(":cleanup"): + test_name = test_name[:-len(":cleanup")] + cleanup = True + test = self.get_test_instance(test_name) + try: + pass + except Exception as e: + traceback.print_exc() + self.report_failure("unable to init test: {}".format(e)) + if not test: + self.report_failure("Unknown test case `{}'".format(test_name)) + if not cleanup: + try: + test.run() + except Exception as e: + traceback.print_exc() + self.report_failure("error: {}".format(e)) + self.report_result(ATFTestResult.PASSED) + else: + try: + test.cleanup() + except Exception as e: + traceback.print_exc() + self.report_failure("error: {}".format(e)) + + +def parser(): + parser = argparse.ArgumentParser() + parser.add_argument("-l", dest="list", action="store_true") + parser.add_argument("-s", dest="dir") + parser.add_argument("-r", dest="stdout") + parser.add_argument("-v", dest="params", action="append", default=[]) + parser.add_argument("test_name", nargs="?") + return parser.parse_args() + + +def atf_main(globals_dict): + args = parser() + if args.list: + handler = ATFHandler(args, globals_dict) + handler.list_tests() + elif args.test_name: + handler = ATFHandler(args, globals_dict) + handler.run_test(args.test_name) Index: tests/sys/common/vnet.py =================================================================== --- /dev/null +++ tests/sys/common/vnet.py @@ -0,0 +1,220 @@ +#!/usr/local/bin/python3 + +from atf import atf_main, ATFTestTemplate +import time + +from ctypes.util import find_library +from ctypes import cdll + +from typing import Dict, List + +import os +import socket + + +def run_cmd(cmd: str) -> str: + print("run: '{}'".format(cmd)) + return os.popen(cmd).read() + + +class VnetInterface(object): + INTERFACES_FNAME = "created_interfaces.lst" + + IFT_LOOP = 0x18 + IFT_ETHER = 0x06 + + def __init__(self, iface_name: str): + self.name = iface_name + self.vnet_name = "" + self.jailed = False + if iface_name.startswith("lo"): + self.iftype = self.IFT_LOOP + else: + self.iftype = self.IFT_ETHER + + @property + def ifindex(self): + return socket.if_nametoindex(self.name) + + def set_vnet(self, vnet_name: str): + self.vnet_name = vnet_name + + def set_jailed(self, jailed: bool): + self.jailed = jailed + + def run_cmd(self, cmd): + if self.vnet_name and not self.jailed: + cmd = "jexec {} {}".format(self.vnet_name, cmd) + run_cmd(cmd) + + @staticmethod + def file_append_line(line): + with open(VnetInterface.INTERFACES_FNAME, "a") as f: + f.write(line + "\n") + + @classmethod + def create_iface(cls, iface_name: str): + iface_name = run_cmd("/sbin/ifconfig {} create".format(iface_name)).rstrip() + if not iface_name: + raise Exception("Unable to create iface {}".format(iface_name)) + cls.file_append_line(iface_name) + if iface_name.startswith("epair"): + cls.file_append_line(iface_name[:-1] + "b") + return cls(iface_name) + + @staticmethod + def cleanup_ifaces(): + try: + with open(VnetInterface.INTERFACES_FNAME, "r") as f: + for line in f: + run_cmd("/sbin/ifconfig {} destroy".format(line.strip())) + os.unlink(VnetInterface.INTERFACES_FNAME) + except Exception as e: + pass + + def setup_addr(self, addr: str): + if ":" in addr: + family = "inet6" + else: + family = "inet" + cmd = "/sbin/ifconfig {} {} {}".format(self.name, family, addr) + self.run_cmd(cmd) + + def delete_addr(self, addr: str): + if ":" in addr: + cmd = "/sbin/ifconfig {} inet6 {} delete".format(self.name, addr) + else: + cmd = "/sbin/ifconfig {} -alias {}".format(self.name, addr) + self.run_cmd(cmd) + + def turn_up(self): + cmd = "/sbin/ifconfig {} up".format(self.name) + self.run_cmd(cmd) + + def enable_ipv6(self): + cmd = "/usr/sbin/ndp -i {} -disabled".format(self.name) + self.run_cmd(cmd) + + +class VnetInstance(object): + JAILS_FNAME = "created_jails.lst" + + def __init__(self, vnet_name: str, jid: int, ifaces: List[VnetInterface]): + self.name = vnet_name + self.jid = jid + self.ifaces = ifaces + for iface in ifaces: + iface.set_vnet(vnet_name) + iface.set_jailed(True) + + def run_vnet_cmd(self, cmd): + if self.vnet_name: + cmd = "jexec {} {}".format(self.vnet_name, cmd) + return run_cmd(cmd) + + @staticmethod + def wait_interface(vnet_name: str, iface_name: str): + cmd = "jexec {} /sbin/ifconfig -l".format(vnet_name) + for i in range(50): + ifaces = run_cmd(cmd).strip().split(" ") + if iface_name in ifaces: + return True + time.sleep(0.1) + return False + + @staticmethod + def file_append_line(line): + with open(VnetInstance.JAILS_FNAME, "a") as f: + f.write(line + "\n") + + @staticmethod + def cleanup_vnets(): + try: + with open(VnetInstance.JAILS_FNAME) as f: + for line in f: + run_cmd("/usr/sbin/jail -r {}".format(line.strip())) + os.unlink(VnetInstance.JAILS_FNAME) + except Exception as e: + pass + + @classmethod + def create_with_interfaces(cls, vnet_name: str, ifaces: List[VnetInterface]): + iface_cmds = " ".join(["vnet.interface={}".format(i.name) for i in ifaces]) + cmd = "/usr/sbin/jail -i -c name={} persist vnet {}".format(vnet_name, iface_cmds) + jid_str = run_cmd(cmd) + jid = int(jid_str) + if jid <= 0: + raise Exception("Jail creation failed, output: {}".format(jid)) + cls.file_append_line(vnet_name) + + for iface in ifaces: + if cls.wait_interface(vnet_name, iface.name): + continue + raise Exception("Interface {} has not appeared in vnet {}".format(iface.name, vnet_name)) + return cls(vnet_name, jid, ifaces) + + @staticmethod + def attach_jid(jid: int): + libc = cdll.LoadLibrary(find_library("c")) + if libc.jail_attach(jid) != 0: + raise Exception("jail_attach() failed: errno {}".format(sys.errno)) + + def attach(self): + self.attach_jid(self.jid) + + +class SingleVnetTestTemplate(ATFTestTemplate): + num_epairs = 1 + + def __init__(self, params: Dict[str, str]): + super().__init__(params) + + def run(self): + vnet_name = "jail_{}".format(self.test_name()) + ifaces = [] + for i in range(self.num_epairs): + ifaces.append(VnetInterface.create_iface("epair")) + self.vnet = VnetInstance.create_with_interfaces(vnet_name, ifaces) + self.vnet.attach() + if hasattr(self, "IPV6_PREFIXES"): + for i, addr in enumerate(self.IPV6_PREFIXES): + if addr: + iface = self.vnet.ifaces[i] + iface.turn_up() + iface.enable_ipv6() + iface.setup_addr(addr) + if hasattr(self, "IPV4_PREFIXES"): + for i, addr in enumerate(self.IPV4_PREFIXES): + if addr: + iface = self.vnet.ifaces[i] + iface.turn_up() + iface.setup_addr(addr) + + def cleanup(self): + print("==== vnet cleanup ===") + # XXX: sleep 100ms to avoid epair qflush panic + time.sleep(0.1) + VnetInstance.cleanup_vnets() + VnetInterface.cleanup_ifaces() + + def run_cmd(self, cmd: str) -> str: + return os.popen(cmd).read() + + def create_iface(self, iface_name: str): + iface_name = self.run_cmd("/sbin/ifconfig {} create".format(iface_name)).rstrip() + if not iface_name: + raise Exception("Unable to create iface {}".format(iface_name)) + self.file_append_line(iface_name) + if iface_name.startswith("epair"): + self.file_append_line(iface_name[:-1] + "b") + return iface_name + + def destroy_iface(self, iface_name: str): + self.run_cmd("/sbin/ifconfig {} destroy".format(iface_name)) + + def setup_iface_addr(self, addr: str): + pass + + def file_append_line(self, line): + with open(self.IFACES_FNAME, "a") as f: + f.write(iface_name + "\n") Index: tests/sys/net/routing/Makefile =================================================================== --- tests/sys/net/routing/Makefile +++ tests/sys/net/routing/Makefile @@ -8,6 +8,8 @@ ATF_TESTS_C += test_rtsock_l3 ATF_TESTS_C += test_rtsock_lladdr +ATF_TESTS_PY += test_rtsock_multipath + ${PACKAGE}FILES+= generic_cleanup.sh ${PACKAGE}FILESMODE_generic_cleanup.sh=0555 Index: tests/sys/net/routing/rtsock.py =================================================================== --- /dev/null +++ tests/sys/net/routing/rtsock.py @@ -0,0 +1,440 @@ +#!/usr/local/bin/python3 + +from ctypes import * +import socket +import os +import sys +import unittest + +from typing import List, Dict, Optional + + +def roundup2(val: int, num: int) -> int: + if val % num: + return (val | (num - 1)) + 1 + else: + return val + + +class RtConst(): + RTM_VERSION = 5 + + AF_INET = socket.AF_INET + AF_INET6 = socket.AF_INET6 + AF_LINK = socket.AF_LINK + + RTA_DST = 0x1 + RTA_GATEWAY = 0x2 + RTA_NETMASK = 0x4 + RTA_GENMASK = 0x8 + RTA_IFP = 0x10 + RTA_IFA = 0x20 + RTA_AUTHOR = 0x40 + RTA_BRD = 0x80 + + RTM_ADD = 1 + RTM_DELETE = 2 + RTM_CHANGE = 3 + RTM_GET = 4 + + RTF_UP = 0x1 + RTF_GATEWAY = 0x2 + RTF_HOST = 0x4 + RTF_REJECT = 0x8 + RTF_DYNAMIC = 0x10 + RTF_MODIFIED = 0x20 + RTF_DONE = 0x40 + RTF_XRESOLVE = 0x200 + RTF_LLINFO = 0x400 + RTF_LLDATA = 0x400 + RTF_STATIC = 0x800 + RTF_BLACKHOLE = 0x1000 + RTF_PROTO2 = 0x4000 + RTF_PROTO1 = 0x8000 + RTF_PROTO3 = 0x40000 + RTF_FIXEDMTU = 0x80000 + RTF_PINNED = 0x100000 + RTF_LOCAL = 0x200000 + RTF_BROADCAST = 0x400000 + RTF_MULTICAST = 0x800000 + RTF_STICKY = 0x10000000 + RTF_RNH_LOCKED = 0x40000000 + RTF_GWFLAG_COMPAT = 0x80000000 + + @staticmethod + def get_props(prefix: str) -> List[str]: + return [n for n in dir(RtConst) if n.startswith(prefix)] + + @staticmethod + def get_name(prefix: str, value: int) -> str: + props = RtConst.get_props(prefix) + for prop in props: + if getattr(RtConst, prop) == value: + return prop + return "U:" . str(value) + + @staticmethod + def get_bitmask_map(prefix: str, value: int) -> Dict[int, str]: + r1 = {} + props = RtConst.get_props(prefix) + propmap = {getattr(RtConst, prop): prop for prop in props} + v = 1 + ret = {} + while value: + if v & value: + if v in propmap: + ret[v] = propmap[v] + else: + ret[v] = hex(v) + value -= v + v *= 2 + return ret + + @staticmethod + def get_bitmask_str(prefix: str, value: int) -> str: + bmap = RtConst.get_bitmask_map(prefix, value) + return ",".join([v for k, v in bmap.items()]) + + +class RtMetrics(Structure): + _fields_ = [ + ("rmx_locks", c_ulong), + ("rmx_mtu", c_ulong), + ("rmx_hopcount", c_ulong), + ("rmx_expire", c_ulong), + ("rmx_recvpipe", c_ulong), + ("rmx_sendpipe", c_ulong), + ("rmx_ssthresh", c_ulong), + ("rmx_rtt", c_ulong), + ("rmx_rttvar", c_ulong), + ("rmx_pksent", c_ulong), + ("rmx_weight", c_ulong), + ("rmx_nhidx", c_ulong), + ("rmx_filler", c_ulong * 2), + ] + + +class RtMsgHdr(Structure): + _fields_ = [ + ("rtm_msglen", c_ushort), + ("rtm_version", c_byte), + ("rtm_type", c_byte), + ("rtm_index", c_ushort), + ("_rtm_spare1", c_ushort), + ("rtm_flags", c_int), + ("rtm_addrs", c_int), + ("rtm_pid", c_int), + ("rtm_seq", c_int), + ("rtm_errno", c_int), + ("rtm_fmask", c_int), + ("rtm_inits", c_ulong), + ("rtm_rmx", RtMetrics), + ] + + +class SockaddrIn(Structure): + _fields_ = [ + ("sin_len", c_byte), + ("sin_family", c_byte), + ("sin_port", c_ushort), + ("sin_addr", c_uint32), + ("sin_zero", c_char * 8), + ] + + +class SockaddrIn6(Structure): + _fields_ = [ + ("sin6_len", c_byte), + ("sin6_family", c_byte), + ("sin6_port", c_ushort), + ("sin6_flowinfo", c_uint32), + ("sin6_addr", c_byte * 16), + ("sin6_scope_id", c_uint32), + ] + + +class SockaddrDl(Structure): + _fields_ = [ + ("sdl_len", c_byte), + ("sdl_family", c_byte), + ("sdl_index", c_ushort), + ("sdl_type", c_byte), + ("sdl_nlen", c_byte), + ("sdl_alen", c_byte), + ("sdl_slen", c_byte), + ("sdl_data", c_byte * 8), + ] + + +class SaHelper(object): + @staticmethod + def ip_sa(ip: str) -> bytes: + addr_int = int.from_bytes(socket.inet_pton(2, ip), sys.byteorder) + sin = SockaddrIn(sizeof(SockaddrIn), socket.AF_INET, 0, addr_int) + return bytes(sin) + + @staticmethod + def ip6_sa(ip6: str, scopeid: int): + addr_bytes = socket.inet_pton(socket.AF_INET6, ip6) + sin6 = SockaddrIn6(sizeof(SockaddrIn6), socket.AF_INET6, 0, 0, addr_bytes, 0) + return bytes(sin6) + + @staticmethod + def link_sa(ifindex: Optional[int]=0, iftype: Optional[int]=0): + sa = SockaddrDl(sizeof(SockaddrDl), socket.AF_LINK, c_ushort(ifindex), iftype) + return bytes(sa) + + +class BaseRtsockMessage(object): + def __init__(self, rtm_type): + self.rtm_type = rtm_type + self.ut = unittest.TestCase() + self.sa = SaHelper() + + def assertEqual(self, a, b, msg=None): + self.ut.assertEqual(a, b, msg) + + def assertNotEqual(self, a, b, msg=None): + self.ut.assertNotEqual(a, b, msg) + + +class RtsockRtMessage(BaseRtsockMessage): + messages = [RtConst.RTM_ADD, RtConst.RTM_DELETE, RtConst.RTM_CHANGE, RtConst.RTM_GET] + + def __init__(self, rtm_type, rtm_seq=1, dst_sa=None, mask_sa=None): + super().__init__(rtm_type) + self.rtm_flags = 0 + self.rtm_seq = rtm_seq + self._attrs = {} + self.rtm_errno = 0 + self.rtm_pid = 0 + self._orig_data = None + if dst_sa: + self.add_sa_attr(RtConst.RTA_DST, dst_sa) + if mask_sa: + self.add_sa_attr(RtConst.RTA_NETMASK, mask_sa) + + def add_sa_attr(self, attr_type, attr_bytes): + self._attrs[attr_type] = attr_bytes + + def add_ip_attr(self, attr_type, ip: str): + self.add_sa_attr(attr_type, self.sa.ip_sa(ip)) + + def add_ip6_attr(self, attr_type, ip6: str, scopeid: int): + self.add_sa_attr(attr_type, self.sa.ip6_sa(ip6, scopeid)) + + def add_link_attr(self, attr_type, ifindex: Optional[int]=0): + self.add_sa_attr(attr_type, self.sa.link_sa(ifindex)) + + def print_sa_inet(self, sa: bytes): + if len(sa) < 8: + raise Exception("IPv4 sa size too small: {}".format(sa)) + addr = socket.inet_ntop(socket.AF_INET, sa[4:8]) + return "{}".format(addr) + + def print_sa_inet6(self, sa: bytes): + if len(sa) < sizeof(SockaddrIn6): + raise Exception("IPv6 sa size too small: {}".format(sa)) + addr = socket.inet_ntop(socket.AF_INET6, sa[8:24]) + scopeid = struct.unpack(">I", sa[24:28]) + return "{} scopeid {}".format(addr, scopeid) + + def print_sa_link(self, sa: bytes, hd: Optional[bool] = True): + if len(sa) < sizeof(SockaddrDl): + raise Exception("LINK sa size too small: {}".format(sa)) + sdl = SockaddrDl.from_buffer_copy(sa) + if sdl.sdl_index: + ifindex = "link#{} ".format(sdl.sdl_index) + else: + ifindex = "" + if sdl.sdl_nlen: + iface_offset = 8 + if sdl.sdl_nlen + iface_offset > len(sa): + raise Exception("LINK sa sdl_nlen {} > total len {}".format(sdl.sdl_nlen, len(sa))) + ifname = "ifname:{} ".format(bytes.decode(sa[iface_offset:iface_offset + sdl.sdl_nlen])) + else: + ifname = "" + return "{}{}".format(ifindex, ifname) + + def print_sa_unknown(self, sa: bytes): + return "unknown_type:{}".format(sa[1]) + + def print_sa(self, sa: bytes, hd: Optional[bool] = False): + if sa[0] != len(sa): + raise Exception("sa size {} != buffer size {}".format(sa[0], len(sa))) + + if len(sa) < 2: + raise Exception("sa type {} too short: {}".format(RtConst.get_name("AF_", sa[1]), len(sa))) + + if sa[1] == socket.AF_INET: + text = self.print_sa_inet(sa) + elif sa[1] == socket.AF_INET6: + text = self.print_sa_inet6(sa) + elif sa[1] == socket.AF_LINK: + text = self.print_sa_link(sa) + else: + text = self.print_sa_unknown(sa) + if hd: + dump = " [{}]".format(sa) + else: + dump = "" + return "{}{}".format(text, dump) + + def print_message(self): + # RTM_GET: Report Metrics: len 272, pid: 87839, seq 1, errno 0, flags: + if self._orig_data: + rtm_len = len(self._orig_data) + else: + rtm_len = len(bytes(self)) + print("{}: len {}, pid: {}, seq {}, errno {}, flags: <{}>".format( + RtConst.get_name("RTM_", self.rtm_type), + rtm_len, + self.rtm_pid, + self.rtm_seq, + self.rtm_errno, + RtConst.get_bitmask_str("RTF_", self.rtm_flags) + )) + rtm_addrs = sum(list(self._attrs.keys())) + print("Addrs: <{}>".format(RtConst.get_bitmask_str("RTA_", rtm_addrs))) + for attr in sorted(self._attrs.keys()): + sa_data = self.print_sa(self._attrs[attr]) + print(" {}: {}".format(RtConst.get_name("RTA_", attr), sa_data)) + + @staticmethod + def verify_sa_inet(sa_data): + if len(sa_data) < 8: + raise Exception("IPv4 sa size too small: {}".format(sa_data)) + if sa_data[0] > len(sa_data): + raise Exception("IPv4 sin_len too big: {} vs sa size {}: {}".format(sa_data[0], len(sa_data), sa_data)) + sin = SockaddrIn.from_buffer_copy(sa_data) + self.assertEqual(sin.sin_port, 0) + assert sin.sin_zero == [0] * 8 + + def compare_sa(self, sa_type, sa_data): + if len(sa_data) < 4: + raise Exception("sa_len for type {} too short: {}".format()) + our_sa = self._attrs[sa_type] + self.assertEqual(len(sa_data), len(our_sa)) + self.assertEqual(our_sa, sa_data) + + def verify(self, rtm_type: int, rtm_sa): + assert self.rtm_type == rtm_type + assert self.rtm_errno == 0 + hdr = RtMsgHdr.from_buffer_copy(self._orig_data) + assert hdr._rtm_spare1 == 0 + for sa_type, sa_data in rtm_sa.items(): + if sa_type not in self._attrs: + raise Exception("SA type {} not present".format(RtConst.get_name("RTA_", sa_type))) + self.compare_sa(sa_type, sa_data) + + @classmethod + def from_bytes(cls, data: bytes): + if len(data) < sizeof(RtMsgHdr): + raise Exception("messages size {} is less than expected {}".format(len(data), sizeof(RtMsgHdr))) + hdr = RtMsgHdr.from_buffer_copy(data) + + self = cls(hdr.rtm_type) + self.rtm_flags = hdr.rtm_flags + self.rtm_seq = hdr.rtm_seq + self.rtm_errno = hdr.rtm_errno + self.rtm_pid = hdr.rtm_pid + self.rtm_len = len(data) + self._orig_data = data + + off = sizeof(RtMsgHdr) + v = 1 + addrs_mask = hdr.rtm_addrs + while addrs_mask: + if addrs_mask & v: + addrs_mask -= v + + if off + data[off] > len(data): + raise Exception("SA sizeof for {} > total message length: {}+{} > {}".format( + RtConst.get_name("RTA_", v), off, data[off], len(data))) + self._attrs[v] = data[off:off + data[off]] + off += roundup2(data[off], 4) + v *= 2 + return self + + def __bytes__(self): + sz = sizeof(RtMsgHdr) + addrs_mask = 0 + for k, v in self._attrs.items(): + sz += roundup2(len(v), 4) + addrs_mask += k + hdr = RtMsgHdr( + rtm_msglen=sz, + rtm_version=RtConst.RTM_VERSION, + rtm_type=self.rtm_type, + rtm_flags=self.rtm_flags, + rtm_seq=self.rtm_seq, + rtm_addrs=addrs_mask, + ) + buf = bytearray(sz) + buf[0:sizeof(RtMsgHdr)] = hdr + off = sizeof(RtMsgHdr) + for attr in sorted(self._attrs.keys()): + v = self._attrs[attr] + sa_len = len(v) + buf[off:off + sa_len] = v + off += roundup2(len(v), 4) + return bytes(buf) + + +class Rtsock(): + def __init__(self): + self.rtsock_fd = self._setup_rtsock() + self.rtm_seq = 1 + self.msgmap = self.build_msgmap() + + def build_msgmap(self): + classes = [RtsockRtMessage] + xmap = {} + for cls in classes: + for message in cls.messages: + xmap[message] = cls + return xmap + + def get_seq(self): + ret = self.rtm_seq + self.rtm_seq += 1 + return ret + + def _setup_rtsock(self) -> int: + s = socket.socket(socket.AF_ROUTE, socket.SOCK_RAW, socket.AF_UNSPEC) + s.setsockopt(socket.SOL_SOCKET, socket.SO_USELOOPBACK, 1) + return s + + def write_message(self, msg): + print("vvvvvvvv OUT vvvvvvvv") + msg.print_message() + msg_bytes = bytes(msg) + try: + ret = os.write(self.rtsock_fd.fileno(), bytes(msg)) + except Exception as e: + print("write({}) -> {}".format(len(msg_bytes), e)) + + def parse_message(self, data: bytes): + if len(data) < 4: + raise Exception("Short read from rtsock: {} bytes".format(len(data))) + rtm_type = data[4] + if rtm_type not in self.msgmap: + return None + + def write_data(self, data: bytes): + self.rtsock_fd.send(data) + + def read_data(self, seq: Optional[int] = None) -> bytes: + while True: + data = self.rtsock_fd.recv(4096) + if seq is None: + break + if len(data) > sizeof(RtMsgHdr): + hdr = RtMsgHdr.from_buffer_copy(data) + if hdr.rtm_seq == seq: + break + return data + + def read_message(self) -> bytes: + data = self.read_data() + return self.parse_message(data) Index: tests/sys/net/routing/test_rtsock_multipath.py =================================================================== --- /dev/null +++ tests/sys/net/routing/test_rtsock_multipath.py @@ -0,0 +1,85 @@ + +from common.atf import atf_main, ATFTestTemplate +from common.vnet import SingleVnetTestTemplate, VnetInstance +from rtsock import Rtsock, RtsockRtMessage, RtConst, SaHelper + + +class BaseIPv4RoutingTestTemplate(SingleVnetTestTemplate): + options = {"require.user": "root"} + num_epairs = 1 + IPV4_PREFIXES = ["192.0.2.1/24"] + + def run(self): + super().run() + self.rtsock = Rtsock() + + +class BaseIPv6RoutingTestTemplate(SingleVnetTestTemplate): + options = {"require.user": "root"} + num_epairs = 1 + IPV6_PREFIXES = ["2001:DB8::1/64"] + + def run(self): + super().run() + self.rtsock = Rtsock() + + +class MultipathAdd(ATFTestTemplate): + description = "Multipath routing" + options = {"require.user": "root"} + + def run(self): + self.log("HERE") + + @staticmethod + def cleanup(self): + self.log("CLEANUP HERE") + + +class RtmGetv4ExactSuccess(BaseIPv4RoutingTestTemplate): + description = "Tests RTM_GET with exact prefix lookup on an interface prefix" + + def run(self): + super().run() + sa = SaHelper() + msg = RtsockRtMessage(RtConst.RTM_GET, self.rtsock.get_seq(), sa.ip_sa("192.0.2.0"), sa.ip_sa("255.255.255.0")) + self.rtsock.write_message(msg) + + iface = self.vnet.ifaces[0] + desired_sa = { + RtConst.RTA_DST: sa.ip_sa("192.0.2.0"), + RtConst.RTA_NETMASK: sa.ip_sa("255.255.255.0"), + RtConst.RTA_GATEWAY: sa.link_sa(ifindex=iface.ifindex, iftype=iface.iftype), + } + + data = self.rtsock.read_data(msg.rtm_seq) + msg = RtsockRtMessage.from_bytes(data) + print("vvvvvvvv IN vvvvvvvv") + msg.print_message() + msg.verify(RtConst.RTM_GET, desired_sa) + + +class AddRouteWithRta(BaseIPv4RoutingTestTemplate): + + def run(self): + super().run() + sa = SaHelper() + msg = RtsockRtMessage(RtConst.RTM_GET, self.rtsock.get_seq(), sa.ip_sa("192.0.2.0"), sa.ip_sa("255.255.255.0")) + self.rtsock.write_message(msg) + + iface = self.vnet.ifaces[0] + desired_sa = { + RtConst.RTA_DST: sa.ip_sa("192.0.2.0"), + RtConst.RTA_NETMASK: sa.ip_sa("255.255.255.0"), + RtConst.RTA_GATEWAY: sa.link_sa(ifindex=iface.ifindex, iftype=iface.iftype), + } + + data = self.rtsock.read_data(msg.rtm_seq) + msg = RtsockRtMessage.from_bytes(data) + print("vvvvvvvv IN vvvvvvvv") + msg.print_message() + msg.verify(RtConst.RTM_GET, desired_sa) + + +if __name__ == '__main__': + atf_main(globals())