diff --git a/tests/atf_python/sys/netlink/message.py b/tests/atf_python/sys/netlink/message.py index 1e2b71775102..98a1e3bb21c5 100644 --- a/tests/atf_python/sys/netlink/message.py +++ b/tests/atf_python/sys/netlink/message.py @@ -1,261 +1,286 @@ #!/usr/local/bin/python3 import struct from ctypes import sizeof from enum import Enum from typing import List from typing import NamedTuple from atf_python.sys.netlink.attrs import NlAttr from atf_python.sys.netlink.attrs import NlAttrNested from atf_python.sys.netlink.base_headers import NlmAckFlags from atf_python.sys.netlink.base_headers import NlmNewFlags from atf_python.sys.netlink.base_headers import NlmGetFlags from atf_python.sys.netlink.base_headers import NlmDeleteFlags from atf_python.sys.netlink.base_headers import NlmBaseFlags from atf_python.sys.netlink.base_headers import Nlmsghdr from atf_python.sys.netlink.base_headers import NlMsgType from atf_python.sys.netlink.utils import align4 from atf_python.sys.netlink.utils import enum_or_int from atf_python.sys.netlink.utils import get_bitmask_str class NlMsgCategory(Enum): UNKNOWN = 0 GET = 1 NEW = 2 DELETE = 3 ACK = 4 class NlMsgProps(NamedTuple): msg: Enum category: NlMsgCategory class BaseNetlinkMessage(object): def __init__(self, helper, nlmsg_type): self.nlmsg_type = enum_or_int(nlmsg_type) self.nla_list = [] self._orig_data = None self.helper = helper self.nl_hdr = Nlmsghdr( nlmsg_type=self.nlmsg_type, nlmsg_seq=helper.get_seq(), nlmsg_pid=helper.pid ) self.base_hdr = None def set_request(self, need_ack=True): self.add_nlflags([NlmBaseFlags.NLM_F_REQUEST]) if need_ack: self.add_nlflags([NlmBaseFlags.NLM_F_ACK]) def add_nlflags(self, flags: List): int_flags = 0 for flag in flags: int_flags |= enum_or_int(flag) self.nl_hdr.nlmsg_flags |= int_flags def add_nla(self, nla): self.nla_list.append(nla) def _get_nla(self, nla_list, nla_type): nla_type_raw = enum_or_int(nla_type) for nla in nla_list: if nla.nla_type == nla_type_raw: return nla return None def get_nla(self, nla_type): return self._get_nla(self.nla_list, nla_type) @staticmethod def parse_nl_header(data: bytes): if len(data) < sizeof(Nlmsghdr): raise ValueError("length less than netlink message header") return Nlmsghdr.from_buffer_copy(data), sizeof(Nlmsghdr) def is_type(self, nlmsg_type): nlmsg_type_raw = enum_or_int(nlmsg_type) return nlmsg_type_raw == self.nl_hdr.nlmsg_type def is_reply(self, hdr): return hdr.nlmsg_type == NlMsgType.NLMSG_ERROR.value @property def msg_name(self): return "msg#{}".format(self._get_msg_type()) def _get_nl_category(self): if self.is_reply(self.nl_hdr): return NlMsgCategory.ACK return NlMsgCategory.UNKNOWN def get_nlm_flags_str(self): category = self._get_nl_category() flags = self.nl_hdr.nlmsg_flags if category == NlMsgCategory.UNKNOWN: return self.helper.get_bitmask_str(NlmBaseFlags, flags) elif category == NlMsgCategory.GET: flags_enum = NlmGetFlags elif category == NlMsgCategory.NEW: flags_enum = NlmNewFlags elif category == NlMsgCategory.DELETE: flags_enum = NlmDeleteFlags elif category == NlMsgCategory.ACK: flags_enum = NlmAckFlags return get_bitmask_str([NlmBaseFlags, flags_enum], flags) def print_nl_header(self, prepend=""): # len=44, type=RTM_DELROUTE, flags=NLM_F_REQUEST|NLM_F_ACK, seq=1641163704, pid=0 # noqa: E501 hdr = self.nl_hdr print( "{}len={}, type={}, flags={}(0x{:X}), seq={}, pid={}".format( prepend, hdr.nlmsg_len, self.msg_name, self.get_nlm_flags_str(), hdr.nlmsg_flags, hdr.nlmsg_seq, hdr.nlmsg_pid, ) ) @classmethod def from_bytes(cls, helper, data): try: hdr, hdrlen = BaseNetlinkMessage.parse_nl_header(data) self = cls(helper, hdr.nlmsg_type) self._orig_data = data self.nl_hdr = hdr except ValueError as e: print("Failed to parse nl header: {}".format(e)) cls.print_as_bytes(data) raise return self def print_message(self): self.print_nl_header() @staticmethod def print_as_bytes(data: bytes, descr: str): print("===vv {} (len:{:3d}) vv===".format(descr, len(data))) off = 0 step = 16 while off < len(data): for i in range(step): if off + i < len(data): print(" {:02X}".format(data[off + i]), end="") print("") off += step print("--------------------") class StdNetlinkMessage(BaseNetlinkMessage): nl_attrs_map = {} @classmethod def from_bytes(cls, helper, data): try: hdr, hdrlen = BaseNetlinkMessage.parse_nl_header(data) self = cls(helper, hdr.nlmsg_type) self._orig_data = data self.nl_hdr = hdr except ValueError as e: print("Failed to parse nl header: {}".format(e)) cls.print_as_bytes(data) raise offset = align4(hdrlen) try: base_hdr, hdrlen = self.parse_base_header(data[offset:]) self.base_hdr = base_hdr offset += align4(hdrlen) # XXX: CAP_ACK except ValueError as e: print("Failed to parse nl rt header: {}".format(e)) cls.print_as_bytes(data) raise orig_offset = offset try: nla_list, nla_len = self.parse_nla_list(data[offset:]) offset += nla_len if offset != len(data): raise ValueError( "{} bytes left at the end of the packet".format(len(data) - offset) ) # noqa: E501 self.nla_list = nla_list except ValueError as e: print( "Failed to parse nla attributes at offset {}: {}".format(orig_offset, e) ) # noqa: E501 cls.print_as_bytes(data, "msg dump") cls.print_as_bytes(data[orig_offset:], "failed block") raise return self + def parse_child(self, data: bytes, attr_key, attr_map): + attrs, _ = self.parse_attrs(data, attr_map) + return NlAttrNested(attr_key, attrs) + + def parse_child_array(self, data: bytes, attr_key, attr_map): + ret = [] + off = 0 + while len(data) - off >= 4: + nla_len, raw_nla_type = struct.unpack("@HH", data[off : off + 4]) + if nla_len + off > len(data): + raise ValueError( + "attr length {} > than the remaining length {}".format( + nla_len, len(data) - off + ) + ) + nla_type = raw_nla_type & 0x3FFF + val = self.parse_child(data[off + 4 : off + nla_len], nla_type, attr_map) + ret.append(val) + off += align4(nla_len) + return NlAttrNested(attr_key, ret) + def parse_attrs(self, data: bytes, attr_map): ret = [] off = 0 while len(data) - off >= 4: - nla_len, raw_nla_type = struct.unpack("@HH", data[off:off + 4]) + nla_len, raw_nla_type = struct.unpack("@HH", data[off : off + 4]) if nla_len + off > len(data): raise ValueError( "attr length {} > than the remaining length {}".format( nla_len, len(data) - off ) ) nla_type = raw_nla_type & 0x3FFF if nla_type in attr_map: v = attr_map[nla_type] - val = v["ad"].cls.from_bytes(data[off:off + nla_len], v["ad"].val) + val = v["ad"].cls.from_bytes(data[off : off + nla_len], v["ad"].val) if "child" in v: # nested - attrs, _ = self.parse_attrs( - data[off + 4:off + nla_len], v["child"] - ) - val = NlAttrNested(v["ad"].val, attrs) + child_data = data[off + 4 : off + nla_len] + if v.get("is_array", False): + # Array of nested attributes + val = self.parse_child_array( + child_data, v["ad"].val, v["child"] + ) + else: + val = self.parse_child(child_data, v["ad"].val, v["child"]) else: # unknown attribute - val = NlAttr(raw_nla_type, data[off + 4:off + nla_len]) + val = NlAttr(raw_nla_type, data[off + 4 : off + nla_len]) ret.append(val) off += align4(nla_len) return ret, off def parse_nla_list(self, data: bytes) -> List[NlAttr]: return self.parse_attrs(data, self.nl_attrs_map) def __bytes__(self): ret = bytes() for nla in self.nla_list: ret += bytes(nla) ret = bytes(self.base_hdr) + ret self.nl_hdr.nlmsg_len = len(ret) + sizeof(Nlmsghdr) return bytes(self.nl_hdr) + ret def _get_msg_type(self): return self.nl_hdr.nlmsg_type @property def msg_props(self): msg_type = self._get_msg_type() for msg_props in self.messages: if msg_props.msg.value == msg_type: return msg_props return None @property def msg_name(self): msg_props = self.msg_props if msg_props is not None: return msg_props.msg.name return super().msg_name def print_base_header(self, hdr, prepend=""): pass def print_message(self): self.print_nl_header() self.print_base_header(self.base_hdr, " ") for nla in self.nla_list: nla.print_attr(" ") diff --git a/tests/atf_python/sys/netlink/netlink.py b/tests/atf_python/sys/netlink/netlink.py index 9b5906815489..f8f886b09b24 100644 --- a/tests/atf_python/sys/netlink/netlink.py +++ b/tests/atf_python/sys/netlink/netlink.py @@ -1,414 +1,417 @@ #!/usr/local/bin/python3 import os import socket import sys from ctypes import c_int from ctypes import c_ubyte from ctypes import c_uint from ctypes import c_ushort from ctypes import sizeof from ctypes import Structure from enum import auto from enum import Enum from atf_python.sys.netlink.attrs import NlAttr from atf_python.sys.netlink.attrs import NlAttrStr from atf_python.sys.netlink.attrs import NlAttrU32 from atf_python.sys.netlink.base_headers import GenlMsgHdr from atf_python.sys.netlink.base_headers import NlmBaseFlags from atf_python.sys.netlink.base_headers import Nlmsghdr from atf_python.sys.netlink.base_headers import NlMsgType from atf_python.sys.netlink.message import BaseNetlinkMessage from atf_python.sys.netlink.message import NlMsgCategory from atf_python.sys.netlink.message import NlMsgProps from atf_python.sys.netlink.message import StdNetlinkMessage from atf_python.sys.netlink.netlink_generic import GenlCtrlAttrType from atf_python.sys.netlink.netlink_generic import GenlCtrlMsgType from atf_python.sys.netlink.netlink_generic import handler_classes as genl_classes from atf_python.sys.netlink.netlink_route import handler_classes as rt_classes from atf_python.sys.netlink.utils import align4 from atf_python.sys.netlink.utils import AttrDescr from atf_python.sys.netlink.utils import build_propmap from atf_python.sys.netlink.utils import enum_or_int from atf_python.sys.netlink.utils import get_bitmask_map from atf_python.sys.netlink.utils import NlConst from atf_python.sys.netlink.utils import prepare_attrs_map class SockaddrNl(Structure): _fields_ = [ ("nl_len", c_ubyte), ("nl_family", c_ubyte), ("nl_pad", c_ushort), ("nl_pid", c_uint), ("nl_groups", c_uint), ] class Nlmsgdone(Structure): _fields_ = [ ("error", c_int), ] class Nlmsgerr(Structure): _fields_ = [ ("error", c_int), ("msg", Nlmsghdr), ] class NlErrattrType(Enum): NLMSGERR_ATTR_UNUSED = 0 NLMSGERR_ATTR_MSG = auto() NLMSGERR_ATTR_OFFS = auto() NLMSGERR_ATTR_COOKIE = auto() NLMSGERR_ATTR_POLICY = auto() class AddressFamilyLinux(Enum): AF_INET = socket.AF_INET AF_INET6 = socket.AF_INET6 AF_NETLINK = 16 class AddressFamilyBsd(Enum): AF_INET = socket.AF_INET AF_INET6 = socket.AF_INET6 AF_NETLINK = 38 class NlHelper: def __init__(self): self._pmap = {} self._af_cls = self.get_af_cls() self._seq_counter = 1 self.pid = os.getpid() def get_seq(self): ret = self._seq_counter self._seq_counter += 1 return ret def get_af_cls(self): if sys.platform.startswith("freebsd"): cls = AddressFamilyBsd else: cls = AddressFamilyLinux return cls def get_propmap(self, cls): if cls not in self._pmap: self._pmap[cls] = build_propmap(cls) return self._pmap[cls] def get_name_propmap(self, cls): ret = {} for prop in dir(cls): if not prop.startswith("_"): ret[prop] = getattr(cls, prop).value return ret def get_attr_byval(self, cls, attr_val): propmap = self.get_propmap(cls) return propmap.get(attr_val) def get_af_name(self, family): v = self.get_attr_byval(self._af_cls, family) if v is not None: return v return "af#{}".format(family) def get_af_value(self, family_str: str) -> int: propmap = self.get_name_propmap(self._af_cls) return propmap.get(family_str) def get_bitmask_str(self, cls, val): bmap = get_bitmask_map(self.get_propmap(cls), val) return ",".join([v for k, v in bmap.items()]) @staticmethod def get_bitmask_str_uncached(cls, val): pmap = NlHelper.build_propmap(cls) bmap = NlHelper.get_bitmask_map(pmap, val) return ",".join([v for k, v in bmap.items()]) nldone_attrs = prepare_attrs_map([]) nlerr_attrs = prepare_attrs_map( [ AttrDescr(NlErrattrType.NLMSGERR_ATTR_MSG, NlAttrStr), AttrDescr(NlErrattrType.NLMSGERR_ATTR_OFFS, NlAttrU32), AttrDescr(NlErrattrType.NLMSGERR_ATTR_COOKIE, NlAttr), ] ) class NetlinkDoneMessage(StdNetlinkMessage): messages = [NlMsgProps(NlMsgType.NLMSG_DONE, NlMsgCategory.ACK)] nl_attrs_map = nldone_attrs @property def error_code(self): return self.base_hdr.error def parse_base_header(self, data): if len(data) < sizeof(Nlmsgdone): raise ValueError("length less than nlmsgdone header") done_hdr = Nlmsgdone.from_buffer_copy(data) sz = sizeof(Nlmsgdone) return (done_hdr, sz) def print_base_header(self, hdr, prepend=""): print("{}error={}".format(prepend, hdr.error)) class NetlinkErrorMessage(StdNetlinkMessage): messages = [NlMsgProps(NlMsgType.NLMSG_ERROR, NlMsgCategory.ACK)] nl_attrs_map = nlerr_attrs @property def error_code(self): return self.base_hdr.error @property def error_str(self): nla = self.get_nla(NlErrattrType.NLMSGERR_ATTR_MSG) if nla: return nla.text return None @property def error_offset(self): nla = self.get_nla(NlErrattrType.NLMSGERR_ATTR_OFFS) if nla: return nla.u32 return None @property def cookie(self): return self.get_nla(NlErrattrType.NLMSGERR_ATTR_COOKIE) def parse_base_header(self, data): if len(data) < sizeof(Nlmsgerr): raise ValueError("length less than nlmsgerr header") err_hdr = Nlmsgerr.from_buffer_copy(data) sz = sizeof(Nlmsgerr) if (self.nl_hdr.nlmsg_flags & 0x100) == 0: sz += align4(err_hdr.msg.nlmsg_len - sizeof(Nlmsghdr)) return (err_hdr, sz) def print_base_header(self, errhdr, prepend=""): print("{}error={}, ".format(prepend, errhdr.error), end="") hdr = errhdr.msg print( "{}len={}, type={}, flags={}(0x{:X}), seq={}, pid={}".format( prepend, hdr.nlmsg_len, "msg#{}".format(hdr.nlmsg_type), self.helper.get_bitmask_str(NlmBaseFlags, hdr.nlmsg_flags), hdr.nlmsg_flags, hdr.nlmsg_seq, hdr.nlmsg_pid, ) ) core_classes = { "netlink_core": [ NetlinkDoneMessage, NetlinkErrorMessage, ], } class Nlsock: HANDLER_CLASSES = [core_classes, rt_classes, genl_classes] def __init__(self, family, helper): self.helper = helper self.sock_fd = self._setup_netlink(family) self._sock_family = family self._data = bytes() self.msgmap = self.build_msgmap() self._family_map = { NlConst.GENL_ID_CTRL: "nlctrl", } def build_msgmap(self): handler_classes = {} for d in self.HANDLER_CLASSES: handler_classes.update(d) xmap = {} # 'family_name': [class.messages[MsgProps.msg], ] for family_id, family_classes in handler_classes.items(): xmap[family_id] = {} for cls in family_classes: for msg_props in cls.messages: xmap[family_id][enum_or_int(msg_props.msg)] = cls return xmap def _setup_netlink(self, netlink_family) -> int: family = self.helper.get_af_value("AF_NETLINK") s = socket.socket(family, socket.SOCK_RAW, netlink_family) s.setsockopt(270, 10, 1) # NETLINK_CAP_ACK s.setsockopt(270, 11, 1) # NETLINK_EXT_ACK return s def set_groups(self, mask: int): self.sock_fd.setsockopt(socket.SOL_SOCKET, 1, mask) # snl = SockaddrNl(nl_len = sizeof(SockaddrNl), nl_family=38, # nl_pid=self.pid, nl_groups=mask) # xbuffer = create_string_buffer(sizeof(SockaddrNl)) # memmove(xbuffer, addressof(snl), sizeof(SockaddrNl)) # k = struct.pack("@BBHII", 12, 38, 0, self.pid, mask) # self.sock_fd.bind(k) + def join_group(self, group_id: int): + self.sock_fd.setsockopt(270, 1, group_id) + def write_message(self, msg, verbose=True): if verbose: print("vvvvvvvv OUT vvvvvvvv") msg.print_message() msg_bytes = bytes(msg) try: ret = os.write(self.sock_fd.fileno(), msg_bytes) assert ret == len(msg_bytes) except Exception as e: print("write({}) -> {}".format(len(msg_bytes), e)) def parse_message(self, data: bytes): if len(data) < sizeof(Nlmsghdr): raise Exception("Short read from nl: {} bytes".format(len(data))) hdr = Nlmsghdr.from_buffer_copy(data) if hdr.nlmsg_type < 16: family_name = "netlink_core" nlmsg_type = hdr.nlmsg_type elif self._sock_family == NlConst.NETLINK_ROUTE: family_name = "netlink_route" nlmsg_type = hdr.nlmsg_type else: # Genetlink if len(data) < sizeof(Nlmsghdr) + sizeof(GenlMsgHdr): raise Exception("Short read from genl: {} bytes".format(len(data))) family_name = self._family_map.get(hdr.nlmsg_type, "") ghdr = GenlMsgHdr.from_buffer_copy(data[sizeof(Nlmsghdr):]) nlmsg_type = ghdr.cmd cls = self.msgmap.get(family_name, {}).get(nlmsg_type) if not cls: cls = BaseNetlinkMessage return cls.from_bytes(self.helper, data) def get_genl_family_id(self, family_name): hdr = Nlmsghdr( nlmsg_type=NlConst.GENL_ID_CTRL, nlmsg_flags=NlmBaseFlags.NLM_F_REQUEST.value, nlmsg_seq=self.helper.get_seq(), ) ghdr = GenlMsgHdr(cmd=GenlCtrlMsgType.CTRL_CMD_GETFAMILY.value) nla = NlAttrStr(GenlCtrlAttrType.CTRL_ATTR_FAMILY_NAME, family_name) hdr.nlmsg_len = sizeof(Nlmsghdr) + sizeof(GenlMsgHdr) + len(bytes(nla)) msg_bytes = bytes(hdr) + bytes(ghdr) + bytes(nla) self.write_data(msg_bytes) while True: rx_msg = self.read_message() if hdr.nlmsg_seq == rx_msg.nl_hdr.nlmsg_seq: if rx_msg.is_type(NlMsgType.NLMSG_ERROR): if rx_msg.error_code != 0: raise ValueError("unable to get family {}".format(family_name)) else: family_id = rx_msg.get_nla(GenlCtrlAttrType.CTRL_ATTR_FAMILY_ID).u16 self._family_map[family_id] = family_name return family_id raise ValueError("unable to get family {}".format(family_name)) def write_data(self, data: bytes): self.sock_fd.send(data) def read_data(self): while True: data = self.sock_fd.recv(65535) self._data += data if len(self._data) >= sizeof(Nlmsghdr): break def read_message(self) -> bytes: if len(self._data) < sizeof(Nlmsghdr): self.read_data() hdr = Nlmsghdr.from_buffer_copy(self._data) while hdr.nlmsg_len > len(self._data): self.read_data() raw_msg = self._data[: hdr.nlmsg_len] self._data = self._data[hdr.nlmsg_len:] return self.parse_message(raw_msg) def get_reply(self, tx_msg): self.write_message(tx_msg) while True: rx_msg = self.read_message() if tx_msg.nl_hdr.nlmsg_seq == rx_msg.nl_hdr.nlmsg_seq: return rx_msg class NetlinkMultipartIterator(object): def __init__(self, obj, seq_number: int, msg_type): self._obj = obj self._seq = seq_number self._msg_type = msg_type def __iter__(self): return self def __next__(self): msg = self._obj.read_message() if self._seq != msg.nl_hdr.nlmsg_seq: raise ValueError("bad sequence number") if msg.is_type(NlMsgType.NLMSG_ERROR): raise ValueError( "error while handling multipart msg: {}".format(msg.error_code) ) elif msg.is_type(NlMsgType.NLMSG_DONE): if msg.error_code == 0: raise StopIteration raise ValueError( "error listing some parts of the multipart msg: {}".format( msg.error_code ) ) elif not msg.is_type(self._msg_type): raise ValueError("bad message type: {}".format(msg)) return msg class NetlinkTestTemplate(object): REQUIRED_MODULES = ["netlink"] def setup_netlink(self, netlink_family: NlConst): self.helper = NlHelper() self.nlsock = Nlsock(netlink_family, self.helper) def write_message(self, msg, silent=False): if not silent: print("") print("============= >> TX MESSAGE =============") msg.print_message() msg.print_as_bytes(bytes(msg), "-- DATA --") self.nlsock.write_data(bytes(msg)) def read_message(self, silent=False): msg = self.nlsock.read_message() if not silent: print("") print("============= << RX MESSAGE =============") msg.print_message() return msg def get_reply(self, tx_msg): self.write_message(tx_msg) while True: rx_msg = self.read_message() if tx_msg.nl_hdr.nlmsg_seq == rx_msg.nl_hdr.nlmsg_seq: return rx_msg def read_msg_list(self, seq, msg_type): return list(NetlinkMultipartIterator(self, seq, msg_type)) diff --git a/tests/atf_python/sys/netlink/netlink_generic.py b/tests/atf_python/sys/netlink/netlink_generic.py index b49a30c1e8e7..80c6eea72a93 100644 --- a/tests/atf_python/sys/netlink/netlink_generic.py +++ b/tests/atf_python/sys/netlink/netlink_generic.py @@ -1,279 +1,312 @@ #!/usr/local/bin/python3 import struct from ctypes import c_int64 from ctypes import c_long from ctypes import sizeof from ctypes import Structure from enum import Enum from atf_python.sys.netlink.attrs import NlAttr from atf_python.sys.netlink.attrs import NlAttrIp4 from atf_python.sys.netlink.attrs import NlAttrIp6 +from atf_python.sys.netlink.attrs import NlAttrNested from atf_python.sys.netlink.attrs import NlAttrS32 from atf_python.sys.netlink.attrs import NlAttrStr from atf_python.sys.netlink.attrs import NlAttrU16 from atf_python.sys.netlink.attrs import NlAttrU32 from atf_python.sys.netlink.attrs import NlAttrU8 from atf_python.sys.netlink.base_headers import GenlMsgHdr from atf_python.sys.netlink.message import NlMsgCategory from atf_python.sys.netlink.message import NlMsgProps from atf_python.sys.netlink.message import StdNetlinkMessage from atf_python.sys.netlink.utils import AttrDescr from atf_python.sys.netlink.utils import enum_or_int from atf_python.sys.netlink.utils import prepare_attrs_map class NetlinkGenlMessage(StdNetlinkMessage): messages = [] nl_attrs_map = {} family_name = None def __init__(self, helper, family_id, cmd=0): super().__init__(helper, family_id) self.base_hdr = GenlMsgHdr(cmd=enum_or_int(cmd)) def parse_base_header(self, data): if len(data) < sizeof(GenlMsgHdr): raise ValueError("length less than GenlMsgHdr header") ghdr = GenlMsgHdr.from_buffer_copy(data) return (ghdr, sizeof(GenlMsgHdr)) def _get_msg_type(self): return self.base_hdr.cmd def print_nl_header(self, prepend=""): # len=44, type=RTM_DELROUTE, flags=NLM_F_REQUEST|NLM_F_ACK, seq=1641163704, pid=0 # noqa: E501 hdr = self.nl_hdr print( "{}len={}, family={}, flags={}(0x{:X}), seq={}, pid={}".format( prepend, hdr.nlmsg_len, self.family_name, self.get_nlm_flags_str(), hdr.nlmsg_flags, hdr.nlmsg_seq, hdr.nlmsg_pid, ) ) def print_base_header(self, hdr, prepend=""): print( "{}cmd={} version={} reserved={}".format( prepend, self.msg_name, hdr.version, hdr.reserved ) ) GenlCtrlFamilyName = "nlctrl" class GenlCtrlMsgType(Enum): CTRL_CMD_UNSPEC = 0 CTRL_CMD_NEWFAMILY = 1 CTRL_CMD_DELFAMILY = 2 CTRL_CMD_GETFAMILY = 3 CTRL_CMD_NEWOPS = 4 CTRL_CMD_DELOPS = 5 CTRL_CMD_GETOPS = 6 CTRL_CMD_NEWMCAST_GRP = 7 CTRL_CMD_DELMCAST_GRP = 8 CTRL_CMD_GETMCAST_GRP = 9 CTRL_CMD_GETPOLICY = 10 class GenlCtrlAttrType(Enum): CTRL_ATTR_FAMILY_ID = 1 CTRL_ATTR_FAMILY_NAME = 2 CTRL_ATTR_VERSION = 3 CTRL_ATTR_HDRSIZE = 4 CTRL_ATTR_MAXATTR = 5 CTRL_ATTR_OPS = 6 CTRL_ATTR_MCAST_GROUPS = 7 CTRL_ATTR_POLICY = 8 CTRL_ATTR_OP_POLICY = 9 CTRL_ATTR_OP = 10 +class GenlCtrlAttrOpType(Enum): + CTRL_ATTR_OP_ID = 1 + CTRL_ATTR_OP_FLAGS = 2 + + +class GenlCtrlAttrMcastGroupsType(Enum): + CTRL_ATTR_MCAST_GRP_NAME = 1 + CTRL_ATTR_MCAST_GRP_ID = 2 + + genl_ctrl_attrs = prepare_attrs_map( [ AttrDescr(GenlCtrlAttrType.CTRL_ATTR_FAMILY_ID, NlAttrU16), AttrDescr(GenlCtrlAttrType.CTRL_ATTR_FAMILY_NAME, NlAttrStr), AttrDescr(GenlCtrlAttrType.CTRL_ATTR_VERSION, NlAttrU32), AttrDescr(GenlCtrlAttrType.CTRL_ATTR_HDRSIZE, NlAttrU32), AttrDescr(GenlCtrlAttrType.CTRL_ATTR_MAXATTR, NlAttrU32), + AttrDescr( + GenlCtrlAttrType.CTRL_ATTR_OPS, + NlAttrNested, + [ + AttrDescr(GenlCtrlAttrOpType.CTRL_ATTR_OP_ID, NlAttrU32), + AttrDescr(GenlCtrlAttrOpType.CTRL_ATTR_OP_FLAGS, NlAttrU32), + ], + True, + ), + AttrDescr( + GenlCtrlAttrType.CTRL_ATTR_MCAST_GROUPS, + NlAttrNested, + [ + AttrDescr( + GenlCtrlAttrMcastGroupsType.CTRL_ATTR_MCAST_GRP_NAME, NlAttrStr + ), + AttrDescr( + GenlCtrlAttrMcastGroupsType.CTRL_ATTR_MCAST_GRP_ID, NlAttrU32 + ), + ], + True, + ), ] ) class NetlinkGenlCtrlMessage(NetlinkGenlMessage): messages = [ NlMsgProps(GenlCtrlMsgType.CTRL_CMD_NEWFAMILY, NlMsgCategory.NEW), NlMsgProps(GenlCtrlMsgType.CTRL_CMD_GETFAMILY, NlMsgCategory.GET), NlMsgProps(GenlCtrlMsgType.CTRL_CMD_DELFAMILY, NlMsgCategory.DELETE), ] nl_attrs_map = genl_ctrl_attrs family_name = GenlCtrlFamilyName CarpFamilyName = "carp" class CarpMsgType(Enum): CARP_NL_CMD_UNSPEC = 0 CARP_NL_CMD_GET = 1 CARP_NL_CMD_SET = 2 class CarpAttrType(Enum): CARP_NL_UNSPEC = 0 CARP_NL_VHID = 1 CARP_NL_STATE = 2 CARP_NL_ADVBASE = 3 CARP_NL_ADVSKEW = 4 CARP_NL_KEY = 5 CARP_NL_IFINDEX = 6 CARP_NL_ADDR = 7 CARP_NL_ADDR6 = 8 CARP_NL_IFNAME = 9 carp_gen_attrs = prepare_attrs_map( [ AttrDescr(CarpAttrType.CARP_NL_VHID, NlAttrU32), AttrDescr(CarpAttrType.CARP_NL_STATE, NlAttrU32), AttrDescr(CarpAttrType.CARP_NL_ADVBASE, NlAttrS32), AttrDescr(CarpAttrType.CARP_NL_ADVSKEW, NlAttrS32), AttrDescr(CarpAttrType.CARP_NL_KEY, NlAttr), AttrDescr(CarpAttrType.CARP_NL_IFINDEX, NlAttrU32), AttrDescr(CarpAttrType.CARP_NL_ADDR, NlAttrIp4), AttrDescr(CarpAttrType.CARP_NL_ADDR6, NlAttrIp6), AttrDescr(CarpAttrType.CARP_NL_IFNAME, NlAttrStr), ] ) class CarpGenMessage(NetlinkGenlMessage): messages = [ NlMsgProps(CarpMsgType.CARP_NL_CMD_GET, NlMsgCategory.GET), NlMsgProps(CarpMsgType.CARP_NL_CMD_SET, NlMsgCategory.NEW), ] nl_attrs_map = carp_gen_attrs family_name = CarpFamilyName KtestFamilyName = "ktest" class KtestMsgType(Enum): KTEST_CMD_UNSPEC = 0 KTEST_CMD_LIST = 1 KTEST_CMD_RUN = 2 KTEST_CMD_NEWTEST = 3 KTEST_CMD_NEWMESSAGE = 4 class KtestAttrType(Enum): KTEST_ATTR_MOD_NAME = 1 KTEST_ATTR_TEST_NAME = 2 KTEST_ATTR_TEST_DESCR = 3 KTEST_ATTR_TEST_META = 4 class KtestLogMsgType(Enum): KTEST_MSG_START = 1 KTEST_MSG_END = 2 KTEST_MSG_LOG = 3 KTEST_MSG_FAIL = 4 class KtestMsgAttrType(Enum): KTEST_MSG_ATTR_TS = 1 KTEST_MSG_ATTR_FUNC = 2 KTEST_MSG_ATTR_FILE = 3 KTEST_MSG_ATTR_LINE = 4 KTEST_MSG_ATTR_TEXT = 5 KTEST_MSG_ATTR_LEVEL = 6 KTEST_MSG_ATTR_META = 7 class timespec(Structure): _fields_ = [ ("tv_sec", c_int64), ("tv_nsec", c_long), ] class NlAttrTS(NlAttr): DATA_LEN = sizeof(timespec) def __init__(self, nla_type, val): self.ts = val super().__init__(nla_type, b"") @property def nla_len(self): return NlAttr.HDR_LEN + self.DATA_LEN def _print_attr_value(self): return " tv_sec={} tv_nsec={}".format(self.ts.tv_sec, self.ts.tv_nsec) @staticmethod def _validate(data): assert len(data) == NlAttr.HDR_LEN + NlAttrTS.DATA_LEN - nla_len, nla_type = struct.unpack("@HH", data[:NlAttr.HDR_LEN]) + nla_len, nla_type = struct.unpack("@HH", data[: NlAttr.HDR_LEN]) assert nla_len == NlAttr.HDR_LEN + NlAttrTS.DATA_LEN @classmethod def _parse(cls, data): - nla_len, nla_type = struct.unpack("@HH", data[:NlAttr.HDR_LEN]) - val = timespec.from_buffer_copy(data[NlAttr.HDR_LEN:]) + nla_len, nla_type = struct.unpack("@HH", data[: NlAttr.HDR_LEN]) + val = timespec.from_buffer_copy(data[NlAttr.HDR_LEN :]) return cls(nla_type, val) def __bytes__(self): return self._to_bytes(bytes(self.ts)) ktest_info_attrs = prepare_attrs_map( [ AttrDescr(KtestAttrType.KTEST_ATTR_MOD_NAME, NlAttrStr), AttrDescr(KtestAttrType.KTEST_ATTR_TEST_NAME, NlAttrStr), AttrDescr(KtestAttrType.KTEST_ATTR_TEST_DESCR, NlAttrStr), ] ) ktest_msg_attrs = prepare_attrs_map( [ AttrDescr(KtestMsgAttrType.KTEST_MSG_ATTR_FUNC, NlAttrStr), AttrDescr(KtestMsgAttrType.KTEST_MSG_ATTR_FILE, NlAttrStr), AttrDescr(KtestMsgAttrType.KTEST_MSG_ATTR_LINE, NlAttrU32), AttrDescr(KtestMsgAttrType.KTEST_MSG_ATTR_TEXT, NlAttrStr), AttrDescr(KtestMsgAttrType.KTEST_MSG_ATTR_LEVEL, NlAttrU8), AttrDescr(KtestMsgAttrType.KTEST_MSG_ATTR_TS, NlAttrTS), ] ) class KtestInfoMessage(NetlinkGenlMessage): messages = [ NlMsgProps(KtestMsgType.KTEST_CMD_LIST, NlMsgCategory.GET), NlMsgProps(KtestMsgType.KTEST_CMD_RUN, NlMsgCategory.NEW), NlMsgProps(KtestMsgType.KTEST_CMD_NEWTEST, NlMsgCategory.NEW), ] nl_attrs_map = ktest_info_attrs family_name = KtestFamilyName class KtestMsgMessage(NetlinkGenlMessage): messages = [ NlMsgProps(KtestMsgType.KTEST_CMD_NEWMESSAGE, NlMsgCategory.NEW), ] nl_attrs_map = ktest_msg_attrs family_name = KtestFamilyName handler_classes = { CarpFamilyName: [CarpGenMessage], GenlCtrlFamilyName: [NetlinkGenlCtrlMessage], KtestFamilyName: [KtestInfoMessage, KtestMsgMessage], } diff --git a/tests/atf_python/sys/netlink/utils.py b/tests/atf_python/sys/netlink/utils.py index 7a41791b5318..f1d0ba3321ed 100644 --- a/tests/atf_python/sys/netlink/utils.py +++ b/tests/atf_python/sys/netlink/utils.py @@ -1,78 +1,80 @@ #!/usr/local/bin/python3 from enum import Enum from typing import Any from typing import Dict from typing import List from typing import NamedTuple class NlConst: AF_NETLINK = 38 NETLINK_ROUTE = 0 NETLINK_GENERIC = 16 GENL_ID_CTRL = 16 def roundup2(val: int, num: int) -> int: if val % num: return (val | (num - 1)) + 1 else: return val def align4(val: int) -> int: return roundup2(val, 4) def enum_or_int(val) -> int: if isinstance(val, Enum): return val.value return val class AttrDescr(NamedTuple): val: Enum cls: "NlAttr" child_map: Any = None + is_array: bool = False def prepare_attrs_map(attrs: List[AttrDescr]) -> Dict[str, Dict]: ret = {} for ad in attrs: ret[ad.val.value] = {"ad": ad} if ad.child_map: ret[ad.val.value]["child"] = prepare_attrs_map(ad.child_map) + ret[ad.val.value]["is_array"] = ad.is_array return ret def build_propmap(cls): ret = {} for prop in dir(cls): if not prop.startswith("_"): ret[getattr(cls, prop).value] = prop return ret def get_bitmask_map(propmap, val): v = 1 ret = {} while val: if v & val: if v in propmap: ret[v] = propmap[v] else: ret[v] = hex(v) val -= v v *= 2 return ret def get_bitmask_str(cls, val): if isinstance(cls, type): pmap = build_propmap(cls) else: pmap = {} for _cls in cls: pmap.update(build_propmap(_cls)) bmap = get_bitmask_map(pmap, val) return ",".join([v for k, v in bmap.items()])