Page MenuHomeFreeBSD

D31084.id91913.diff
No OneTemporary

D31084.id91913.diff

Index: share/mk/atf.test.mk
===================================================================
--- share/mk/atf.test.mk
+++ share/mk/atf.test.mk
@@ -22,6 +22,7 @@
ATF_TESTS_CXX?=
ATF_TESTS_SH?=
ATF_TESTS_KSH93?=
+ATF_TESTS_PY?=
.if !empty(ATF_TESTS_C)
PROGS+= ${ATF_TESTS_C}
@@ -109,3 +110,29 @@
mv ${.TARGET}.tmp ${.TARGET}
.endfor
.endif
+
+.if !empty(ATF_TESTS_PY)
+SCRIPTS+= ${ATF_TESTS_PY}
+_TESTS+= ${ATF_TESTS_PY}
+.for _T in ${ATF_TESTS_PY}
+SCRIPTSDIR_${_T}= ${TESTSDIR}
+TEST_INTERFACE.${_T}= atf
+TEST_METADATA.${_T}+= required_programs="python3"
+CLEANFILES+= ${_T} ${_T}.tmp
+# TODO(jmmv): It seems to me that this SED and SRC functionality should
+# exist in bsd.prog.mk along the support for SCRIPTS. Move it there if
+# this proves to be useful within the tests.
+ATF_TESTS_PY_SED_${_T}?= # empty
+ATF_TESTS_PY_SRC_${_T}?= ${_T}.py
+${_T}: ${ATF_TESTS_PY_SRC_${_T}}
+ echo "#!/usr/bin/env -S PYTHONPATH=/usr/tests/sys python3" > ${.TARGET}.tmp
+.if empty(ATF_TESTS_PY_SED_${_T})
+ cat ${.ALLSRC:N*Makefile*} >>${.TARGET}.tmp
+.else
+ cat ${.ALLSRC:N*Makefile*} \
+ | sed ${ATF_TESTS_PY_SED_${_T}} >>${.TARGET}.tmp
+.endif
+ chmod +x ${.TARGET}.tmp
+ mv ${.TARGET}.tmp ${.TARGET}
+.endfor
+.endif
Index: tests/sys/Makefile
===================================================================
--- tests/sys/Makefile
+++ tests/sys/Makefile
@@ -31,6 +31,7 @@
TESTS_SUBDIRS+= sys
TESTS_SUBDIRS+= vfs
TESTS_SUBDIRS+= vm
+TESTS_SUBDIRS= net
.if ${MK_AUDIT} != "no"
_audit= audit
@@ -45,4 +46,7 @@
SUBDIR+= common
+PACKAGE= tests
+${PACKAGE}FILES+= __init__.py
+
.include <bsd.test.mk>
Index: tests/sys/common/Makefile
===================================================================
--- tests/sys/common/Makefile
+++ tests/sys/common/Makefile
@@ -6,6 +6,9 @@
${PACKAGE}FILES+= divert.py
${PACKAGE}FILES+= sender.py
${PACKAGE}FILES+= net_receiver.py
+${PACKAGE}FILES+= atf.py
+${PACKAGE}FILES+= vnet.py
+${PACKAGE}FILES+= __init__.py
${PACKAGE}FILESMODE_divert.py=0555
${PACKAGE}FILESMODE_sender.py=0555
Index: tests/sys/common/atf.py
===================================================================
--- /dev/null
+++ tests/sys/common/atf.py
@@ -0,0 +1,155 @@
+#!/usr/local/bin/python3
+
+import argparse
+import unittest
+import inspect
+import sys
+import traceback
+from enum import auto, Enum
+
+from typing import Dict, List, Optional
+
+
+class ATFTestResult(Enum):
+ FAILED = auto()
+ PASSED = auto()
+ SKIPPED = auto()
+
+
+class ATFTestTemplate(unittest.TestCase):
+ description = "Base ATF test"
+ options = {}
+
+ def __init__(self, params: Dict):
+ self._params = params
+
+ def log(self, msg: str):
+ print(msg, file=sys.stderr)
+
+ @classmethod
+ def test_name(cls):
+ return camel_to_lc(cls.__name__)
+
+
+def camel_to_lc(name: str):
+ ret = []
+ for c in name:
+ if c.isupper() and ret:
+ ret.append("_")
+ ret.append(c.lower())
+ return ''.join(ret)
+
+
+class ATFHandler(object):
+ def __init__(self, args, globals_dict):
+ self._args = args
+ self._test_params = ATFHandler._get_params(args.params)
+ self._tests = ATFHandler.get_tests(globals_dict)
+ if args.stdout:
+ self._status_fd = open(args.stdout, "w")
+ else:
+ self._status_fd = None
+
+ def report_result(self, result: ATFTestResult, reason: Optional[str]=None):
+ exit_code = 1
+ if result in {ATFTestResult.PASSED, ATFTestResult.SKIPPED}:
+ exit_code = 0
+ text = result.name.lower()
+ if reason:
+ text = "{} {}".format(text, reason)
+ print(text, file=self._status_fd)
+ sys.exit(exit_code)
+
+ def report_failure(self, reason: str):
+ self.report_result(ATFTestResult.FAILED, reason)
+
+ @staticmethod
+ def is_test_class(cls):
+ return inspect.isclass(cls) and not cls.__name__.endswith("Template")
+
+ @staticmethod
+ def get_tests(globals_dict):
+ candidates = globals_dict.values()
+ candidates = [cl for cl in candidates if ATFHandler.is_test_class(cl)]
+ return [cl for cl in candidates if issubclass(cl, ATFTestTemplate)]
+
+ @staticmethod
+ def _get_params(params_list: List[str]) -> Dict[str, str]:
+ res = {}
+ for param in params_list:
+ nv = param.split("=", 1)
+ if len(nv) == 2:
+ res[nv[0]] = nv[1]
+ return res
+
+ def list_test(self, test):
+ print("ident: {}".format(test.test_name()))
+ if hasattr(test, "cleanup"):
+ has_cleanup = "true"
+ else:
+ has_cleanup = "false"
+ print("has.cleanup: {}".format(has_cleanup))
+ # Encap
+ print("descr: {}".format(test.description))
+ for option_name, option_val in test.options.items():
+ print("{}: {}".format(option_name, option_val))
+
+ def list_tests(self):
+ print("Content-Type: application/X-atf-tp; version=\"1\"")
+ print()
+ for test in self._tests:
+ self.list_test(test)
+ print()
+
+ def get_test_instance(self, test_name: str):
+ for test in self._tests:
+ if test.test_name() == test_name:
+ return test(self._test_params)
+ return None
+
+ def run_test(self, test_name: str) -> bool:
+ cleanup = False
+ if test_name.endswith(":cleanup"):
+ test_name = test_name[:-len(":cleanup")]
+ cleanup = True
+ test = self.get_test_instance(test_name)
+ try:
+ pass
+ except Exception as e:
+ traceback.print_exc()
+ self.report_failure("unable to init test: {}".format(e))
+ if not test:
+ self.report_failure("Unknown test case `{}'".format(test_name))
+ if not cleanup:
+ try:
+ test.run()
+ except Exception as e:
+ traceback.print_exc()
+ self.report_failure("error: {}".format(e))
+ self.report_result(ATFTestResult.PASSED)
+ else:
+ try:
+ test.cleanup()
+ except Exception as e:
+ traceback.print_exc()
+ self.report_failure("error: {}".format(e))
+
+
+def parser():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-l", dest="list", action="store_true")
+ parser.add_argument("-s", dest="dir")
+ parser.add_argument("-r", dest="stdout")
+ parser.add_argument("-v", dest="params", action="append", default=[])
+ parser.add_argument("test_name", nargs="?")
+ return parser.parse_args()
+
+
+def atf_main(globals_dict):
+ args = parser()
+ if args.list:
+ handler = ATFHandler(args, globals_dict)
+ handler.list_tests()
+ elif args.test_name:
+ handler = ATFHandler(args, globals_dict)
+ handler.run_test(args.test_name)
Index: tests/sys/common/vnet.py
===================================================================
--- /dev/null
+++ tests/sys/common/vnet.py
@@ -0,0 +1,220 @@
+#!/usr/local/bin/python3
+
+from atf import atf_main, ATFTestTemplate
+import time
+
+from ctypes.util import find_library
+from ctypes import cdll
+
+from typing import Dict, List
+
+import os
+import socket
+
+
+def run_cmd(cmd: str) -> str:
+ print("run: '{}'".format(cmd))
+ return os.popen(cmd).read()
+
+
+class VnetInterface(object):
+ INTERFACES_FNAME = "created_interfaces.lst"
+
+ IFT_LOOP = 0x18
+ IFT_ETHER = 0x06
+
+ def __init__(self, iface_name: str):
+ self.name = iface_name
+ self.vnet_name = ""
+ self.jailed = False
+ if iface_name.startswith("lo"):
+ self.iftype = self.IFT_LOOP
+ else:
+ self.iftype = self.IFT_ETHER
+
+ @property
+ def ifindex(self):
+ return socket.if_nametoindex(self.name)
+
+ def set_vnet(self, vnet_name: str):
+ self.vnet_name = vnet_name
+
+ def set_jailed(self, jailed: bool):
+ self.jailed = jailed
+
+ def run_cmd(self, cmd):
+ if self.vnet_name and not self.jailed:
+ cmd = "jexec {} {}".format(self.vnet_name, cmd)
+ run_cmd(cmd)
+
+ @staticmethod
+ def file_append_line(line):
+ with open(VnetInterface.INTERFACES_FNAME, "a") as f:
+ f.write(line + "\n")
+
+ @classmethod
+ def create_iface(cls, iface_name: str):
+ iface_name = run_cmd("/sbin/ifconfig {} create".format(iface_name)).rstrip()
+ if not iface_name:
+ raise Exception("Unable to create iface {}".format(iface_name))
+ cls.file_append_line(iface_name)
+ if iface_name.startswith("epair"):
+ cls.file_append_line(iface_name[:-1] + "b")
+ return cls(iface_name)
+
+ @staticmethod
+ def cleanup_ifaces():
+ try:
+ with open(VnetInterface.INTERFACES_FNAME, "r") as f:
+ for line in f:
+ run_cmd("/sbin/ifconfig {} destroy".format(line.strip()))
+ os.unlink(VnetInterface.INTERFACES_FNAME)
+ except Exception as e:
+ pass
+
+ def setup_addr(self, addr: str):
+ if ":" in addr:
+ family = "inet6"
+ else:
+ family = "inet"
+ cmd = "/sbin/ifconfig {} {} {}".format(self.name, family, addr)
+ self.run_cmd(cmd)
+
+ def delete_addr(self, addr: str):
+ if ":" in addr:
+ cmd = "/sbin/ifconfig {} inet6 {} delete".format(self.name, addr)
+ else:
+ cmd = "/sbin/ifconfig {} -alias {}".format(self.name, addr)
+ self.run_cmd(cmd)
+
+ def turn_up(self):
+ cmd = "/sbin/ifconfig {} up".format(self.name)
+ self.run_cmd(cmd)
+
+ def enable_ipv6(self):
+ cmd = "/usr/sbin/ndp -i {} -disabled".format(self.name)
+ self.run_cmd(cmd)
+
+
+class VnetInstance(object):
+ JAILS_FNAME = "created_jails.lst"
+
+ def __init__(self, vnet_name: str, jid: int, ifaces: List[VnetInterface]):
+ self.name = vnet_name
+ self.jid = jid
+ self.ifaces = ifaces
+ for iface in ifaces:
+ iface.set_vnet(vnet_name)
+ iface.set_jailed(True)
+
+ def run_vnet_cmd(self, cmd):
+ if self.vnet_name:
+ cmd = "jexec {} {}".format(self.vnet_name, cmd)
+ return run_cmd(cmd)
+
+ @staticmethod
+ def wait_interface(vnet_name: str, iface_name: str):
+ cmd = "jexec {} /sbin/ifconfig -l".format(vnet_name)
+ for i in range(50):
+ ifaces = run_cmd(cmd).strip().split(" ")
+ if iface_name in ifaces:
+ return True
+ time.sleep(0.1)
+ return False
+
+ @staticmethod
+ def file_append_line(line):
+ with open(VnetInstance.JAILS_FNAME, "a") as f:
+ f.write(line + "\n")
+
+ @staticmethod
+ def cleanup_vnets():
+ try:
+ with open(VnetInstance.JAILS_FNAME) as f:
+ for line in f:
+ run_cmd("/usr/sbin/jail -r {}".format(line.strip()))
+ os.unlink(VnetInstance.JAILS_FNAME)
+ except Exception as e:
+ pass
+
+ @classmethod
+ def create_with_interfaces(cls, vnet_name: str, ifaces: List[VnetInterface]):
+ iface_cmds = " ".join(["vnet.interface={}".format(i.name) for i in ifaces])
+ cmd = "/usr/sbin/jail -i -c name={} persist vnet {}".format(vnet_name, iface_cmds)
+ jid_str = run_cmd(cmd)
+ jid = int(jid_str)
+ if jid <= 0:
+ raise Exception("Jail creation failed, output: {}".format(jid))
+ cls.file_append_line(vnet_name)
+
+ for iface in ifaces:
+ if cls.wait_interface(vnet_name, iface.name):
+ continue
+ raise Exception("Interface {} has not appeared in vnet {}".format(iface.name, vnet_name))
+ return cls(vnet_name, jid, ifaces)
+
+ @staticmethod
+ def attach_jid(jid: int):
+ libc = cdll.LoadLibrary(find_library("c"))
+ if libc.jail_attach(jid) != 0:
+ raise Exception("jail_attach() failed: errno {}".format(sys.errno))
+
+ def attach(self):
+ self.attach_jid(self.jid)
+
+
+class SingleVnetTestTemplate(ATFTestTemplate):
+ num_epairs = 1
+
+ def __init__(self, params: Dict[str, str]):
+ super().__init__(params)
+
+ def run(self):
+ vnet_name = "jail_{}".format(self.test_name())
+ ifaces = []
+ for i in range(self.num_epairs):
+ ifaces.append(VnetInterface.create_iface("epair"))
+ self.vnet = VnetInstance.create_with_interfaces(vnet_name, ifaces)
+ self.vnet.attach()
+ if hasattr(self, "IPV6_PREFIXES"):
+ for i, addr in enumerate(self.IPV6_PREFIXES):
+ if addr:
+ iface = self.vnet.ifaces[i]
+ iface.turn_up()
+ iface.enable_ipv6()
+ iface.setup_addr(addr)
+ if hasattr(self, "IPV4_PREFIXES"):
+ for i, addr in enumerate(self.IPV4_PREFIXES):
+ if addr:
+ iface = self.vnet.ifaces[i]
+ iface.turn_up()
+ iface.setup_addr(addr)
+
+ def cleanup(self):
+ print("==== vnet cleanup ===")
+ # XXX: sleep 100ms to avoid epair qflush panic
+ time.sleep(0.1)
+ VnetInstance.cleanup_vnets()
+ VnetInterface.cleanup_ifaces()
+
+ def run_cmd(self, cmd: str) -> str:
+ return os.popen(cmd).read()
+
+ def create_iface(self, iface_name: str):
+ iface_name = self.run_cmd("/sbin/ifconfig {} create".format(iface_name)).rstrip()
+ if not iface_name:
+ raise Exception("Unable to create iface {}".format(iface_name))
+ self.file_append_line(iface_name)
+ if iface_name.startswith("epair"):
+ self.file_append_line(iface_name[:-1] + "b")
+ return iface_name
+
+ def destroy_iface(self, iface_name: str):
+ self.run_cmd("/sbin/ifconfig {} destroy".format(iface_name))
+
+ def setup_iface_addr(self, addr: str):
+ pass
+
+ def file_append_line(self, line):
+ with open(self.IFACES_FNAME, "a") as f:
+ f.write(iface_name + "\n")
Index: tests/sys/net/routing/Makefile
===================================================================
--- tests/sys/net/routing/Makefile
+++ tests/sys/net/routing/Makefile
@@ -8,6 +8,8 @@
ATF_TESTS_C += test_rtsock_l3
ATF_TESTS_C += test_rtsock_lladdr
+ATF_TESTS_PY += test_rtsock_multipath
+
${PACKAGE}FILES+= generic_cleanup.sh
${PACKAGE}FILESMODE_generic_cleanup.sh=0555
Index: tests/sys/net/routing/rtsock.py
===================================================================
--- /dev/null
+++ tests/sys/net/routing/rtsock.py
@@ -0,0 +1,440 @@
+#!/usr/local/bin/python3
+
+from ctypes import *
+import socket
+import os
+import sys
+import unittest
+
+from typing import List, Dict, Optional
+
+
+def roundup2(val: int, num: int) -> int:
+ if val % num:
+ return (val | (num - 1)) + 1
+ else:
+ return val
+
+
+class RtConst():
+ RTM_VERSION = 5
+
+ AF_INET = socket.AF_INET
+ AF_INET6 = socket.AF_INET6
+ AF_LINK = socket.AF_LINK
+
+ RTA_DST = 0x1
+ RTA_GATEWAY = 0x2
+ RTA_NETMASK = 0x4
+ RTA_GENMASK = 0x8
+ RTA_IFP = 0x10
+ RTA_IFA = 0x20
+ RTA_AUTHOR = 0x40
+ RTA_BRD = 0x80
+
+ RTM_ADD = 1
+ RTM_DELETE = 2
+ RTM_CHANGE = 3
+ RTM_GET = 4
+
+ RTF_UP = 0x1
+ RTF_GATEWAY = 0x2
+ RTF_HOST = 0x4
+ RTF_REJECT = 0x8
+ RTF_DYNAMIC = 0x10
+ RTF_MODIFIED = 0x20
+ RTF_DONE = 0x40
+ RTF_XRESOLVE = 0x200
+ RTF_LLINFO = 0x400
+ RTF_LLDATA = 0x400
+ RTF_STATIC = 0x800
+ RTF_BLACKHOLE = 0x1000
+ RTF_PROTO2 = 0x4000
+ RTF_PROTO1 = 0x8000
+ RTF_PROTO3 = 0x40000
+ RTF_FIXEDMTU = 0x80000
+ RTF_PINNED = 0x100000
+ RTF_LOCAL = 0x200000
+ RTF_BROADCAST = 0x400000
+ RTF_MULTICAST = 0x800000
+ RTF_STICKY = 0x10000000
+ RTF_RNH_LOCKED = 0x40000000
+ RTF_GWFLAG_COMPAT = 0x80000000
+
+ @staticmethod
+ def get_props(prefix: str) -> List[str]:
+ return [n for n in dir(RtConst) if n.startswith(prefix)]
+
+ @staticmethod
+ def get_name(prefix: str, value: int) -> str:
+ props = RtConst.get_props(prefix)
+ for prop in props:
+ if getattr(RtConst, prop) == value:
+ return prop
+ return "U:" . str(value)
+
+ @staticmethod
+ def get_bitmask_map(prefix: str, value: int) -> Dict[int, str]:
+ r1 = {}
+ props = RtConst.get_props(prefix)
+ propmap = {getattr(RtConst, prop): prop for prop in props}
+ v = 1
+ ret = {}
+ while value:
+ if v & value:
+ if v in propmap:
+ ret[v] = propmap[v]
+ else:
+ ret[v] = hex(v)
+ value -= v
+ v *= 2
+ return ret
+
+ @staticmethod
+ def get_bitmask_str(prefix: str, value: int) -> str:
+ bmap = RtConst.get_bitmask_map(prefix, value)
+ return ",".join([v for k, v in bmap.items()])
+
+
+class RtMetrics(Structure):
+ _fields_ = [
+ ("rmx_locks", c_ulong),
+ ("rmx_mtu", c_ulong),
+ ("rmx_hopcount", c_ulong),
+ ("rmx_expire", c_ulong),
+ ("rmx_recvpipe", c_ulong),
+ ("rmx_sendpipe", c_ulong),
+ ("rmx_ssthresh", c_ulong),
+ ("rmx_rtt", c_ulong),
+ ("rmx_rttvar", c_ulong),
+ ("rmx_pksent", c_ulong),
+ ("rmx_weight", c_ulong),
+ ("rmx_nhidx", c_ulong),
+ ("rmx_filler", c_ulong * 2),
+ ]
+
+
+class RtMsgHdr(Structure):
+ _fields_ = [
+ ("rtm_msglen", c_ushort),
+ ("rtm_version", c_byte),
+ ("rtm_type", c_byte),
+ ("rtm_index", c_ushort),
+ ("_rtm_spare1", c_ushort),
+ ("rtm_flags", c_int),
+ ("rtm_addrs", c_int),
+ ("rtm_pid", c_int),
+ ("rtm_seq", c_int),
+ ("rtm_errno", c_int),
+ ("rtm_fmask", c_int),
+ ("rtm_inits", c_ulong),
+ ("rtm_rmx", RtMetrics),
+ ]
+
+
+class SockaddrIn(Structure):
+ _fields_ = [
+ ("sin_len", c_byte),
+ ("sin_family", c_byte),
+ ("sin_port", c_ushort),
+ ("sin_addr", c_uint32),
+ ("sin_zero", c_char * 8),
+ ]
+
+
+class SockaddrIn6(Structure):
+ _fields_ = [
+ ("sin6_len", c_byte),
+ ("sin6_family", c_byte),
+ ("sin6_port", c_ushort),
+ ("sin6_flowinfo", c_uint32),
+ ("sin6_addr", c_byte * 16),
+ ("sin6_scope_id", c_uint32),
+ ]
+
+
+class SockaddrDl(Structure):
+ _fields_ = [
+ ("sdl_len", c_byte),
+ ("sdl_family", c_byte),
+ ("sdl_index", c_ushort),
+ ("sdl_type", c_byte),
+ ("sdl_nlen", c_byte),
+ ("sdl_alen", c_byte),
+ ("sdl_slen", c_byte),
+ ("sdl_data", c_byte * 8),
+ ]
+
+
+class SaHelper(object):
+ @staticmethod
+ def ip_sa(ip: str) -> bytes:
+ addr_int = int.from_bytes(socket.inet_pton(2, ip), sys.byteorder)
+ sin = SockaddrIn(sizeof(SockaddrIn), socket.AF_INET, 0, addr_int)
+ return bytes(sin)
+
+ @staticmethod
+ def ip6_sa(ip6: str, scopeid: int):
+ addr_bytes = socket.inet_pton(socket.AF_INET6, ip6)
+ sin6 = SockaddrIn6(sizeof(SockaddrIn6), socket.AF_INET6, 0, 0, addr_bytes, 0)
+ return bytes(sin6)
+
+ @staticmethod
+ def link_sa(ifindex: Optional[int]=0, iftype: Optional[int]=0):
+ sa = SockaddrDl(sizeof(SockaddrDl), socket.AF_LINK, c_ushort(ifindex), iftype)
+ return bytes(sa)
+
+
+class BaseRtsockMessage(object):
+ def __init__(self, rtm_type):
+ self.rtm_type = rtm_type
+ self.ut = unittest.TestCase()
+ self.sa = SaHelper()
+
+ def assertEqual(self, a, b, msg=None):
+ self.ut.assertEqual(a, b, msg)
+
+ def assertNotEqual(self, a, b, msg=None):
+ self.ut.assertNotEqual(a, b, msg)
+
+
+class RtsockRtMessage(BaseRtsockMessage):
+ messages = [RtConst.RTM_ADD, RtConst.RTM_DELETE, RtConst.RTM_CHANGE, RtConst.RTM_GET]
+
+ def __init__(self, rtm_type, rtm_seq=1, dst_sa=None, mask_sa=None):
+ super().__init__(rtm_type)
+ self.rtm_flags = 0
+ self.rtm_seq = rtm_seq
+ self._attrs = {}
+ self.rtm_errno = 0
+ self.rtm_pid = 0
+ self._orig_data = None
+ if dst_sa:
+ self.add_sa_attr(RtConst.RTA_DST, dst_sa)
+ if mask_sa:
+ self.add_sa_attr(RtConst.RTA_NETMASK, mask_sa)
+
+ def add_sa_attr(self, attr_type, attr_bytes):
+ self._attrs[attr_type] = attr_bytes
+
+ def add_ip_attr(self, attr_type, ip: str):
+ self.add_sa_attr(attr_type, self.sa.ip_sa(ip))
+
+ def add_ip6_attr(self, attr_type, ip6: str, scopeid: int):
+ self.add_sa_attr(attr_type, self.sa.ip6_sa(ip6, scopeid))
+
+ def add_link_attr(self, attr_type, ifindex: Optional[int]=0):
+ self.add_sa_attr(attr_type, self.sa.link_sa(ifindex))
+
+ def print_sa_inet(self, sa: bytes):
+ if len(sa) < 8:
+ raise Exception("IPv4 sa size too small: {}".format(sa))
+ addr = socket.inet_ntop(socket.AF_INET, sa[4:8])
+ return "{}".format(addr)
+
+ def print_sa_inet6(self, sa: bytes):
+ if len(sa) < sizeof(SockaddrIn6):
+ raise Exception("IPv6 sa size too small: {}".format(sa))
+ addr = socket.inet_ntop(socket.AF_INET6, sa[8:24])
+ scopeid = struct.unpack(">I", sa[24:28])
+ return "{} scopeid {}".format(addr, scopeid)
+
+ def print_sa_link(self, sa: bytes, hd: Optional[bool] = True):
+ if len(sa) < sizeof(SockaddrDl):
+ raise Exception("LINK sa size too small: {}".format(sa))
+ sdl = SockaddrDl.from_buffer_copy(sa)
+ if sdl.sdl_index:
+ ifindex = "link#{} ".format(sdl.sdl_index)
+ else:
+ ifindex = ""
+ if sdl.sdl_nlen:
+ iface_offset = 8
+ if sdl.sdl_nlen + iface_offset > len(sa):
+ raise Exception("LINK sa sdl_nlen {} > total len {}".format(sdl.sdl_nlen, len(sa)))
+ ifname = "ifname:{} ".format(bytes.decode(sa[iface_offset:iface_offset + sdl.sdl_nlen]))
+ else:
+ ifname = ""
+ return "{}{}".format(ifindex, ifname)
+
+ def print_sa_unknown(self, sa: bytes):
+ return "unknown_type:{}".format(sa[1])
+
+ def print_sa(self, sa: bytes, hd: Optional[bool] = False):
+ if sa[0] != len(sa):
+ raise Exception("sa size {} != buffer size {}".format(sa[0], len(sa)))
+
+ if len(sa) < 2:
+ raise Exception("sa type {} too short: {}".format(RtConst.get_name("AF_", sa[1]), len(sa)))
+
+ if sa[1] == socket.AF_INET:
+ text = self.print_sa_inet(sa)
+ elif sa[1] == socket.AF_INET6:
+ text = self.print_sa_inet6(sa)
+ elif sa[1] == socket.AF_LINK:
+ text = self.print_sa_link(sa)
+ else:
+ text = self.print_sa_unknown(sa)
+ if hd:
+ dump = " [{}]".format(sa)
+ else:
+ dump = ""
+ return "{}{}".format(text, dump)
+
+ def print_message(self):
+ # RTM_GET: Report Metrics: len 272, pid: 87839, seq 1, errno 0, flags:<UP,GATEWAY,DONE,STATIC>
+ if self._orig_data:
+ rtm_len = len(self._orig_data)
+ else:
+ rtm_len = len(bytes(self))
+ print("{}: len {}, pid: {}, seq {}, errno {}, flags: <{}>".format(
+ RtConst.get_name("RTM_", self.rtm_type),
+ rtm_len,
+ self.rtm_pid,
+ self.rtm_seq,
+ self.rtm_errno,
+ RtConst.get_bitmask_str("RTF_", self.rtm_flags)
+ ))
+ rtm_addrs = sum(list(self._attrs.keys()))
+ print("Addrs: <{}>".format(RtConst.get_bitmask_str("RTA_", rtm_addrs)))
+ for attr in sorted(self._attrs.keys()):
+ sa_data = self.print_sa(self._attrs[attr])
+ print(" {}: {}".format(RtConst.get_name("RTA_", attr), sa_data))
+
+ @staticmethod
+ def verify_sa_inet(sa_data):
+ if len(sa_data) < 8:
+ raise Exception("IPv4 sa size too small: {}".format(sa_data))
+ if sa_data[0] > len(sa_data):
+ raise Exception("IPv4 sin_len too big: {} vs sa size {}: {}".format(sa_data[0], len(sa_data), sa_data))
+ sin = SockaddrIn.from_buffer_copy(sa_data)
+ self.assertEqual(sin.sin_port, 0)
+ assert sin.sin_zero == [0] * 8
+
+ def compare_sa(self, sa_type, sa_data):
+ if len(sa_data) < 4:
+ raise Exception("sa_len for type {} too short: {}".format())
+ our_sa = self._attrs[sa_type]
+ self.assertEqual(len(sa_data), len(our_sa))
+ self.assertEqual(our_sa, sa_data)
+
+ def verify(self, rtm_type: int, rtm_sa):
+ assert self.rtm_type == rtm_type
+ assert self.rtm_errno == 0
+ hdr = RtMsgHdr.from_buffer_copy(self._orig_data)
+ assert hdr._rtm_spare1 == 0
+ for sa_type, sa_data in rtm_sa.items():
+ if sa_type not in self._attrs:
+ raise Exception("SA type {} not present".format(RtConst.get_name("RTA_", sa_type)))
+ self.compare_sa(sa_type, sa_data)
+
+ @classmethod
+ def from_bytes(cls, data: bytes):
+ if len(data) < sizeof(RtMsgHdr):
+ raise Exception("messages size {} is less than expected {}".format(len(data), sizeof(RtMsgHdr)))
+ hdr = RtMsgHdr.from_buffer_copy(data)
+
+ self = cls(hdr.rtm_type)
+ self.rtm_flags = hdr.rtm_flags
+ self.rtm_seq = hdr.rtm_seq
+ self.rtm_errno = hdr.rtm_errno
+ self.rtm_pid = hdr.rtm_pid
+ self.rtm_len = len(data)
+ self._orig_data = data
+
+ off = sizeof(RtMsgHdr)
+ v = 1
+ addrs_mask = hdr.rtm_addrs
+ while addrs_mask:
+ if addrs_mask & v:
+ addrs_mask -= v
+
+ if off + data[off] > len(data):
+ raise Exception("SA sizeof for {} > total message length: {}+{} > {}".format(
+ RtConst.get_name("RTA_", v), off, data[off], len(data)))
+ self._attrs[v] = data[off:off + data[off]]
+ off += roundup2(data[off], 4)
+ v *= 2
+ return self
+
+ def __bytes__(self):
+ sz = sizeof(RtMsgHdr)
+ addrs_mask = 0
+ for k, v in self._attrs.items():
+ sz += roundup2(len(v), 4)
+ addrs_mask += k
+ hdr = RtMsgHdr(
+ rtm_msglen=sz,
+ rtm_version=RtConst.RTM_VERSION,
+ rtm_type=self.rtm_type,
+ rtm_flags=self.rtm_flags,
+ rtm_seq=self.rtm_seq,
+ rtm_addrs=addrs_mask,
+ )
+ buf = bytearray(sz)
+ buf[0:sizeof(RtMsgHdr)] = hdr
+ off = sizeof(RtMsgHdr)
+ for attr in sorted(self._attrs.keys()):
+ v = self._attrs[attr]
+ sa_len = len(v)
+ buf[off:off + sa_len] = v
+ off += roundup2(len(v), 4)
+ return bytes(buf)
+
+
+class Rtsock():
+ def __init__(self):
+ self.rtsock_fd = self._setup_rtsock()
+ self.rtm_seq = 1
+ self.msgmap = self.build_msgmap()
+
+ def build_msgmap(self):
+ classes = [RtsockRtMessage]
+ xmap = {}
+ for cls in classes:
+ for message in cls.messages:
+ xmap[message] = cls
+ return xmap
+
+ def get_seq(self):
+ ret = self.rtm_seq
+ self.rtm_seq += 1
+ return ret
+
+ def _setup_rtsock(self) -> int:
+ s = socket.socket(socket.AF_ROUTE, socket.SOCK_RAW, socket.AF_UNSPEC)
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_USELOOPBACK, 1)
+ return s
+
+ def write_message(self, msg):
+ print("vvvvvvvv OUT vvvvvvvv")
+ msg.print_message()
+ msg_bytes = bytes(msg)
+ try:
+ ret = os.write(self.rtsock_fd.fileno(), bytes(msg))
+ except Exception as e:
+ print("write({}) -> {}".format(len(msg_bytes), e))
+
+ def parse_message(self, data: bytes):
+ if len(data) < 4:
+ raise Exception("Short read from rtsock: {} bytes".format(len(data)))
+ rtm_type = data[4]
+ if rtm_type not in self.msgmap:
+ return None
+
+ def write_data(self, data: bytes):
+ self.rtsock_fd.send(data)
+
+ def read_data(self, seq: Optional[int] = None) -> bytes:
+ while True:
+ data = self.rtsock_fd.recv(4096)
+ if seq is None:
+ break
+ if len(data) > sizeof(RtMsgHdr):
+ hdr = RtMsgHdr.from_buffer_copy(data)
+ if hdr.rtm_seq == seq:
+ break
+ return data
+
+ def read_message(self) -> bytes:
+ data = self.read_data()
+ return self.parse_message(data)
Index: tests/sys/net/routing/test_rtsock_multipath.py
===================================================================
--- /dev/null
+++ tests/sys/net/routing/test_rtsock_multipath.py
@@ -0,0 +1,85 @@
+
+from common.atf import atf_main, ATFTestTemplate
+from common.vnet import SingleVnetTestTemplate, VnetInstance
+from rtsock import Rtsock, RtsockRtMessage, RtConst, SaHelper
+
+
+class BaseIPv4RoutingTestTemplate(SingleVnetTestTemplate):
+ options = {"require.user": "root"}
+ num_epairs = 1
+ IPV4_PREFIXES = ["192.0.2.1/24"]
+
+ def run(self):
+ super().run()
+ self.rtsock = Rtsock()
+
+
+class BaseIPv6RoutingTestTemplate(SingleVnetTestTemplate):
+ options = {"require.user": "root"}
+ num_epairs = 1
+ IPV6_PREFIXES = ["2001:DB8::1/64"]
+
+ def run(self):
+ super().run()
+ self.rtsock = Rtsock()
+
+
+class MultipathAdd(ATFTestTemplate):
+ description = "Multipath routing"
+ options = {"require.user": "root"}
+
+ def run(self):
+ self.log("HERE")
+
+ @staticmethod
+ def cleanup(self):
+ self.log("CLEANUP HERE")
+
+
+class RtmGetv4ExactSuccess(BaseIPv4RoutingTestTemplate):
+ description = "Tests RTM_GET with exact prefix lookup on an interface prefix"
+
+ def run(self):
+ super().run()
+ sa = SaHelper()
+ msg = RtsockRtMessage(RtConst.RTM_GET, self.rtsock.get_seq(), sa.ip_sa("192.0.2.0"), sa.ip_sa("255.255.255.0"))
+ self.rtsock.write_message(msg)
+
+ iface = self.vnet.ifaces[0]
+ desired_sa = {
+ RtConst.RTA_DST: sa.ip_sa("192.0.2.0"),
+ RtConst.RTA_NETMASK: sa.ip_sa("255.255.255.0"),
+ RtConst.RTA_GATEWAY: sa.link_sa(ifindex=iface.ifindex, iftype=iface.iftype),
+ }
+
+ data = self.rtsock.read_data(msg.rtm_seq)
+ msg = RtsockRtMessage.from_bytes(data)
+ print("vvvvvvvv IN vvvvvvvv")
+ msg.print_message()
+ msg.verify(RtConst.RTM_GET, desired_sa)
+
+
+class AddRouteWithRta(BaseIPv4RoutingTestTemplate):
+
+ def run(self):
+ super().run()
+ sa = SaHelper()
+ msg = RtsockRtMessage(RtConst.RTM_GET, self.rtsock.get_seq(), sa.ip_sa("192.0.2.0"), sa.ip_sa("255.255.255.0"))
+ self.rtsock.write_message(msg)
+
+ iface = self.vnet.ifaces[0]
+ desired_sa = {
+ RtConst.RTA_DST: sa.ip_sa("192.0.2.0"),
+ RtConst.RTA_NETMASK: sa.ip_sa("255.255.255.0"),
+ RtConst.RTA_GATEWAY: sa.link_sa(ifindex=iface.ifindex, iftype=iface.iftype),
+ }
+
+ data = self.rtsock.read_data(msg.rtm_seq)
+ msg = RtsockRtMessage.from_bytes(data)
+ print("vvvvvvvv IN vvvvvvvv")
+ msg.print_message()
+ msg.verify(RtConst.RTM_GET, desired_sa)
+
+
+if __name__ == '__main__':
+ atf_main(globals())

File Metadata

Mime Type
text/plain
Expires
Tue, May 26, 5:37 AM (14 h, 40 m)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
33525104
Default Alt Text
D31084.id91913.diff (31 KB)

Event Timeline