diff --git a/tests/atf_python/sys/netlink/Makefile b/tests/atf_python/sys/netlink/Makefile index 057415cf87f6..73ce5ac50261 100644 --- a/tests/atf_python/sys/netlink/Makefile +++ b/tests/atf_python/sys/netlink/Makefile @@ -1,11 +1,12 @@ .include .PATH: ${.CURDIR} -FILES= __init__.py attrs.py base_headers.py message.py netlink.py netlink_route.py utils.py +FILES= __init__.py attrs.py base_headers.py message.py netlink.py \ + netlink_generic.py netlink_route.py utils.py .include FILESDIR= ${TESTSBASE}/atf_python/sys/netlink .include diff --git a/tests/atf_python/sys/netlink/message.py b/tests/atf_python/sys/netlink/message.py index 6bc9f2932868..b6fb2f8e357a 100644 --- a/tests/atf_python/sys/netlink/message.py +++ b/tests/atf_python/sys/netlink/message.py @@ -1,201 +1,261 @@ #!/usr/local/bin/python3 import struct from ctypes import sizeof +from enum import Enum from typing import List +from typing import NamedTuple from atf_python.sys.netlink.attrs import NlAttr from atf_python.sys.netlink.attrs import NlAttrNested +from atf_python.sys.netlink.base_headers import NlmAckFlags +from atf_python.sys.netlink.base_headers import NlmNewFlags +from atf_python.sys.netlink.base_headers import NlmGetFlags +from atf_python.sys.netlink.base_headers import NlmDeleteFlags from atf_python.sys.netlink.base_headers import NlmBaseFlags from atf_python.sys.netlink.base_headers import Nlmsghdr from atf_python.sys.netlink.base_headers import NlMsgType from atf_python.sys.netlink.utils import align4 from atf_python.sys.netlink.utils import enum_or_int +from atf_python.sys.netlink.utils import get_bitmask_str + + +class NlMsgCategory(Enum): + UNKNOWN = 0 + GET = 1 + NEW = 2 + DELETE = 3 + ACK = 4 + + +class NlMsgProps(NamedTuple): + msg: Enum + category: NlMsgCategory class BaseNetlinkMessage(object): def __init__(self, helper, nlmsg_type): self.nlmsg_type = enum_or_int(nlmsg_type) self.nla_list = [] self._orig_data = None self.helper = helper self.nl_hdr = Nlmsghdr( nlmsg_type=self.nlmsg_type, nlmsg_seq=helper.get_seq(), nlmsg_pid=helper.pid ) self.base_hdr = None def set_request(self, need_ack=True): self.add_nlflags([NlmBaseFlags.NLM_F_REQUEST]) if need_ack: self.add_nlflags([NlmBaseFlags.NLM_F_ACK]) def add_nlflags(self, flags: List): int_flags = 0 for flag in flags: int_flags |= enum_or_int(flag) self.nl_hdr.nlmsg_flags |= int_flags def add_nla(self, nla): self.nla_list.append(nla) def _get_nla(self, nla_list, nla_type): nla_type_raw = enum_or_int(nla_type) for nla in nla_list: if nla.nla_type == nla_type_raw: return nla return None def get_nla(self, nla_type): return self._get_nla(self.nla_list, nla_type) @staticmethod def parse_nl_header(data: bytes): if len(data) < sizeof(Nlmsghdr): raise ValueError("length less than netlink message header") return Nlmsghdr.from_buffer_copy(data), sizeof(Nlmsghdr) def is_type(self, nlmsg_type): nlmsg_type_raw = enum_or_int(nlmsg_type) return nlmsg_type_raw == self.nl_hdr.nlmsg_type def is_reply(self, hdr): return hdr.nlmsg_type == NlMsgType.NLMSG_ERROR.value - def print_nl_header(self, hdr, prepend=""): + @property + def msg_name(self): + return "msg#{}".format(self._get_msg_type()) + + def _get_nl_category(self): + if self.is_reply(self.nl_hdr): + return NlMsgCategory.ACK + return NlMsgCategory.UNKNOWN + + def get_nlm_flags_str(self): + category = self._get_nl_category() + flags = self.nl_hdr.nlmsg_flags + + if category == NlMsgCategory.UNKNOWN: + return self.helper.get_bitmask_str(NlmBaseFlags, flags) + elif category == NlMsgCategory.GET: + flags_enum = NlmGetFlags + elif category == NlMsgCategory.NEW: + flags_enum = NlmNewFlags + elif category == NlMsgCategory.DELETE: + flags_enum = NlmDeleteFlags + elif category == NlMsgCategory.ACK: + flags_enum = NlmAckFlags + return get_bitmask_str([NlmBaseFlags, flags_enum], flags) + + def print_nl_header(self, prepend=""): # len=44, type=RTM_DELROUTE, flags=NLM_F_REQUEST|NLM_F_ACK, seq=1641163704, pid=0 # noqa: E501 - is_reply = self.is_reply(hdr) - msg_name = self.helper.get_nlmsg_name(hdr.nlmsg_type) + hdr = self.nl_hdr print( "{}len={}, type={}, flags={}(0x{:X}), seq={}, pid={}".format( prepend, hdr.nlmsg_len, - msg_name, - self.helper.get_nlm_flags_str( - msg_name, is_reply, hdr.nlmsg_flags - ), # noqa: E501 + self.msg_name, + self.get_nlm_flags_str(), hdr.nlmsg_flags, hdr.nlmsg_seq, hdr.nlmsg_pid, ) ) @classmethod def from_bytes(cls, helper, data): try: hdr, hdrlen = BaseNetlinkMessage.parse_nl_header(data) self = cls(helper, hdr.nlmsg_type) self._orig_data = data self.nl_hdr = hdr except ValueError as e: print("Failed to parse nl header: {}".format(e)) cls.print_as_bytes(data) raise return self def print_message(self): - self.print_nl_header(self.nl_hdr) + self.print_nl_header() @staticmethod def print_as_bytes(data: bytes, descr: str): print("===vv {} (len:{:3d}) vv===".format(descr, len(data))) off = 0 step = 16 while off < len(data): for i in range(step): if off + i < len(data): print(" {:02X}".format(data[off + i]), end="") print("") off += step print("--------------------") class StdNetlinkMessage(BaseNetlinkMessage): nl_attrs_map = {} @classmethod def from_bytes(cls, helper, data): try: hdr, hdrlen = BaseNetlinkMessage.parse_nl_header(data) self = cls(helper, hdr.nlmsg_type) self._orig_data = data self.nl_hdr = hdr except ValueError as e: print("Failed to parse nl header: {}".format(e)) cls.print_as_bytes(data) raise offset = align4(hdrlen) try: base_hdr, hdrlen = self.parse_base_header(data[offset:]) self.base_hdr = base_hdr offset += align4(hdrlen) # XXX: CAP_ACK except ValueError as e: print("Failed to parse nl rt header: {}".format(e)) cls.print_as_bytes(data) raise orig_offset = offset try: nla_list, nla_len = self.parse_nla_list(data[offset:]) offset += nla_len if offset != len(data): raise ValueError( "{} bytes left at the end of the packet".format(len(data) - offset) ) # noqa: E501 self.nla_list = nla_list except ValueError as e: print( "Failed to parse nla attributes at offset {}: {}".format(orig_offset, e) ) # noqa: E501 cls.print_as_bytes(data, "msg dump") cls.print_as_bytes(data[orig_offset:], "failed block") raise return self def parse_attrs(self, data: bytes, attr_map): ret = [] off = 0 while len(data) - off >= 4: nla_len, raw_nla_type = struct.unpack("@HH", data[off:off + 4]) if nla_len + off > len(data): raise ValueError( "attr length {} > than the remaining length {}".format( nla_len, len(data) - off ) ) nla_type = raw_nla_type & 0x3F if nla_type in attr_map: v = attr_map[nla_type] val = v["ad"].cls.from_bytes(data[off:off + nla_len], v["ad"].val) if "child" in v: # nested attrs, _ = self.parse_attrs( data[off + 4:off + nla_len], v["child"] ) val = NlAttrNested(v["ad"].val, attrs) else: # unknown attribute val = NlAttr(raw_nla_type, data[off + 4:off + nla_len]) ret.append(val) off += align4(nla_len) return ret, off def parse_nla_list(self, data: bytes) -> List[NlAttr]: return self.parse_attrs(data, self.nl_attrs_map) def __bytes__(self): ret = bytes() for nla in self.nla_list: ret += bytes(nla) ret = bytes(self.base_hdr) + ret self.nl_hdr.nlmsg_len = len(ret) + sizeof(Nlmsghdr) return bytes(self.nl_hdr) + ret + def _get_msg_type(self): + return self.nl_hdr.nlmsg_type + + @property + def msg_props(self): + msg_type = self._get_msg_type() + for msg_props in self.messages: + if msg_props.msg.value == msg_type: + return msg_props + return None + + @property + def msg_name(self): + msg_props = self.msg_props + if msg_props is not None: + return msg_props.msg.name + return super().msg_name + def print_base_header(self, hdr, prepend=""): pass def print_message(self): - self.print_nl_header(self.nl_hdr) + self.print_nl_header() self.print_base_header(self.base_hdr, " ") for nla in self.nla_list: nla.print_attr(" ") diff --git a/tests/atf_python/sys/netlink/netlink.py b/tests/atf_python/sys/netlink/netlink.py index e10efadee656..f813727d55b4 100644 --- a/tests/atf_python/sys/netlink/netlink.py +++ b/tests/atf_python/sys/netlink/netlink.py @@ -1,368 +1,407 @@ #!/usr/local/bin/python3 import os import socket import sys from ctypes import c_int from ctypes import c_ubyte from ctypes import c_uint from ctypes import c_ushort from ctypes import sizeof from ctypes import Structure from enum import auto from enum import Enum from atf_python.sys.netlink.attrs import NlAttr from atf_python.sys.netlink.attrs import NlAttrStr from atf_python.sys.netlink.attrs import NlAttrU32 -from atf_python.sys.netlink.base_headers import NlmAckFlags +from atf_python.sys.netlink.base_headers import GenlMsgHdr from atf_python.sys.netlink.base_headers import NlmBaseFlags -from atf_python.sys.netlink.base_headers import NlmDeleteFlags -from atf_python.sys.netlink.base_headers import NlmGetFlags -from atf_python.sys.netlink.base_headers import NlmNewFlags from atf_python.sys.netlink.base_headers import Nlmsghdr from atf_python.sys.netlink.base_headers import NlMsgType from atf_python.sys.netlink.message import BaseNetlinkMessage +from atf_python.sys.netlink.message import NlMsgCategory +from atf_python.sys.netlink.message import NlMsgProps from atf_python.sys.netlink.message import StdNetlinkMessage -from atf_python.sys.netlink.netlink_route import NetlinkIfaMessage -from atf_python.sys.netlink.netlink_route import NetlinkIflaMessage -from atf_python.sys.netlink.netlink_route import NetlinkNdMessage -from atf_python.sys.netlink.netlink_route import NetlinkRtMessage -from atf_python.sys.netlink.netlink_route import NlRtMsgType +from atf_python.sys.netlink.netlink_generic import GenlCtrlMsgType +from atf_python.sys.netlink.netlink_generic import GenlCtrlAttrType +from atf_python.sys.netlink.netlink_generic import handler_classes as genl_classes +from atf_python.sys.netlink.netlink_route import handler_classes as rt_classes from atf_python.sys.netlink.utils import align4 from atf_python.sys.netlink.utils import AttrDescr from atf_python.sys.netlink.utils import build_propmap +from atf_python.sys.netlink.utils import enum_or_int from atf_python.sys.netlink.utils import get_bitmask_map from atf_python.sys.netlink.utils import NlConst from atf_python.sys.netlink.utils import prepare_attrs_map class SockaddrNl(Structure): _fields_ = [ ("nl_len", c_ubyte), ("nl_family", c_ubyte), ("nl_pad", c_ushort), ("nl_pid", c_uint), ("nl_groups", c_uint), ] class Nlmsgdone(Structure): _fields_ = [ ("error", c_int), ] class Nlmsgerr(Structure): _fields_ = [ ("error", c_int), ("msg", Nlmsghdr), ] class NlErrattrType(Enum): NLMSGERR_ATTR_UNUSED = 0 NLMSGERR_ATTR_MSG = auto() NLMSGERR_ATTR_OFFS = auto() NLMSGERR_ATTR_COOKIE = auto() NLMSGERR_ATTR_POLICY = auto() class AddressFamilyLinux(Enum): AF_INET = socket.AF_INET AF_INET6 = socket.AF_INET6 AF_NETLINK = 16 class AddressFamilyBsd(Enum): AF_INET = socket.AF_INET AF_INET6 = socket.AF_INET6 AF_NETLINK = 38 class NlHelper: def __init__(self): self._pmap = {} self._af_cls = self.get_af_cls() self._seq_counter = 1 self.pid = os.getpid() def get_seq(self): ret = self._seq_counter self._seq_counter += 1 return ret def get_af_cls(self): if sys.platform.startswith("freebsd"): cls = AddressFamilyBsd else: cls = AddressFamilyLinux return cls def get_propmap(self, cls): if cls not in self._pmap: self._pmap[cls] = build_propmap(cls) return self._pmap[cls] def get_name_propmap(self, cls): ret = {} for prop in dir(cls): if not prop.startswith("_"): ret[prop] = getattr(cls, prop).value return ret def get_attr_byval(self, cls, attr_val): propmap = self.get_propmap(cls) return propmap.get(attr_val) - def get_nlmsg_name(self, val): - for cls in [NlRtMsgType, NlMsgType]: - v = self.get_attr_byval(cls, val) - if v is not None: - return v - return "msg#{}".format(val) - def get_af_name(self, family): v = self.get_attr_byval(self._af_cls, family) if v is not None: return v return "af#{}".format(family) def get_af_value(self, family_str: str) -> int: propmap = self.get_name_propmap(self._af_cls) return propmap.get(family_str) def get_bitmask_str(self, cls, val): bmap = get_bitmask_map(self.get_propmap(cls), val) return ",".join([v for k, v in bmap.items()]) @staticmethod def get_bitmask_str_uncached(cls, val): pmap = NlHelper.build_propmap(cls) bmap = NlHelper.get_bitmask_map(pmap, val) return ",".join([v for k, v in bmap.items()]) - def get_nlm_flags_str(self, msg_str: str, reply: bool, val): - if reply: - return self.get_bitmask_str(NlmAckFlags, val) - if msg_str.startswith("RTM_GET"): - return self.get_bitmask_str(NlmGetFlags, val) - elif msg_str.startswith("RTM_DEL"): - return self.get_bitmask_str(NlmDeleteFlags, val) - elif msg_str.startswith("RTM_NEW"): - return self.get_bitmask_str(NlmNewFlags, val) - else: - return self.get_bitmask_str(NlmBaseFlags, val) - nldone_attrs = prepare_attrs_map([]) nlerr_attrs = prepare_attrs_map( [ AttrDescr(NlErrattrType.NLMSGERR_ATTR_MSG, NlAttrStr), AttrDescr(NlErrattrType.NLMSGERR_ATTR_OFFS, NlAttrU32), AttrDescr(NlErrattrType.NLMSGERR_ATTR_COOKIE, NlAttr), ] ) class NetlinkDoneMessage(StdNetlinkMessage): - messages = [NlMsgType.NLMSG_DONE.value] + messages = [NlMsgProps(NlMsgType.NLMSG_DONE, NlMsgCategory.ACK)] nl_attrs_map = nldone_attrs @property def error_code(self): return self.base_hdr.error def parse_base_header(self, data): if len(data) < sizeof(Nlmsgdone): raise ValueError("length less than nlmsgdone header") done_hdr = Nlmsgdone.from_buffer_copy(data) sz = sizeof(Nlmsgdone) return (done_hdr, sz) def print_base_header(self, hdr, prepend=""): print("{}error={}".format(prepend, hdr.error)) class NetlinkErrorMessage(StdNetlinkMessage): - messages = [NlMsgType.NLMSG_ERROR.value] + messages = [NlMsgProps(NlMsgType.NLMSG_ERROR, NlMsgCategory.ACK)] nl_attrs_map = nlerr_attrs @property def error_code(self): return self.base_hdr.error @property def error_str(self): nla = self.get_nla(NlErrattrType.NLMSGERR_ATTR_MSG) if nla: return nla.text return None @property def error_offset(self): nla = self.get_nla(NlErrattrType.NLMSGERR_ATTR_OFFS) if nla: return nla.u32 return None @property def cookie(self): return self.get_nla(NlErrattrType.NLMSGERR_ATTR_COOKIE) def parse_base_header(self, data): if len(data) < sizeof(Nlmsgerr): raise ValueError("length less than nlmsgerr header") err_hdr = Nlmsgerr.from_buffer_copy(data) sz = sizeof(Nlmsgerr) if (self.nl_hdr.nlmsg_flags & 0x100) == 0: sz += align4(err_hdr.msg.nlmsg_len - sizeof(Nlmsghdr)) return (err_hdr, sz) def print_base_header(self, errhdr, prepend=""): print("{}error={}, ".format(prepend, errhdr.error), end="") - self.print_nl_header(errhdr.msg, prepend) + hdr = errhdr.msg + print( + "{}len={}, type={}, flags={}(0x{:X}), seq={}, pid={}".format( + prepend, + hdr.nlmsg_len, + "msg#{}".format(hdr.nlmsg_type), + self.helper.get_bitmask_str(NlmBaseFlags, hdr.nlmsg_flags), + hdr.nlmsg_flags, + hdr.nlmsg_seq, + hdr.nlmsg_pid, + ) + ) + + +core_classes = { + "netlink_core": [ + NetlinkDoneMessage, + NetlinkErrorMessage, + ], +} class Nlsock: + HANDLER_CLASSES = [core_classes, rt_classes, genl_classes] + def __init__(self, family, helper): self.helper = helper self.sock_fd = self._setup_netlink(family) + self._sock_family = family self._data = bytes() self.msgmap = self.build_msgmap() - # self.set_groups(NlRtGroup.RTNLGRP_IPV4_ROUTE.value | NlRtGroup.RTNLGRP_IPV6_ROUTE.value) # noqa: E501 + self._family_map = { + NlConst.GENL_ID_CTRL: "nlctrl", + } def build_msgmap(self): - classes = [ - NetlinkRtMessage, - NetlinkIflaMessage, - NetlinkIfaMessage, - NetlinkNdMessage, - NetlinkDoneMessage, - NetlinkErrorMessage, - ] + handler_classes = {} + for d in self.HANDLER_CLASSES: + handler_classes.update(d) xmap = {} - for cls in classes: - for message in cls.messages: - xmap[message] = cls + # 'family_name': [class.messages[MsgProps.msg], ] + for family_id, family_classes in handler_classes.items(): + xmap[family_id] = {} + for cls in family_classes: + for msg_props in cls.messages: + xmap[family_id][enum_or_int(msg_props.msg)] = cls return xmap def _setup_netlink(self, netlink_family) -> int: family = self.helper.get_af_value("AF_NETLINK") s = socket.socket(family, socket.SOCK_RAW, netlink_family) s.setsockopt(270, 10, 1) # NETLINK_CAP_ACK s.setsockopt(270, 11, 1) # NETLINK_EXT_ACK return s def set_groups(self, mask: int): self.sock_fd.setsockopt(socket.SOL_SOCKET, 1, mask) # snl = SockaddrNl(nl_len = sizeof(SockaddrNl), nl_family=38, # nl_pid=self.pid, nl_groups=mask) # xbuffer = create_string_buffer(sizeof(SockaddrNl)) # memmove(xbuffer, addressof(snl), sizeof(SockaddrNl)) # k = struct.pack("@BBHII", 12, 38, 0, self.pid, mask) # self.sock_fd.bind(k) - def write_message(self, msg): - print("vvvvvvvv OUT vvvvvvvv") - msg.print_message() + def write_message(self, msg, verbose=True): + if verbose: + print("vvvvvvvv OUT vvvvvvvv") + msg.print_message() msg_bytes = bytes(msg) try: ret = os.write(self.sock_fd.fileno(), msg_bytes) assert ret == len(msg_bytes) except Exception as e: print("write({}) -> {}".format(len(msg_bytes), e)) def parse_message(self, data: bytes): if len(data) < sizeof(Nlmsghdr): raise Exception("Short read from nl: {} bytes".format(len(data))) hdr = Nlmsghdr.from_buffer_copy(data) - nlmsg_type = hdr.nlmsg_type - cls = self.msgmap.get(nlmsg_type) + if hdr.nlmsg_type < 16: + family_name = "netlink_core" + nlmsg_type = hdr.nlmsg_type + elif self._sock_family == NlConst.NETLINK_ROUTE: + family_name = "netlink_route" + nlmsg_type = hdr.nlmsg_type + else: + # Genetlink + if len(data) < sizeof(Nlmsghdr) + sizeof(GenlMsgHdr): + raise Exception("Short read from genl: {} bytes".format(len(data))) + family_name = self._family_map.get(hdr.nlmsg_type, "") + ghdr = GenlMsgHdr.from_buffer_copy(data[sizeof(Nlmsghdr):]) + nlmsg_type = ghdr.cmd + cls = self.msgmap.get(family_name, {}).get(nlmsg_type) if not cls: cls = BaseNetlinkMessage return cls.from_bytes(self.helper, data) + def get_genl_family_id(self, family_name): + hdr = Nlmsghdr( + nlmsg_type=NlConst.GENL_ID_CTRL, + nlmsg_flags=NlmBaseFlags.NLM_F_REQUEST.value, + nlmsg_seq = self.helper.get_seq(), + ) + ghdr = GenlMsgHdr(cmd=GenlCtrlMsgType.CTRL_CMD_GETFAMILY.value) + nla = NlAttrStr(GenlCtrlAttrType.CTRL_ATTR_FAMILY_NAME, family_name) + hdr.nlmsg_len = sizeof(Nlmsghdr) + sizeof(GenlMsgHdr) + len(bytes(nla)) + + msg_bytes = bytes(hdr) + bytes(ghdr) + bytes(nla) + self.write_data(msg_bytes) + while True: + rx_msg = self.read_message() + if hdr.nlmsg_seq == rx_msg.nl_hdr.nlmsg_seq: + if rx_msg.is_type(NlMsgType.NLMSG_ERROR): + if rx_msg.error_code != 0: + raise ValueError("unable to get family {}".format(family_name)) + else: + family_id = rx_msg.get_nla(GenlCtrlAttrType.CTRL_ATTR_FAMILY_ID).u16 + self._family_map[family_id] = family_name + return family_id + raise ValueError("unable to get family {}".format(family_name)) + def write_data(self, data: bytes): self.sock_fd.send(data) def read_data(self): while True: data = self.sock_fd.recv(65535) self._data += data if len(self._data) >= sizeof(Nlmsghdr): break def read_message(self) -> bytes: if len(self._data) < sizeof(Nlmsghdr): self.read_data() hdr = Nlmsghdr.from_buffer_copy(self._data) while hdr.nlmsg_len > len(self._data): self.read_data() raw_msg = self._data[: hdr.nlmsg_len] self._data = self._data[hdr.nlmsg_len:] return self.parse_message(raw_msg) class NetlinkMultipartIterator(object): def __init__(self, obj, seq_number: int, msg_type): self._obj = obj self._seq = seq_number self._msg_type = msg_type def __iter__(self): return self def __next__(self): msg = self._obj.read_message() if self._seq != msg.nl_hdr.nlmsg_seq: raise ValueError("bad sequence number") if msg.is_type(NlMsgType.NLMSG_ERROR): raise ValueError( "error while handling multipart msg: {}".format(msg.error_code) ) elif msg.is_type(NlMsgType.NLMSG_DONE): if msg.error_code == 0: raise StopIteration raise ValueError( "error listing some parts of the multipart msg: {}".format( msg.error_code ) ) elif not msg.is_type(self._msg_type): raise ValueError("bad message type: {}".format(msg)) return msg class NetlinkTestTemplate(object): REQUIRED_MODULES = ["netlink"] def setup_netlink(self, netlink_family: NlConst): self.helper = NlHelper() self.nlsock = Nlsock(netlink_family, self.helper) def write_message(self, msg, silent=False): if not silent: print("") print("============= >> TX MESSAGE =============") msg.print_message() msg.print_as_bytes(bytes(msg), "-- DATA --") self.nlsock.write_data(bytes(msg)) def read_message(self, silent=False): msg = self.nlsock.read_message() if not silent: print("") print("============= << RX MESSAGE =============") msg.print_message() return msg def get_reply(self, tx_msg): self.write_message(tx_msg) while True: rx_msg = self.read_message() if tx_msg.nl_hdr.nlmsg_seq == rx_msg.nl_hdr.nlmsg_seq: return rx_msg def read_msg_list(self, seq, msg_type): return list(NetlinkMultipartIterator(self, seq, msg_type)) diff --git a/tests/atf_python/sys/netlink/netlink_generic.py b/tests/atf_python/sys/netlink/netlink_generic.py new file mode 100644 index 000000000000..ee75d5bf37f3 --- /dev/null +++ b/tests/atf_python/sys/netlink/netlink_generic.py @@ -0,0 +1,110 @@ +#!/usr/local/bin/python3 +from ctypes import sizeof +from enum import Enum + +from atf_python.sys.netlink.attrs import NlAttrStr +from atf_python.sys.netlink.attrs import NlAttrU16 +from atf_python.sys.netlink.attrs import NlAttrU32 +from atf_python.sys.netlink.base_headers import GenlMsgHdr +from atf_python.sys.netlink.message import NlMsgCategory +from atf_python.sys.netlink.message import NlMsgProps +from atf_python.sys.netlink.message import StdNetlinkMessage +from atf_python.sys.netlink.utils import AttrDescr +from atf_python.sys.netlink.utils import prepare_attrs_map +from atf_python.sys.netlink.utils import enum_or_int + + +class NetlinkGenlMessage(StdNetlinkMessage): + messages = [] + nl_attrs_map = {} + family_name = None + + def __init__(self, helper, family_id, cmd=0): + super().__init__(helper, family_id) + self.base_hdr = GenlMsgHdr(cmd=enum_or_int(cmd)) + + def parse_base_header(self, data): + if len(data) < sizeof(GenlMsgHdr): + raise ValueError("length less than GenlMsgHdr header") + ghdr = GenlMsgHdr.from_buffer_copy(data) + return (ghdr, sizeof(GenlMsgHdr)) + + def _get_msg_type(self): + return self.base_hdr.cmd + + def print_nl_header(self, prepend=""): + # len=44, type=RTM_DELROUTE, flags=NLM_F_REQUEST|NLM_F_ACK, seq=1641163704, pid=0 # noqa: E501 + hdr = self.nl_hdr + print( + "{}len={}, family={}, flags={}(0x{:X}), seq={}, pid={}".format( + prepend, + hdr.nlmsg_len, + self.family_name, + self.get_nlm_flags_str(), + hdr.nlmsg_flags, + hdr.nlmsg_seq, + hdr.nlmsg_pid, + ) + ) + + def print_base_header(self, hdr, prepend=""): + print( + "{}cmd={} version={} reserved={}".format( + prepend, self.msg_name, hdr.version, hdr.reserved + ) + ) + + +GenlCtrlFamilyName = "nlctrl" + +class GenlCtrlMsgType(Enum): + CTRL_CMD_UNSPEC = 0 + CTRL_CMD_NEWFAMILY = 1 + CTRL_CMD_DELFAMILY = 2 + CTRL_CMD_GETFAMILY = 3 + CTRL_CMD_NEWOPS = 4 + CTRL_CMD_DELOPS = 5 + CTRL_CMD_GETOPS = 6 + CTRL_CMD_NEWMCAST_GRP = 7 + CTRL_CMD_DELMCAST_GRP = 8 + CTRL_CMD_GETMCAST_GRP = 9 + CTRL_CMD_GETPOLICY = 10 + + +class GenlCtrlAttrType(Enum): + CTRL_ATTR_FAMILY_ID = 1 + CTRL_ATTR_FAMILY_NAME = 2 + CTRL_ATTR_VERSION = 3 + CTRL_ATTR_HDRSIZE = 4 + CTRL_ATTR_MAXATTR = 5 + CTRL_ATTR_OPS = 6 + CTRL_ATTR_MCAST_GROUPS = 7 + CTRL_ATTR_POLICY = 8 + CTRL_ATTR_OP_POLICY = 9 + CTRL_ATTR_OP = 10 + + +genl_ctrl_attrs = prepare_attrs_map( + [ + AttrDescr(GenlCtrlAttrType.CTRL_ATTR_FAMILY_ID, NlAttrU16), + AttrDescr(GenlCtrlAttrType.CTRL_ATTR_FAMILY_NAME, NlAttrStr), + AttrDescr(GenlCtrlAttrType.CTRL_ATTR_VERSION, NlAttrU32), + AttrDescr(GenlCtrlAttrType.CTRL_ATTR_HDRSIZE, NlAttrU32), + AttrDescr(GenlCtrlAttrType.CTRL_ATTR_MAXATTR, NlAttrU32), + ] +) + + +class NetlinkGenlCtrlMessage(NetlinkGenlMessage): + messages = [ + NlMsgProps(GenlCtrlMsgType.CTRL_CMD_NEWFAMILY, NlMsgCategory.NEW), + NlMsgProps(GenlCtrlMsgType.CTRL_CMD_GETFAMILY, NlMsgCategory.GET), + NlMsgProps(GenlCtrlMsgType.CTRL_CMD_DELFAMILY, NlMsgCategory.DELETE), + ] + nl_attrs_map = genl_ctrl_attrs + family_name = GenlCtrlFamilyName + + +handler_classes = { + GenlCtrlFamilyName: [NetlinkGenlCtrlMessage], +} diff --git a/tests/atf_python/sys/netlink/netlink_route.py b/tests/atf_python/sys/netlink/netlink_route.py index c6163a0908af..81f4e89d3e57 100644 --- a/tests/atf_python/sys/netlink/netlink_route.py +++ b/tests/atf_python/sys/netlink/netlink_route.py @@ -1,727 +1,739 @@ import socket import struct from ctypes import c_int from ctypes import c_ubyte from ctypes import c_uint from ctypes import c_ushort from ctypes import sizeof from ctypes import Structure from enum import auto from enum import Enum from atf_python.sys.netlink.attrs import NlAttr from atf_python.sys.netlink.attrs import NlAttrIp from atf_python.sys.netlink.attrs import NlAttrNested from atf_python.sys.netlink.attrs import NlAttrStr from atf_python.sys.netlink.attrs import NlAttrU32 from atf_python.sys.netlink.attrs import NlAttrU8 from atf_python.sys.netlink.message import StdNetlinkMessage +from atf_python.sys.netlink.message import NlMsgProps +from atf_python.sys.netlink.message import NlMsgCategory from atf_python.sys.netlink.utils import AttrDescr from atf_python.sys.netlink.utils import get_bitmask_str from atf_python.sys.netlink.utils import prepare_attrs_map class RtattrType(Enum): RTA_UNSPEC = 0 RTA_DST = 1 RTA_SRC = 2 RTA_IIF = 3 RTA_OIF = 4 RTA_GATEWAY = 5 RTA_PRIORITY = 6 RTA_PREFSRC = 7 RTA_METRICS = 8 RTA_MULTIPATH = 9 # RTA_PROTOINFO = 10 RTA_KNH_ID = 10 RTA_FLOW = 11 RTA_CACHEINFO = 12 RTA_SESSION = 13 # RTA_MP_ALGO = 14 RTA_RTFLAGS = 14 RTA_TABLE = 15 RTA_MARK = 16 RTA_MFC_STATS = 17 RTA_VIA = 18 RTA_NEWDST = 19 RTA_PREF = 20 RTA_ENCAP_TYPE = 21 RTA_ENCAP = 22 RTA_EXPIRES = 23 RTA_PAD = 24 RTA_UID = 25 RTA_TTL_PROPAGATE = 26 RTA_IP_PROTO = 27 RTA_SPORT = 28 RTA_DPORT = 29 RTA_NH_ID = 30 class NlRtMsgType(Enum): RTM_NEWLINK = 16 RTM_DELLINK = 17 RTM_GETLINK = 18 RTM_SETLINK = 19 RTM_NEWADDR = 20 RTM_DELADDR = 21 RTM_GETADDR = 22 RTM_NEWROUTE = 24 RTM_DELROUTE = 25 RTM_GETROUTE = 26 RTM_NEWNEIGH = 28 RTM_DELNEIGH = 29 RTM_GETNEIGH = 30 RTM_NEWRULE = 32 RTM_DELRULE = 33 RTM_GETRULE = 34 RTM_NEWQDISC = 36 RTM_DELQDISC = 37 RTM_GETQDISC = 38 RTM_NEWTCLASS = 40 RTM_DELTCLASS = 41 RTM_GETTCLASS = 42 RTM_NEWTFILTER = 44 RTM_DELTFILTER = 45 RTM_GETTFILTER = 46 RTM_NEWACTION = 48 RTM_DELACTION = 49 RTM_GETACTION = 50 RTM_NEWPREFIX = 52 RTM_GETMULTICAST = 58 RTM_GETANYCAST = 62 RTM_NEWNEIGHTBL = 64 RTM_GETNEIGHTBL = 66 RTM_SETNEIGHTBL = 67 RTM_NEWNDUSEROPT = 68 RTM_NEWADDRLABEL = 72 RTM_DELADDRLABEL = 73 RTM_GETADDRLABEL = 74 RTM_GETDCB = 78 RTM_SETDCB = 79 RTM_NEWNETCONF = 80 RTM_GETNETCONF = 82 RTM_NEWMDB = 84 RTM_DELMDB = 85 RTM_GETMDB = 86 RTM_NEWNSID = 88 RTM_DELNSID = 89 RTM_GETNSID = 90 RTM_NEWSTATS = 92 RTM_GETSTATS = 94 class RtAttr(Structure): _fields_ = [ ("rta_len", c_ushort), ("rta_type", c_ushort), ] class RtMsgHdr(Structure): _fields_ = [ ("rtm_family", c_ubyte), ("rtm_dst_len", c_ubyte), ("rtm_src_len", c_ubyte), ("rtm_tos", c_ubyte), ("rtm_table", c_ubyte), ("rtm_protocol", c_ubyte), ("rtm_scope", c_ubyte), ("rtm_type", c_ubyte), ("rtm_flags", c_uint), ] class RtMsgFlags(Enum): RTM_F_NOTIFY = 0x100 RTM_F_CLONED = 0x200 RTM_F_EQUALIZE = 0x400 RTM_F_PREFIX = 0x800 RTM_F_LOOKUP_TABLE = 0x1000 RTM_F_FIB_MATCH = 0x2000 RTM_F_OFFLOAD = 0x4000 RTM_F_TRAP = 0x8000 RTM_F_OFFLOAD_FAILED = 0x20000000 class RtScope(Enum): RT_SCOPE_UNIVERSE = 0 RT_SCOPE_SITE = 200 RT_SCOPE_LINK = 253 RT_SCOPE_HOST = 254 RT_SCOPE_NOWHERE = 255 class RtType(Enum): RTN_UNSPEC = 0 RTN_UNICAST = auto() RTN_LOCAL = auto() RTN_BROADCAST = auto() RTN_ANYCAST = auto() RTN_MULTICAST = auto() RTN_BLACKHOLE = auto() RTN_UNREACHABLE = auto() RTN_PROHIBIT = auto() RTN_THROW = auto() RTN_NAT = auto() RTN_XRESOLVE = auto() class RtProto(Enum): RTPROT_UNSPEC = 0 RTPROT_REDIRECT = 1 RTPROT_KERNEL = 2 RTPROT_BOOT = 3 RTPROT_STATIC = 4 RTPROT_GATED = 8 RTPROT_RA = 9 RTPROT_MRT = 10 RTPROT_ZEBRA = 11 RTPROT_BIRD = 12 RTPROT_DNROUTED = 13 RTPROT_XORP = 14 RTPROT_NTK = 15 RTPROT_DHCP = 16 RTPROT_MROUTED = 17 RTPROT_KEEPALIVED = 18 RTPROT_BABEL = 42 RTPROT_OPENR = 99 RTPROT_BGP = 186 RTPROT_ISIS = 187 RTPROT_OSPF = 188 RTPROT_RIP = 189 RTPROT_EIGRP = 192 class NlRtaxType(Enum): RTAX_UNSPEC = 0 RTAX_LOCK = auto() RTAX_MTU = auto() RTAX_WINDOW = auto() RTAX_RTT = auto() RTAX_RTTVAR = auto() RTAX_SSTHRESH = auto() RTAX_CWND = auto() RTAX_ADVMSS = auto() RTAX_REORDERING = auto() RTAX_HOPLIMIT = auto() RTAX_INITCWND = auto() RTAX_FEATURES = auto() RTAX_RTO_MIN = auto() RTAX_INITRWND = auto() RTAX_QUICKACK = auto() RTAX_CC_ALGO = auto() RTAX_FASTOPEN_NO_COOKIE = auto() class RtFlagsBSD(Enum): RTF_UP = 0x1 RTF_GATEWAY = 0x2 RTF_HOST = 0x4 RTF_REJECT = 0x8 RTF_DYNAMIC = 0x10 RTF_MODIFIED = 0x20 RTF_DONE = 0x40 RTF_XRESOLVE = 0x200 RTF_LLINFO = 0x400 RTF_LLDATA = 0x400 RTF_STATIC = 0x800 RTF_BLACKHOLE = 0x1000 RTF_PROTO2 = 0x4000 RTF_PROTO1 = 0x8000 RTF_PROTO3 = 0x40000 RTF_FIXEDMTU = 0x80000 RTF_PINNED = 0x100000 RTF_LOCAL = 0x200000 RTF_BROADCAST = 0x400000 RTF_MULTICAST = 0x800000 RTF_STICKY = 0x10000000 RTF_RNH_LOCKED = 0x40000000 RTF_GWFLAG_COMPAT = 0x80000000 class NlRtGroup(Enum): RTNLGRP_NONE = 0 RTNLGRP_LINK = auto() RTNLGRP_NOTIFY = auto() RTNLGRP_NEIGH = auto() RTNLGRP_TC = auto() RTNLGRP_IPV4_IFADDR = auto() RTNLGRP_IPV4_MROUTE = auto() RTNLGRP_IPV4_ROUTE = auto() RTNLGRP_IPV4_RULE = auto() RTNLGRP_IPV6_IFADDR = auto() RTNLGRP_IPV6_MROUTE = auto() RTNLGRP_IPV6_ROUTE = auto() RTNLGRP_IPV6_IFINFO = auto() RTNLGRP_DECnet_IFADDR = auto() RTNLGRP_NOP2 = auto() RTNLGRP_DECnet_ROUTE = auto() RTNLGRP_DECnet_RULE = auto() RTNLGRP_NOP4 = auto() RTNLGRP_IPV6_PREFIX = auto() RTNLGRP_IPV6_RULE = auto() RTNLGRP_ND_USEROPT = auto() RTNLGRP_PHONET_IFADDR = auto() RTNLGRP_PHONET_ROUTE = auto() RTNLGRP_DCB = auto() RTNLGRP_IPV4_NETCONF = auto() RTNLGRP_IPV6_NETCONF = auto() RTNLGRP_MDB = auto() RTNLGRP_MPLS_ROUTE = auto() RTNLGRP_NSID = auto() RTNLGRP_MPLS_NETCONF = auto() RTNLGRP_IPV4_MROUTE_R = auto() RTNLGRP_IPV6_MROUTE_R = auto() RTNLGRP_NEXTHOP = auto() RTNLGRP_BRVLAN = auto() class IfinfoMsg(Structure): _fields_ = [ ("ifi_family", c_ubyte), ("__ifi_pad", c_ubyte), ("ifi_type", c_ushort), ("ifi_index", c_int), ("ifi_flags", c_uint), ("ifi_change", c_uint), ] class IflattrType(Enum): IFLA_UNSPEC = 0 IFLA_ADDRESS = auto() IFLA_BROADCAST = auto() IFLA_IFNAME = auto() IFLA_MTU = auto() IFLA_LINK = auto() IFLA_QDISC = auto() IFLA_STATS = auto() IFLA_COST = auto() IFLA_PRIORITY = auto() IFLA_MASTER = auto() IFLA_WIRELESS = auto() IFLA_PROTINFO = auto() IFLA_TXQLEN = auto() IFLA_MAP = auto() IFLA_WEIGHT = auto() IFLA_OPERSTATE = auto() IFLA_LINKMODE = auto() IFLA_LINKINFO = auto() IFLA_NET_NS_PID = auto() IFLA_IFALIAS = auto() IFLA_NUM_VF = auto() IFLA_VFINFO_LIST = auto() IFLA_STATS64 = auto() IFLA_VF_PORTS = auto() IFLA_PORT_SELF = auto() IFLA_AF_SPEC = auto() IFLA_GROUP = auto() IFLA_NET_NS_FD = auto() IFLA_EXT_MASK = auto() IFLA_PROMISCUITY = auto() IFLA_NUM_TX_QUEUES = auto() IFLA_NUM_RX_QUEUES = auto() IFLA_CARRIER = auto() IFLA_PHYS_PORT_ID = auto() IFLA_CARRIER_CHANGES = auto() IFLA_PHYS_SWITCH_ID = auto() IFLA_LINK_NETNSID = auto() IFLA_PHYS_PORT_NAME = auto() IFLA_PROTO_DOWN = auto() IFLA_GSO_MAX_SEGS = auto() IFLA_GSO_MAX_SIZE = auto() IFLA_PAD = auto() IFLA_XDP = auto() IFLA_EVENT = auto() IFLA_NEW_NETNSID = auto() IFLA_IF_NETNSID = auto() IFLA_CARRIER_UP_COUNT = auto() IFLA_CARRIER_DOWN_COUNT = auto() IFLA_NEW_IFINDEX = auto() IFLA_MIN_MTU = auto() IFLA_MAX_MTU = auto() IFLA_PROP_LIST = auto() IFLA_ALT_IFNAME = auto() IFLA_PERM_ADDRESS = auto() IFLA_PROTO_DOWN_REASON = auto() class IflinkInfo(Enum): IFLA_INFO_UNSPEC = 0 IFLA_INFO_KIND = auto() IFLA_INFO_DATA = auto() IFLA_INFO_XSTATS = auto() IFLA_INFO_SLAVE_KIND = auto() IFLA_INFO_SLAVE_DATA = auto() class IfLinkInfoDataVlan(Enum): IFLA_VLAN_UNSPEC = 0 IFLA_VLAN_ID = auto() IFLA_VLAN_FLAGS = auto() IFLA_VLAN_EGRESS_QOS = auto() IFLA_VLAN_INGRESS_QOS = auto() IFLA_VLAN_PROTOCOL = auto() class IfaddrMsg(Structure): _fields_ = [ ("ifa_family", c_ubyte), ("ifa_prefixlen", c_ubyte), ("ifa_flags", c_ubyte), ("ifa_scope", c_ubyte), ("ifa_index", c_uint), ] class IfattrType(Enum): IFA_UNSPEC = 0 IFA_ADDRESS = auto() IFA_LOCAL = auto() IFA_LABEL = auto() IFA_BROADCAST = auto() IFA_ANYCAST = auto() IFA_CACHEINFO = auto() IFA_MULTICAST = auto() IFA_FLAGS = auto() IFA_RT_PRIORITY = auto() IFA_TARGET_NETNSID = auto() class NdMsg(Structure): _fields_ = [ ("ndm_family", c_ubyte), ("ndm_pad1", c_ubyte), ("ndm_pad2", c_ubyte), ("ndm_ifindex", c_uint), ("ndm_state", c_ushort), ("ndm_flags", c_ubyte), ("ndm_type", c_ubyte), ] class NdAttrType(Enum): NDA_UNSPEC = 0 NDA_DST = 1 NDA_LLADDR = 2 NDA_CACHEINFO = 3 NDA_PROBES = 4 NDA_VLAN = 5 NDA_PORT = 6 NDA_VNI = 7 NDA_IFINDEX = 8 NDA_MASTER = 9 NDA_LINK_NETNSID = 10 NDA_SRC_VNI = 11 NDA_PROTOCOL = 12 NDA_NH_ID = 13 NDA_FDB_EXT_ATTRS = 14 NDA_FLAGS_EXT = 15 NDA_NDM_STATE_MASK = 16 NDA_NDM_FLAGS_MASK = 17 class NlAttrRtFlags(NlAttrU32): def _print_attr_value(self): s = get_bitmask_str(RtFlagsBSD, self.u32) return " rtflags={}".format(s) class NlAttrIfindex(NlAttrU32): def _print_attr_value(self): try: ifname = socket.if_indextoname(self.u32) return " iface={}(#{})".format(ifname, self.u32) except OSError: pass return " iface=if#{}".format(self.u32) class NlAttrTable(NlAttrU32): def _print_attr_value(self): return " rtable={}".format(self.u32) class NlAttrNhId(NlAttrU32): def _print_attr_value(self): return " nh_id={}".format(self.u32) class NlAttrKNhId(NlAttrU32): def _print_attr_value(self): return " knh_id={}".format(self.u32) class NlAttrMac(NlAttr): def _print_attr_value(self): return ' mac="' + ":".join(["{:02X}".format(b) for b in self._data]) + '"' class NlAttrIfStats(NlAttr): def _print_attr_value(self): return " stats={...}" class NlAttrVia(NlAttr): def __init__(self, nla_type, family, addr: str): super().__init__(nla_type, b"") self.addr = addr self.family = family @staticmethod def _validate(data): nla_len, nla_type = struct.unpack("@HH", data[:4]) data_len = nla_len - 4 if data_len == 0: raise ValueError( "Error validating attr {}: empty data".format(nla_type) ) # noqa: E501 family = int(data_len[0]) if family not in (socket.AF_INET, socket.AF_INET6): raise ValueError( "Error validating attr {}: unsupported AF {}".format( # noqa: E501 nla_type, family ) ) if family == socket.AF_INET: expected_len = 1 + 4 else: expected_len = 1 + 16 if data_len != expected_len: raise ValueError( "Error validating attr {}: expected len {} got {}".format( # noqa: E501 nla_type, expected_len, data_len ) ) @property def nla_len(self): if self.family == socket.AF_INET6: return 21 else: return 9 @classmethod def _parse(cls, data): nla_len, nla_type, family = struct.unpack("@HHB", data[:5]) off = 5 if family == socket.AF_INET: addr = socket.inet_ntop(family, data[off:off + 4]) else: addr = socket.inet_ntop(family, data[off:off + 16]) return cls(nla_type, family, addr) def __bytes__(self): addr = socket.inet_pton(self.family, self.addr) return self._to_bytes(struct.pack("@B", self.family) + addr) def _print_attr_value(self): return " via={}".format(self.addr) rtnl_route_attrs = prepare_attrs_map( [ AttrDescr(RtattrType.RTA_DST, NlAttrIp), AttrDescr(RtattrType.RTA_SRC, NlAttrIp), AttrDescr(RtattrType.RTA_IIF, NlAttrIfindex), AttrDescr(RtattrType.RTA_OIF, NlAttrIfindex), AttrDescr(RtattrType.RTA_GATEWAY, NlAttrIp), AttrDescr(RtattrType.RTA_TABLE, NlAttrTable), AttrDescr(RtattrType.RTA_PRIORITY, NlAttrU32), AttrDescr(RtattrType.RTA_VIA, NlAttrVia), AttrDescr(RtattrType.RTA_NH_ID, NlAttrNhId), AttrDescr(RtattrType.RTA_KNH_ID, NlAttrKNhId), AttrDescr(RtattrType.RTA_RTFLAGS, NlAttrRtFlags), AttrDescr( RtattrType.RTA_METRICS, NlAttrNested, [ AttrDescr(NlRtaxType.RTAX_MTU, NlAttrU32), ], ), ] ) rtnl_ifla_attrs = prepare_attrs_map( [ AttrDescr(IflattrType.IFLA_ADDRESS, NlAttrMac), AttrDescr(IflattrType.IFLA_BROADCAST, NlAttrMac), AttrDescr(IflattrType.IFLA_IFNAME, NlAttrStr), AttrDescr(IflattrType.IFLA_MTU, NlAttrU32), AttrDescr(IflattrType.IFLA_LINK, NlAttrU32), AttrDescr(IflattrType.IFLA_PROMISCUITY, NlAttrU32), AttrDescr(IflattrType.IFLA_OPERSTATE, NlAttrU8), AttrDescr(IflattrType.IFLA_CARRIER, NlAttrU8), AttrDescr(IflattrType.IFLA_IFALIAS, NlAttrStr), AttrDescr(IflattrType.IFLA_STATS64, NlAttrIfStats), AttrDescr(IflattrType.IFLA_NEW_IFINDEX, NlAttrU32), AttrDescr( IflattrType.IFLA_LINKINFO, NlAttrNested, [ AttrDescr(IflinkInfo.IFLA_INFO_KIND, NlAttrStr), AttrDescr(IflinkInfo.IFLA_INFO_DATA, NlAttr), ], ), ] ) rtnl_ifa_attrs = prepare_attrs_map( [ AttrDescr(IfattrType.IFA_ADDRESS, NlAttrIp), AttrDescr(IfattrType.IFA_LOCAL, NlAttrIp), AttrDescr(IfattrType.IFA_LABEL, NlAttrStr), AttrDescr(IfattrType.IFA_BROADCAST, NlAttrIp), AttrDescr(IfattrType.IFA_ANYCAST, NlAttrIp), AttrDescr(IfattrType.IFA_FLAGS, NlAttrU32), ] ) rtnl_nd_attrs = prepare_attrs_map( [ AttrDescr(NdAttrType.NDA_DST, NlAttrIp), AttrDescr(NdAttrType.NDA_IFINDEX, NlAttrIfindex), AttrDescr(NdAttrType.NDA_FLAGS_EXT, NlAttrU32), AttrDescr(NdAttrType.NDA_LLADDR, NlAttrMac), ] ) class BaseNetlinkRtMessage(StdNetlinkMessage): pass class NetlinkRtMessage(BaseNetlinkRtMessage): messages = [ - NlRtMsgType.RTM_NEWROUTE.value, - NlRtMsgType.RTM_DELROUTE.value, - NlRtMsgType.RTM_GETROUTE.value, + NlMsgProps(NlRtMsgType.RTM_NEWROUTE, NlMsgCategory.NEW), + NlMsgProps(NlRtMsgType.RTM_DELROUTE, NlMsgCategory.DELETE), + NlMsgProps(NlRtMsgType.RTM_GETROUTE, NlMsgCategory.GET), ] nl_attrs_map = rtnl_route_attrs def __init__(self, helper, nlm_type): super().__init__(helper, nlm_type) self.base_hdr = RtMsgHdr() def parse_base_header(self, data): if len(data) < sizeof(RtMsgHdr): raise ValueError("length less than rtmsg header") rtm_hdr = RtMsgHdr.from_buffer_copy(data) return (rtm_hdr, sizeof(RtMsgHdr)) def print_base_header(self, hdr, prepend=""): family = self.helper.get_af_name(hdr.rtm_family) print( "{}family={}, dst_len={}, src_len={}, tos={}, table={}, protocol={}({}), scope={}({}), type={}({}), flags={}({})".format( # noqa: E501 prepend, family, hdr.rtm_dst_len, hdr.rtm_src_len, hdr.rtm_tos, hdr.rtm_table, self.helper.get_attr_byval(RtProto, hdr.rtm_protocol), hdr.rtm_protocol, self.helper.get_attr_byval(RtScope, hdr.rtm_scope), hdr.rtm_scope, self.helper.get_attr_byval(RtType, hdr.rtm_type), hdr.rtm_type, self.helper.get_bitmask_str(RtMsgFlags, hdr.rtm_flags), hdr.rtm_flags, ) ) class NetlinkIflaMessage(BaseNetlinkRtMessage): messages = [ - NlRtMsgType.RTM_NEWLINK.value, - NlRtMsgType.RTM_DELLINK.value, - NlRtMsgType.RTM_GETLINK.value, + NlMsgProps(NlRtMsgType.RTM_NEWLINK, NlMsgCategory.NEW), + NlMsgProps(NlRtMsgType.RTM_DELLINK, NlMsgCategory.DELETE), + NlMsgProps(NlRtMsgType.RTM_GETLINK, NlMsgCategory.GET), ] nl_attrs_map = rtnl_ifla_attrs def __init__(self, helper, nlm_type): super().__init__(helper, nlm_type) self.base_hdr = IfinfoMsg() def parse_base_header(self, data): if len(data) < sizeof(IfinfoMsg): raise ValueError("length less than IfinfoMsg header") rtm_hdr = IfinfoMsg.from_buffer_copy(data) return (rtm_hdr, sizeof(IfinfoMsg)) def print_base_header(self, hdr, prepend=""): family = self.helper.get_af_name(hdr.ifi_family) print( "{}family={}, ifi_type={}, ifi_index={}, ifi_flags={}, ifi_change={}".format( # noqa: E501 prepend, family, hdr.ifi_type, hdr.ifi_index, hdr.ifi_flags, hdr.ifi_change, ) ) class NetlinkIfaMessage(BaseNetlinkRtMessage): messages = [ - NlRtMsgType.RTM_NEWADDR.value, - NlRtMsgType.RTM_DELADDR.value, - NlRtMsgType.RTM_GETADDR.value, + NlMsgProps(NlRtMsgType.RTM_NEWADDR, NlMsgCategory.NEW), + NlMsgProps(NlRtMsgType.RTM_DELADDR, NlMsgCategory.DELETE), + NlMsgProps(NlRtMsgType.RTM_GETADDR, NlMsgCategory.GET), ] nl_attrs_map = rtnl_ifa_attrs def __init__(self, helper, nlm_type): super().__init__(helper, nlm_type) self.base_hdr = IfaddrMsg() def parse_base_header(self, data): if len(data) < sizeof(IfaddrMsg): raise ValueError("length less than IfaddrMsg header") rtm_hdr = IfaddrMsg.from_buffer_copy(data) return (rtm_hdr, sizeof(IfaddrMsg)) def print_base_header(self, hdr, prepend=""): family = self.helper.get_af_name(hdr.ifa_family) print( "{}family={}, ifa_prefixlen={}, ifa_flags={}, ifa_scope={}, ifa_index={}".format( # noqa: E501 prepend, family, hdr.ifa_prefixlen, hdr.ifa_flags, hdr.ifa_scope, hdr.ifa_index, ) ) class NetlinkNdMessage(BaseNetlinkRtMessage): messages = [ - NlRtMsgType.RTM_NEWNEIGH.value, - NlRtMsgType.RTM_DELNEIGH.value, - NlRtMsgType.RTM_GETNEIGH.value, + NlMsgProps(NlRtMsgType.RTM_NEWNEIGH, NlMsgCategory.NEW), + NlMsgProps(NlRtMsgType.RTM_DELNEIGH, NlMsgCategory.DELETE), + NlMsgProps(NlRtMsgType.RTM_GETNEIGH, NlMsgCategory.GET), ] nl_attrs_map = rtnl_nd_attrs def __init__(self, helper, nlm_type): super().__init__(helper, nlm_type) self.base_hdr = NdMsg() def parse_base_header(self, data): if len(data) < sizeof(NdMsg): raise ValueError("length less than NdMsg header") nd_hdr = NdMsg.from_buffer_copy(data) return (nd_hdr, sizeof(NdMsg)) def print_base_header(self, hdr, prepend=""): family = self.helper.get_af_name(hdr.ndm_family) print( "{}family={}, ndm_ifindex={}, ndm_state={}, ndm_flags={}".format( # noqa: E501 prepend, family, hdr.ndm_ifindex, hdr.ndm_state, hdr.ndm_flags, ) ) + + +handler_classes = { + "netlink_route": [ + NetlinkRtMessage, + NetlinkIflaMessage, + NetlinkIfaMessage, + NetlinkNdMessage, + ], +} diff --git a/tests/atf_python/sys/netlink/utils.py b/tests/atf_python/sys/netlink/utils.py index 44758148747d..86a910ce6590 100644 --- a/tests/atf_python/sys/netlink/utils.py +++ b/tests/atf_python/sys/netlink/utils.py @@ -1,74 +1,80 @@ #!/usr/local/bin/python3 from enum import Enum from typing import Any from typing import Dict from typing import List from typing import NamedTuple from atf_python.sys.netlink.attrs import NlAttr class NlConst: AF_NETLINK = 38 NETLINK_ROUTE = 0 NETLINK_GENERIC = 16 + GENL_ID_CTRL = 16 def roundup2(val: int, num: int) -> int: if val % num: return (val | (num - 1)) + 1 else: return val def align4(val: int) -> int: return roundup2(val, 4) def enum_or_int(val) -> int: if isinstance(val, Enum): return val.value return val class AttrDescr(NamedTuple): val: Enum cls: "NlAttr" child_map: Any = None def prepare_attrs_map(attrs: List[AttrDescr]) -> Dict[str, Dict]: ret = {} for ad in attrs: ret[ad.val.value] = {"ad": ad} if ad.child_map: ret[ad.val.value]["child"] = prepare_attrs_map(ad.child_map) return ret def build_propmap(cls): ret = {} for prop in dir(cls): if not prop.startswith("_"): ret[getattr(cls, prop).value] = prop return ret def get_bitmask_map(propmap, val): v = 1 ret = {} while val: if v & val: if v in propmap: ret[v] = propmap[v] else: ret[v] = hex(v) val -= v v *= 2 return ret def get_bitmask_str(cls, val): - pmap = build_propmap(cls) + if isinstance(cls, type): + pmap = build_propmap(cls) + else: + pmap = {} + for _cls in cls: + pmap.update(build_propmap(_cls)) bmap = get_bitmask_map(pmap, val) return ",".join([v for k, v in bmap.items()])