Page Menu
Home
FreeBSD
Search
Configure Global Search
Log In
Files
F157790410
D31084.id91913.diff
No One
Temporary
Actions
View File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Flag For Later
Award Token
Size
31 KB
Referenced Files
None
Subscribers
None
D31084.id91913.diff
View Options
Index: share/mk/atf.test.mk
===================================================================
--- share/mk/atf.test.mk
+++ share/mk/atf.test.mk
@@ -22,6 +22,7 @@
ATF_TESTS_CXX?=
ATF_TESTS_SH?=
ATF_TESTS_KSH93?=
+ATF_TESTS_PY?=
.if !empty(ATF_TESTS_C)
PROGS+= ${ATF_TESTS_C}
@@ -109,3 +110,29 @@
mv ${.TARGET}.tmp ${.TARGET}
.endfor
.endif
+
+.if !empty(ATF_TESTS_PY)
+SCRIPTS+= ${ATF_TESTS_PY}
+_TESTS+= ${ATF_TESTS_PY}
+.for _T in ${ATF_TESTS_PY}
+SCRIPTSDIR_${_T}= ${TESTSDIR}
+TEST_INTERFACE.${_T}= atf
+TEST_METADATA.${_T}+= required_programs="python3"
+CLEANFILES+= ${_T} ${_T}.tmp
+# TODO(jmmv): It seems to me that this SED and SRC functionality should
+# exist in bsd.prog.mk along the support for SCRIPTS. Move it there if
+# this proves to be useful within the tests.
+ATF_TESTS_PY_SED_${_T}?= # empty
+ATF_TESTS_PY_SRC_${_T}?= ${_T}.py
+${_T}: ${ATF_TESTS_PY_SRC_${_T}}
+ echo "#!/usr/bin/env -S PYTHONPATH=/usr/tests/sys python3" > ${.TARGET}.tmp
+.if empty(ATF_TESTS_PY_SED_${_T})
+ cat ${.ALLSRC:N*Makefile*} >>${.TARGET}.tmp
+.else
+ cat ${.ALLSRC:N*Makefile*} \
+ | sed ${ATF_TESTS_PY_SED_${_T}} >>${.TARGET}.tmp
+.endif
+ chmod +x ${.TARGET}.tmp
+ mv ${.TARGET}.tmp ${.TARGET}
+.endfor
+.endif
Index: tests/sys/Makefile
===================================================================
--- tests/sys/Makefile
+++ tests/sys/Makefile
@@ -31,6 +31,7 @@
TESTS_SUBDIRS+= sys
TESTS_SUBDIRS+= vfs
TESTS_SUBDIRS+= vm
+TESTS_SUBDIRS= net
.if ${MK_AUDIT} != "no"
_audit= audit
@@ -45,4 +46,7 @@
SUBDIR+= common
+PACKAGE= tests
+${PACKAGE}FILES+= __init__.py
+
.include <bsd.test.mk>
Index: tests/sys/common/Makefile
===================================================================
--- tests/sys/common/Makefile
+++ tests/sys/common/Makefile
@@ -6,6 +6,9 @@
${PACKAGE}FILES+= divert.py
${PACKAGE}FILES+= sender.py
${PACKAGE}FILES+= net_receiver.py
+${PACKAGE}FILES+= atf.py
+${PACKAGE}FILES+= vnet.py
+${PACKAGE}FILES+= __init__.py
${PACKAGE}FILESMODE_divert.py=0555
${PACKAGE}FILESMODE_sender.py=0555
Index: tests/sys/common/atf.py
===================================================================
--- /dev/null
+++ tests/sys/common/atf.py
@@ -0,0 +1,155 @@
+#!/usr/local/bin/python3
+
+import argparse
+import unittest
+import inspect
+import sys
+import traceback
+from enum import auto, Enum
+
+from typing import Dict, List, Optional
+
+
+class ATFTestResult(Enum):
+ FAILED = auto()
+ PASSED = auto()
+ SKIPPED = auto()
+
+
+class ATFTestTemplate(unittest.TestCase):
+ description = "Base ATF test"
+ options = {}
+
+ def __init__(self, params: Dict):
+ self._params = params
+
+ def log(self, msg: str):
+ print(msg, file=sys.stderr)
+
+ @classmethod
+ def test_name(cls):
+ return camel_to_lc(cls.__name__)
+
+
+def camel_to_lc(name: str):
+ ret = []
+ for c in name:
+ if c.isupper() and ret:
+ ret.append("_")
+ ret.append(c.lower())
+ return ''.join(ret)
+
+
+class ATFHandler(object):
+ def __init__(self, args, globals_dict):
+ self._args = args
+ self._test_params = ATFHandler._get_params(args.params)
+ self._tests = ATFHandler.get_tests(globals_dict)
+ if args.stdout:
+ self._status_fd = open(args.stdout, "w")
+ else:
+ self._status_fd = None
+
+ def report_result(self, result: ATFTestResult, reason: Optional[str]=None):
+ exit_code = 1
+ if result in {ATFTestResult.PASSED, ATFTestResult.SKIPPED}:
+ exit_code = 0
+ text = result.name.lower()
+ if reason:
+ text = "{} {}".format(text, reason)
+ print(text, file=self._status_fd)
+ sys.exit(exit_code)
+
+ def report_failure(self, reason: str):
+ self.report_result(ATFTestResult.FAILED, reason)
+
+ @staticmethod
+ def is_test_class(cls):
+ return inspect.isclass(cls) and not cls.__name__.endswith("Template")
+
+ @staticmethod
+ def get_tests(globals_dict):
+ candidates = globals_dict.values()
+ candidates = [cl for cl in candidates if ATFHandler.is_test_class(cl)]
+ return [cl for cl in candidates if issubclass(cl, ATFTestTemplate)]
+
+ @staticmethod
+ def _get_params(params_list: List[str]) -> Dict[str, str]:
+ res = {}
+ for param in params_list:
+ nv = param.split("=", 1)
+ if len(nv) == 2:
+ res[nv[0]] = nv[1]
+ return res
+
+ def list_test(self, test):
+ print("ident: {}".format(test.test_name()))
+ if hasattr(test, "cleanup"):
+ has_cleanup = "true"
+ else:
+ has_cleanup = "false"
+ print("has.cleanup: {}".format(has_cleanup))
+ # Encap
+ print("descr: {}".format(test.description))
+ for option_name, option_val in test.options.items():
+ print("{}: {}".format(option_name, option_val))
+
+ def list_tests(self):
+ print("Content-Type: application/X-atf-tp; version=\"1\"")
+ print()
+ for test in self._tests:
+ self.list_test(test)
+ print()
+
+ def get_test_instance(self, test_name: str):
+ for test in self._tests:
+ if test.test_name() == test_name:
+ return test(self._test_params)
+ return None
+
+ def run_test(self, test_name: str) -> bool:
+ cleanup = False
+ if test_name.endswith(":cleanup"):
+ test_name = test_name[:-len(":cleanup")]
+ cleanup = True
+ test = self.get_test_instance(test_name)
+ try:
+ pass
+ except Exception as e:
+ traceback.print_exc()
+ self.report_failure("unable to init test: {}".format(e))
+ if not test:
+ self.report_failure("Unknown test case `{}'".format(test_name))
+ if not cleanup:
+ try:
+ test.run()
+ except Exception as e:
+ traceback.print_exc()
+ self.report_failure("error: {}".format(e))
+ self.report_result(ATFTestResult.PASSED)
+ else:
+ try:
+ test.cleanup()
+ except Exception as e:
+ traceback.print_exc()
+ self.report_failure("error: {}".format(e))
+
+
+def parser():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-l", dest="list", action="store_true")
+ parser.add_argument("-s", dest="dir")
+ parser.add_argument("-r", dest="stdout")
+ parser.add_argument("-v", dest="params", action="append", default=[])
+ parser.add_argument("test_name", nargs="?")
+ return parser.parse_args()
+
+
+def atf_main(globals_dict):
+ args = parser()
+ if args.list:
+ handler = ATFHandler(args, globals_dict)
+ handler.list_tests()
+ elif args.test_name:
+ handler = ATFHandler(args, globals_dict)
+ handler.run_test(args.test_name)
Index: tests/sys/common/vnet.py
===================================================================
--- /dev/null
+++ tests/sys/common/vnet.py
@@ -0,0 +1,220 @@
+#!/usr/local/bin/python3
+
+from atf import atf_main, ATFTestTemplate
+import time
+
+from ctypes.util import find_library
+from ctypes import cdll
+
+from typing import Dict, List
+
+import os
+import socket
+
+
+def run_cmd(cmd: str) -> str:
+ print("run: '{}'".format(cmd))
+ return os.popen(cmd).read()
+
+
+class VnetInterface(object):
+ INTERFACES_FNAME = "created_interfaces.lst"
+
+ IFT_LOOP = 0x18
+ IFT_ETHER = 0x06
+
+ def __init__(self, iface_name: str):
+ self.name = iface_name
+ self.vnet_name = ""
+ self.jailed = False
+ if iface_name.startswith("lo"):
+ self.iftype = self.IFT_LOOP
+ else:
+ self.iftype = self.IFT_ETHER
+
+ @property
+ def ifindex(self):
+ return socket.if_nametoindex(self.name)
+
+ def set_vnet(self, vnet_name: str):
+ self.vnet_name = vnet_name
+
+ def set_jailed(self, jailed: bool):
+ self.jailed = jailed
+
+ def run_cmd(self, cmd):
+ if self.vnet_name and not self.jailed:
+ cmd = "jexec {} {}".format(self.vnet_name, cmd)
+ run_cmd(cmd)
+
+ @staticmethod
+ def file_append_line(line):
+ with open(VnetInterface.INTERFACES_FNAME, "a") as f:
+ f.write(line + "\n")
+
+ @classmethod
+ def create_iface(cls, iface_name: str):
+ iface_name = run_cmd("/sbin/ifconfig {} create".format(iface_name)).rstrip()
+ if not iface_name:
+ raise Exception("Unable to create iface {}".format(iface_name))
+ cls.file_append_line(iface_name)
+ if iface_name.startswith("epair"):
+ cls.file_append_line(iface_name[:-1] + "b")
+ return cls(iface_name)
+
+ @staticmethod
+ def cleanup_ifaces():
+ try:
+ with open(VnetInterface.INTERFACES_FNAME, "r") as f:
+ for line in f:
+ run_cmd("/sbin/ifconfig {} destroy".format(line.strip()))
+ os.unlink(VnetInterface.INTERFACES_FNAME)
+ except Exception as e:
+ pass
+
+ def setup_addr(self, addr: str):
+ if ":" in addr:
+ family = "inet6"
+ else:
+ family = "inet"
+ cmd = "/sbin/ifconfig {} {} {}".format(self.name, family, addr)
+ self.run_cmd(cmd)
+
+ def delete_addr(self, addr: str):
+ if ":" in addr:
+ cmd = "/sbin/ifconfig {} inet6 {} delete".format(self.name, addr)
+ else:
+ cmd = "/sbin/ifconfig {} -alias {}".format(self.name, addr)
+ self.run_cmd(cmd)
+
+ def turn_up(self):
+ cmd = "/sbin/ifconfig {} up".format(self.name)
+ self.run_cmd(cmd)
+
+ def enable_ipv6(self):
+ cmd = "/usr/sbin/ndp -i {} -disabled".format(self.name)
+ self.run_cmd(cmd)
+
+
+class VnetInstance(object):
+ JAILS_FNAME = "created_jails.lst"
+
+ def __init__(self, vnet_name: str, jid: int, ifaces: List[VnetInterface]):
+ self.name = vnet_name
+ self.jid = jid
+ self.ifaces = ifaces
+ for iface in ifaces:
+ iface.set_vnet(vnet_name)
+ iface.set_jailed(True)
+
+ def run_vnet_cmd(self, cmd):
+ if self.vnet_name:
+ cmd = "jexec {} {}".format(self.vnet_name, cmd)
+ return run_cmd(cmd)
+
+ @staticmethod
+ def wait_interface(vnet_name: str, iface_name: str):
+ cmd = "jexec {} /sbin/ifconfig -l".format(vnet_name)
+ for i in range(50):
+ ifaces = run_cmd(cmd).strip().split(" ")
+ if iface_name in ifaces:
+ return True
+ time.sleep(0.1)
+ return False
+
+ @staticmethod
+ def file_append_line(line):
+ with open(VnetInstance.JAILS_FNAME, "a") as f:
+ f.write(line + "\n")
+
+ @staticmethod
+ def cleanup_vnets():
+ try:
+ with open(VnetInstance.JAILS_FNAME) as f:
+ for line in f:
+ run_cmd("/usr/sbin/jail -r {}".format(line.strip()))
+ os.unlink(VnetInstance.JAILS_FNAME)
+ except Exception as e:
+ pass
+
+ @classmethod
+ def create_with_interfaces(cls, vnet_name: str, ifaces: List[VnetInterface]):
+ iface_cmds = " ".join(["vnet.interface={}".format(i.name) for i in ifaces])
+ cmd = "/usr/sbin/jail -i -c name={} persist vnet {}".format(vnet_name, iface_cmds)
+ jid_str = run_cmd(cmd)
+ jid = int(jid_str)
+ if jid <= 0:
+ raise Exception("Jail creation failed, output: {}".format(jid))
+ cls.file_append_line(vnet_name)
+
+ for iface in ifaces:
+ if cls.wait_interface(vnet_name, iface.name):
+ continue
+ raise Exception("Interface {} has not appeared in vnet {}".format(iface.name, vnet_name))
+ return cls(vnet_name, jid, ifaces)
+
+ @staticmethod
+ def attach_jid(jid: int):
+ libc = cdll.LoadLibrary(find_library("c"))
+ if libc.jail_attach(jid) != 0:
+ raise Exception("jail_attach() failed: errno {}".format(sys.errno))
+
+ def attach(self):
+ self.attach_jid(self.jid)
+
+
+class SingleVnetTestTemplate(ATFTestTemplate):
+ num_epairs = 1
+
+ def __init__(self, params: Dict[str, str]):
+ super().__init__(params)
+
+ def run(self):
+ vnet_name = "jail_{}".format(self.test_name())
+ ifaces = []
+ for i in range(self.num_epairs):
+ ifaces.append(VnetInterface.create_iface("epair"))
+ self.vnet = VnetInstance.create_with_interfaces(vnet_name, ifaces)
+ self.vnet.attach()
+ if hasattr(self, "IPV6_PREFIXES"):
+ for i, addr in enumerate(self.IPV6_PREFIXES):
+ if addr:
+ iface = self.vnet.ifaces[i]
+ iface.turn_up()
+ iface.enable_ipv6()
+ iface.setup_addr(addr)
+ if hasattr(self, "IPV4_PREFIXES"):
+ for i, addr in enumerate(self.IPV4_PREFIXES):
+ if addr:
+ iface = self.vnet.ifaces[i]
+ iface.turn_up()
+ iface.setup_addr(addr)
+
+ def cleanup(self):
+ print("==== vnet cleanup ===")
+ # XXX: sleep 100ms to avoid epair qflush panic
+ time.sleep(0.1)
+ VnetInstance.cleanup_vnets()
+ VnetInterface.cleanup_ifaces()
+
+ def run_cmd(self, cmd: str) -> str:
+ return os.popen(cmd).read()
+
+ def create_iface(self, iface_name: str):
+ iface_name = self.run_cmd("/sbin/ifconfig {} create".format(iface_name)).rstrip()
+ if not iface_name:
+ raise Exception("Unable to create iface {}".format(iface_name))
+ self.file_append_line(iface_name)
+ if iface_name.startswith("epair"):
+ self.file_append_line(iface_name[:-1] + "b")
+ return iface_name
+
+ def destroy_iface(self, iface_name: str):
+ self.run_cmd("/sbin/ifconfig {} destroy".format(iface_name))
+
+ def setup_iface_addr(self, addr: str):
+ pass
+
+ def file_append_line(self, line):
+ with open(self.IFACES_FNAME, "a") as f:
+ f.write(iface_name + "\n")
Index: tests/sys/net/routing/Makefile
===================================================================
--- tests/sys/net/routing/Makefile
+++ tests/sys/net/routing/Makefile
@@ -8,6 +8,8 @@
ATF_TESTS_C += test_rtsock_l3
ATF_TESTS_C += test_rtsock_lladdr
+ATF_TESTS_PY += test_rtsock_multipath
+
${PACKAGE}FILES+= generic_cleanup.sh
${PACKAGE}FILESMODE_generic_cleanup.sh=0555
Index: tests/sys/net/routing/rtsock.py
===================================================================
--- /dev/null
+++ tests/sys/net/routing/rtsock.py
@@ -0,0 +1,440 @@
+#!/usr/local/bin/python3
+
+from ctypes import *
+import socket
+import os
+import sys
+import unittest
+
+from typing import List, Dict, Optional
+
+
+def roundup2(val: int, num: int) -> int:
+ if val % num:
+ return (val | (num - 1)) + 1
+ else:
+ return val
+
+
+class RtConst():
+ RTM_VERSION = 5
+
+ AF_INET = socket.AF_INET
+ AF_INET6 = socket.AF_INET6
+ AF_LINK = socket.AF_LINK
+
+ RTA_DST = 0x1
+ RTA_GATEWAY = 0x2
+ RTA_NETMASK = 0x4
+ RTA_GENMASK = 0x8
+ RTA_IFP = 0x10
+ RTA_IFA = 0x20
+ RTA_AUTHOR = 0x40
+ RTA_BRD = 0x80
+
+ RTM_ADD = 1
+ RTM_DELETE = 2
+ RTM_CHANGE = 3
+ RTM_GET = 4
+
+ RTF_UP = 0x1
+ RTF_GATEWAY = 0x2
+ RTF_HOST = 0x4
+ RTF_REJECT = 0x8
+ RTF_DYNAMIC = 0x10
+ RTF_MODIFIED = 0x20
+ RTF_DONE = 0x40
+ RTF_XRESOLVE = 0x200
+ RTF_LLINFO = 0x400
+ RTF_LLDATA = 0x400
+ RTF_STATIC = 0x800
+ RTF_BLACKHOLE = 0x1000
+ RTF_PROTO2 = 0x4000
+ RTF_PROTO1 = 0x8000
+ RTF_PROTO3 = 0x40000
+ RTF_FIXEDMTU = 0x80000
+ RTF_PINNED = 0x100000
+ RTF_LOCAL = 0x200000
+ RTF_BROADCAST = 0x400000
+ RTF_MULTICAST = 0x800000
+ RTF_STICKY = 0x10000000
+ RTF_RNH_LOCKED = 0x40000000
+ RTF_GWFLAG_COMPAT = 0x80000000
+
+ @staticmethod
+ def get_props(prefix: str) -> List[str]:
+ return [n for n in dir(RtConst) if n.startswith(prefix)]
+
+ @staticmethod
+ def get_name(prefix: str, value: int) -> str:
+ props = RtConst.get_props(prefix)
+ for prop in props:
+ if getattr(RtConst, prop) == value:
+ return prop
+ return "U:" . str(value)
+
+ @staticmethod
+ def get_bitmask_map(prefix: str, value: int) -> Dict[int, str]:
+ r1 = {}
+ props = RtConst.get_props(prefix)
+ propmap = {getattr(RtConst, prop): prop for prop in props}
+ v = 1
+ ret = {}
+ while value:
+ if v & value:
+ if v in propmap:
+ ret[v] = propmap[v]
+ else:
+ ret[v] = hex(v)
+ value -= v
+ v *= 2
+ return ret
+
+ @staticmethod
+ def get_bitmask_str(prefix: str, value: int) -> str:
+ bmap = RtConst.get_bitmask_map(prefix, value)
+ return ",".join([v for k, v in bmap.items()])
+
+
+class RtMetrics(Structure):
+ _fields_ = [
+ ("rmx_locks", c_ulong),
+ ("rmx_mtu", c_ulong),
+ ("rmx_hopcount", c_ulong),
+ ("rmx_expire", c_ulong),
+ ("rmx_recvpipe", c_ulong),
+ ("rmx_sendpipe", c_ulong),
+ ("rmx_ssthresh", c_ulong),
+ ("rmx_rtt", c_ulong),
+ ("rmx_rttvar", c_ulong),
+ ("rmx_pksent", c_ulong),
+ ("rmx_weight", c_ulong),
+ ("rmx_nhidx", c_ulong),
+ ("rmx_filler", c_ulong * 2),
+ ]
+
+
+class RtMsgHdr(Structure):
+ _fields_ = [
+ ("rtm_msglen", c_ushort),
+ ("rtm_version", c_byte),
+ ("rtm_type", c_byte),
+ ("rtm_index", c_ushort),
+ ("_rtm_spare1", c_ushort),
+ ("rtm_flags", c_int),
+ ("rtm_addrs", c_int),
+ ("rtm_pid", c_int),
+ ("rtm_seq", c_int),
+ ("rtm_errno", c_int),
+ ("rtm_fmask", c_int),
+ ("rtm_inits", c_ulong),
+ ("rtm_rmx", RtMetrics),
+ ]
+
+
+class SockaddrIn(Structure):
+ _fields_ = [
+ ("sin_len", c_byte),
+ ("sin_family", c_byte),
+ ("sin_port", c_ushort),
+ ("sin_addr", c_uint32),
+ ("sin_zero", c_char * 8),
+ ]
+
+
+class SockaddrIn6(Structure):
+ _fields_ = [
+ ("sin6_len", c_byte),
+ ("sin6_family", c_byte),
+ ("sin6_port", c_ushort),
+ ("sin6_flowinfo", c_uint32),
+ ("sin6_addr", c_byte * 16),
+ ("sin6_scope_id", c_uint32),
+ ]
+
+
+class SockaddrDl(Structure):
+ _fields_ = [
+ ("sdl_len", c_byte),
+ ("sdl_family", c_byte),
+ ("sdl_index", c_ushort),
+ ("sdl_type", c_byte),
+ ("sdl_nlen", c_byte),
+ ("sdl_alen", c_byte),
+ ("sdl_slen", c_byte),
+ ("sdl_data", c_byte * 8),
+ ]
+
+
+class SaHelper(object):
+ @staticmethod
+ def ip_sa(ip: str) -> bytes:
+ addr_int = int.from_bytes(socket.inet_pton(2, ip), sys.byteorder)
+ sin = SockaddrIn(sizeof(SockaddrIn), socket.AF_INET, 0, addr_int)
+ return bytes(sin)
+
+ @staticmethod
+ def ip6_sa(ip6: str, scopeid: int):
+ addr_bytes = socket.inet_pton(socket.AF_INET6, ip6)
+ sin6 = SockaddrIn6(sizeof(SockaddrIn6), socket.AF_INET6, 0, 0, addr_bytes, 0)
+ return bytes(sin6)
+
+ @staticmethod
+ def link_sa(ifindex: Optional[int]=0, iftype: Optional[int]=0):
+ sa = SockaddrDl(sizeof(SockaddrDl), socket.AF_LINK, c_ushort(ifindex), iftype)
+ return bytes(sa)
+
+
+class BaseRtsockMessage(object):
+ def __init__(self, rtm_type):
+ self.rtm_type = rtm_type
+ self.ut = unittest.TestCase()
+ self.sa = SaHelper()
+
+ def assertEqual(self, a, b, msg=None):
+ self.ut.assertEqual(a, b, msg)
+
+ def assertNotEqual(self, a, b, msg=None):
+ self.ut.assertNotEqual(a, b, msg)
+
+
+class RtsockRtMessage(BaseRtsockMessage):
+ messages = [RtConst.RTM_ADD, RtConst.RTM_DELETE, RtConst.RTM_CHANGE, RtConst.RTM_GET]
+
+ def __init__(self, rtm_type, rtm_seq=1, dst_sa=None, mask_sa=None):
+ super().__init__(rtm_type)
+ self.rtm_flags = 0
+ self.rtm_seq = rtm_seq
+ self._attrs = {}
+ self.rtm_errno = 0
+ self.rtm_pid = 0
+ self._orig_data = None
+ if dst_sa:
+ self.add_sa_attr(RtConst.RTA_DST, dst_sa)
+ if mask_sa:
+ self.add_sa_attr(RtConst.RTA_NETMASK, mask_sa)
+
+ def add_sa_attr(self, attr_type, attr_bytes):
+ self._attrs[attr_type] = attr_bytes
+
+ def add_ip_attr(self, attr_type, ip: str):
+ self.add_sa_attr(attr_type, self.sa.ip_sa(ip))
+
+ def add_ip6_attr(self, attr_type, ip6: str, scopeid: int):
+ self.add_sa_attr(attr_type, self.sa.ip6_sa(ip6, scopeid))
+
+ def add_link_attr(self, attr_type, ifindex: Optional[int]=0):
+ self.add_sa_attr(attr_type, self.sa.link_sa(ifindex))
+
+ def print_sa_inet(self, sa: bytes):
+ if len(sa) < 8:
+ raise Exception("IPv4 sa size too small: {}".format(sa))
+ addr = socket.inet_ntop(socket.AF_INET, sa[4:8])
+ return "{}".format(addr)
+
+ def print_sa_inet6(self, sa: bytes):
+ if len(sa) < sizeof(SockaddrIn6):
+ raise Exception("IPv6 sa size too small: {}".format(sa))
+ addr = socket.inet_ntop(socket.AF_INET6, sa[8:24])
+ scopeid = struct.unpack(">I", sa[24:28])
+ return "{} scopeid {}".format(addr, scopeid)
+
+ def print_sa_link(self, sa: bytes, hd: Optional[bool] = True):
+ if len(sa) < sizeof(SockaddrDl):
+ raise Exception("LINK sa size too small: {}".format(sa))
+ sdl = SockaddrDl.from_buffer_copy(sa)
+ if sdl.sdl_index:
+ ifindex = "link#{} ".format(sdl.sdl_index)
+ else:
+ ifindex = ""
+ if sdl.sdl_nlen:
+ iface_offset = 8
+ if sdl.sdl_nlen + iface_offset > len(sa):
+ raise Exception("LINK sa sdl_nlen {} > total len {}".format(sdl.sdl_nlen, len(sa)))
+ ifname = "ifname:{} ".format(bytes.decode(sa[iface_offset:iface_offset + sdl.sdl_nlen]))
+ else:
+ ifname = ""
+ return "{}{}".format(ifindex, ifname)
+
+ def print_sa_unknown(self, sa: bytes):
+ return "unknown_type:{}".format(sa[1])
+
+ def print_sa(self, sa: bytes, hd: Optional[bool] = False):
+ if sa[0] != len(sa):
+ raise Exception("sa size {} != buffer size {}".format(sa[0], len(sa)))
+
+ if len(sa) < 2:
+ raise Exception("sa type {} too short: {}".format(RtConst.get_name("AF_", sa[1]), len(sa)))
+
+ if sa[1] == socket.AF_INET:
+ text = self.print_sa_inet(sa)
+ elif sa[1] == socket.AF_INET6:
+ text = self.print_sa_inet6(sa)
+ elif sa[1] == socket.AF_LINK:
+ text = self.print_sa_link(sa)
+ else:
+ text = self.print_sa_unknown(sa)
+ if hd:
+ dump = " [{}]".format(sa)
+ else:
+ dump = ""
+ return "{}{}".format(text, dump)
+
+ def print_message(self):
+ # RTM_GET: Report Metrics: len 272, pid: 87839, seq 1, errno 0, flags:<UP,GATEWAY,DONE,STATIC>
+ if self._orig_data:
+ rtm_len = len(self._orig_data)
+ else:
+ rtm_len = len(bytes(self))
+ print("{}: len {}, pid: {}, seq {}, errno {}, flags: <{}>".format(
+ RtConst.get_name("RTM_", self.rtm_type),
+ rtm_len,
+ self.rtm_pid,
+ self.rtm_seq,
+ self.rtm_errno,
+ RtConst.get_bitmask_str("RTF_", self.rtm_flags)
+ ))
+ rtm_addrs = sum(list(self._attrs.keys()))
+ print("Addrs: <{}>".format(RtConst.get_bitmask_str("RTA_", rtm_addrs)))
+ for attr in sorted(self._attrs.keys()):
+ sa_data = self.print_sa(self._attrs[attr])
+ print(" {}: {}".format(RtConst.get_name("RTA_", attr), sa_data))
+
+ @staticmethod
+ def verify_sa_inet(sa_data):
+ if len(sa_data) < 8:
+ raise Exception("IPv4 sa size too small: {}".format(sa_data))
+ if sa_data[0] > len(sa_data):
+ raise Exception("IPv4 sin_len too big: {} vs sa size {}: {}".format(sa_data[0], len(sa_data), sa_data))
+ sin = SockaddrIn.from_buffer_copy(sa_data)
+ self.assertEqual(sin.sin_port, 0)
+ assert sin.sin_zero == [0] * 8
+
+ def compare_sa(self, sa_type, sa_data):
+ if len(sa_data) < 4:
+ raise Exception("sa_len for type {} too short: {}".format())
+ our_sa = self._attrs[sa_type]
+ self.assertEqual(len(sa_data), len(our_sa))
+ self.assertEqual(our_sa, sa_data)
+
+ def verify(self, rtm_type: int, rtm_sa):
+ assert self.rtm_type == rtm_type
+ assert self.rtm_errno == 0
+ hdr = RtMsgHdr.from_buffer_copy(self._orig_data)
+ assert hdr._rtm_spare1 == 0
+ for sa_type, sa_data in rtm_sa.items():
+ if sa_type not in self._attrs:
+ raise Exception("SA type {} not present".format(RtConst.get_name("RTA_", sa_type)))
+ self.compare_sa(sa_type, sa_data)
+
+ @classmethod
+ def from_bytes(cls, data: bytes):
+ if len(data) < sizeof(RtMsgHdr):
+ raise Exception("messages size {} is less than expected {}".format(len(data), sizeof(RtMsgHdr)))
+ hdr = RtMsgHdr.from_buffer_copy(data)
+
+ self = cls(hdr.rtm_type)
+ self.rtm_flags = hdr.rtm_flags
+ self.rtm_seq = hdr.rtm_seq
+ self.rtm_errno = hdr.rtm_errno
+ self.rtm_pid = hdr.rtm_pid
+ self.rtm_len = len(data)
+ self._orig_data = data
+
+ off = sizeof(RtMsgHdr)
+ v = 1
+ addrs_mask = hdr.rtm_addrs
+ while addrs_mask:
+ if addrs_mask & v:
+ addrs_mask -= v
+
+ if off + data[off] > len(data):
+ raise Exception("SA sizeof for {} > total message length: {}+{} > {}".format(
+ RtConst.get_name("RTA_", v), off, data[off], len(data)))
+ self._attrs[v] = data[off:off + data[off]]
+ off += roundup2(data[off], 4)
+ v *= 2
+ return self
+
+ def __bytes__(self):
+ sz = sizeof(RtMsgHdr)
+ addrs_mask = 0
+ for k, v in self._attrs.items():
+ sz += roundup2(len(v), 4)
+ addrs_mask += k
+ hdr = RtMsgHdr(
+ rtm_msglen=sz,
+ rtm_version=RtConst.RTM_VERSION,
+ rtm_type=self.rtm_type,
+ rtm_flags=self.rtm_flags,
+ rtm_seq=self.rtm_seq,
+ rtm_addrs=addrs_mask,
+ )
+ buf = bytearray(sz)
+ buf[0:sizeof(RtMsgHdr)] = hdr
+ off = sizeof(RtMsgHdr)
+ for attr in sorted(self._attrs.keys()):
+ v = self._attrs[attr]
+ sa_len = len(v)
+ buf[off:off + sa_len] = v
+ off += roundup2(len(v), 4)
+ return bytes(buf)
+
+
+class Rtsock():
+ def __init__(self):
+ self.rtsock_fd = self._setup_rtsock()
+ self.rtm_seq = 1
+ self.msgmap = self.build_msgmap()
+
+ def build_msgmap(self):
+ classes = [RtsockRtMessage]
+ xmap = {}
+ for cls in classes:
+ for message in cls.messages:
+ xmap[message] = cls
+ return xmap
+
+ def get_seq(self):
+ ret = self.rtm_seq
+ self.rtm_seq += 1
+ return ret
+
+ def _setup_rtsock(self) -> int:
+ s = socket.socket(socket.AF_ROUTE, socket.SOCK_RAW, socket.AF_UNSPEC)
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_USELOOPBACK, 1)
+ return s
+
+ def write_message(self, msg):
+ print("vvvvvvvv OUT vvvvvvvv")
+ msg.print_message()
+ msg_bytes = bytes(msg)
+ try:
+ ret = os.write(self.rtsock_fd.fileno(), bytes(msg))
+ except Exception as e:
+ print("write({}) -> {}".format(len(msg_bytes), e))
+
+ def parse_message(self, data: bytes):
+ if len(data) < 4:
+ raise Exception("Short read from rtsock: {} bytes".format(len(data)))
+ rtm_type = data[4]
+ if rtm_type not in self.msgmap:
+ return None
+
+ def write_data(self, data: bytes):
+ self.rtsock_fd.send(data)
+
+ def read_data(self, seq: Optional[int] = None) -> bytes:
+ while True:
+ data = self.rtsock_fd.recv(4096)
+ if seq is None:
+ break
+ if len(data) > sizeof(RtMsgHdr):
+ hdr = RtMsgHdr.from_buffer_copy(data)
+ if hdr.rtm_seq == seq:
+ break
+ return data
+
+ def read_message(self) -> bytes:
+ data = self.read_data()
+ return self.parse_message(data)
Index: tests/sys/net/routing/test_rtsock_multipath.py
===================================================================
--- /dev/null
+++ tests/sys/net/routing/test_rtsock_multipath.py
@@ -0,0 +1,85 @@
+
+from common.atf import atf_main, ATFTestTemplate
+from common.vnet import SingleVnetTestTemplate, VnetInstance
+from rtsock import Rtsock, RtsockRtMessage, RtConst, SaHelper
+
+
+class BaseIPv4RoutingTestTemplate(SingleVnetTestTemplate):
+ options = {"require.user": "root"}
+ num_epairs = 1
+ IPV4_PREFIXES = ["192.0.2.1/24"]
+
+ def run(self):
+ super().run()
+ self.rtsock = Rtsock()
+
+
+class BaseIPv6RoutingTestTemplate(SingleVnetTestTemplate):
+ options = {"require.user": "root"}
+ num_epairs = 1
+ IPV6_PREFIXES = ["2001:DB8::1/64"]
+
+ def run(self):
+ super().run()
+ self.rtsock = Rtsock()
+
+
+class MultipathAdd(ATFTestTemplate):
+ description = "Multipath routing"
+ options = {"require.user": "root"}
+
+ def run(self):
+ self.log("HERE")
+
+ @staticmethod
+ def cleanup(self):
+ self.log("CLEANUP HERE")
+
+
+class RtmGetv4ExactSuccess(BaseIPv4RoutingTestTemplate):
+ description = "Tests RTM_GET with exact prefix lookup on an interface prefix"
+
+ def run(self):
+ super().run()
+ sa = SaHelper()
+ msg = RtsockRtMessage(RtConst.RTM_GET, self.rtsock.get_seq(), sa.ip_sa("192.0.2.0"), sa.ip_sa("255.255.255.0"))
+ self.rtsock.write_message(msg)
+
+ iface = self.vnet.ifaces[0]
+ desired_sa = {
+ RtConst.RTA_DST: sa.ip_sa("192.0.2.0"),
+ RtConst.RTA_NETMASK: sa.ip_sa("255.255.255.0"),
+ RtConst.RTA_GATEWAY: sa.link_sa(ifindex=iface.ifindex, iftype=iface.iftype),
+ }
+
+ data = self.rtsock.read_data(msg.rtm_seq)
+ msg = RtsockRtMessage.from_bytes(data)
+ print("vvvvvvvv IN vvvvvvvv")
+ msg.print_message()
+ msg.verify(RtConst.RTM_GET, desired_sa)
+
+
+class AddRouteWithRta(BaseIPv4RoutingTestTemplate):
+
+ def run(self):
+ super().run()
+ sa = SaHelper()
+ msg = RtsockRtMessage(RtConst.RTM_GET, self.rtsock.get_seq(), sa.ip_sa("192.0.2.0"), sa.ip_sa("255.255.255.0"))
+ self.rtsock.write_message(msg)
+
+ iface = self.vnet.ifaces[0]
+ desired_sa = {
+ RtConst.RTA_DST: sa.ip_sa("192.0.2.0"),
+ RtConst.RTA_NETMASK: sa.ip_sa("255.255.255.0"),
+ RtConst.RTA_GATEWAY: sa.link_sa(ifindex=iface.ifindex, iftype=iface.iftype),
+ }
+
+ data = self.rtsock.read_data(msg.rtm_seq)
+ msg = RtsockRtMessage.from_bytes(data)
+ print("vvvvvvvv IN vvvvvvvv")
+ msg.print_message()
+ msg.verify(RtConst.RTM_GET, desired_sa)
+
+
+if __name__ == '__main__':
+ atf_main(globals())
File Metadata
Details
Attached
Mime Type
text/plain
Expires
Tue, May 26, 5:37 AM (14 h, 40 m)
Storage Engine
blob
Storage Format
Raw Data
Storage Handle
33525104
Default Alt Text
D31084.id91913.diff (31 KB)
Attached To
Mode
D31084: Add basic python atf support
Attached
Detach File
Event Timeline
Log In to Comment