diff --git a/tests/atf_python/sys/netlink/Makefile b/tests/atf_python/sys/netlink/Makefile --- a/tests/atf_python/sys/netlink/Makefile +++ b/tests/atf_python/sys/netlink/Makefile @@ -2,7 +2,8 @@ .PATH: ${.CURDIR} -FILES= __init__.py attrs.py base_headers.py message.py netlink.py netlink_route.py utils.py +FILES= __init__.py attrs.py base_headers.py message.py netlink.py \ + netlink_generic.py netlink_route.py utils.py .include FILESDIR= ${TESTSBASE}/atf_python/sys/netlink diff --git a/tests/atf_python/sys/netlink/message.py b/tests/atf_python/sys/netlink/message.py --- a/tests/atf_python/sys/netlink/message.py +++ b/tests/atf_python/sys/netlink/message.py @@ -1,15 +1,35 @@ #!/usr/local/bin/python3 import struct from ctypes import sizeof +from enum import Enum from typing import List +from typing import NamedTuple from atf_python.sys.netlink.attrs import NlAttr from atf_python.sys.netlink.attrs import NlAttrNested +from atf_python.sys.netlink.base_headers import NlmAckFlags +from atf_python.sys.netlink.base_headers import NlmNewFlags +from atf_python.sys.netlink.base_headers import NlmGetFlags +from atf_python.sys.netlink.base_headers import NlmDeleteFlags from atf_python.sys.netlink.base_headers import NlmBaseFlags from atf_python.sys.netlink.base_headers import Nlmsghdr from atf_python.sys.netlink.base_headers import NlMsgType from atf_python.sys.netlink.utils import align4 from atf_python.sys.netlink.utils import enum_or_int +from atf_python.sys.netlink.utils import get_bitmask_str + + +class NlMsgCategory(Enum): + UNKNOWN = 0 + GET = 1 + NEW = 2 + DELETE = 3 + ACK = 4 + + +class NlMsgProps(NamedTuple): + msg: Enum + category: NlMsgCategory class BaseNetlinkMessage(object): @@ -60,18 +80,40 @@ def is_reply(self, hdr): return hdr.nlmsg_type == NlMsgType.NLMSG_ERROR.value - def print_nl_header(self, hdr, prepend=""): + @property + def msg_name(self): + return "msg#{}".format(self._get_msg_type()) + + def _get_nl_category(self): + if self.is_reply(self.nl_hdr): + return NlMsgCategory.ACK + return NlMsgCategory.UNKNOWN + + def get_nlm_flags_str(self): + category = self._get_nl_category() + flags = self.nl_hdr.nlmsg_flags + + if category == NlMsgCategory.UNKNOWN: + return self.helper.get_bitmask_str(NlmBaseFlags, flags) + elif category == NlMsgCategory.GET: + flags_enum = NlmGetFlags + elif category == NlMsgCategory.NEW: + flags_enum = NlmNewFlags + elif category == NlMsgCategory.DELETE: + flags_enum = NlmDeleteFlags + elif category == NlMsgCategory.ACK: + flags_enum = NlmAckFlags + return get_bitmask_str([NlmBaseFlags, flags_enum], flags) + + def print_nl_header(self, prepend=""): # len=44, type=RTM_DELROUTE, flags=NLM_F_REQUEST|NLM_F_ACK, seq=1641163704, pid=0 # noqa: E501 - is_reply = self.is_reply(hdr) - msg_name = self.helper.get_nlmsg_name(hdr.nlmsg_type) + hdr = self.nl_hdr print( "{}len={}, type={}, flags={}(0x{:X}), seq={}, pid={}".format( prepend, hdr.nlmsg_len, - msg_name, - self.helper.get_nlm_flags_str( - msg_name, is_reply, hdr.nlmsg_flags - ), # noqa: E501 + self.msg_name, + self.get_nlm_flags_str(), hdr.nlmsg_flags, hdr.nlmsg_seq, hdr.nlmsg_pid, @@ -92,7 +134,7 @@ return self def print_message(self): - self.print_nl_header(self.nl_hdr) + self.print_nl_header() @staticmethod def print_as_bytes(data: bytes, descr: str): @@ -191,11 +233,29 @@ self.nl_hdr.nlmsg_len = len(ret) + sizeof(Nlmsghdr) return bytes(self.nl_hdr) + ret + def _get_msg_type(self): + return self.nl_hdr.nlmsg_type + + @property + def msg_props(self): + msg_type = self._get_msg_type() + for msg_props in self.messages: + if msg_props.msg.value == msg_type: + return msg_props + return None + + @property + def msg_name(self): + msg_props = self.msg_props + if msg_props is not None: + return msg_props.msg.name + return super().msg_name + def print_base_header(self, hdr, prepend=""): pass def print_message(self): - self.print_nl_header(self.nl_hdr) + self.print_nl_header() self.print_base_header(self.base_hdr, " ") for nla in self.nla_list: nla.print_attr(" ") diff --git a/tests/atf_python/sys/netlink/netlink.py b/tests/atf_python/sys/netlink/netlink.py --- a/tests/atf_python/sys/netlink/netlink.py +++ b/tests/atf_python/sys/netlink/netlink.py @@ -14,23 +14,22 @@ from atf_python.sys.netlink.attrs import NlAttr from atf_python.sys.netlink.attrs import NlAttrStr from atf_python.sys.netlink.attrs import NlAttrU32 -from atf_python.sys.netlink.base_headers import NlmAckFlags +from atf_python.sys.netlink.base_headers import GenlMsgHdr from atf_python.sys.netlink.base_headers import NlmBaseFlags -from atf_python.sys.netlink.base_headers import NlmDeleteFlags -from atf_python.sys.netlink.base_headers import NlmGetFlags -from atf_python.sys.netlink.base_headers import NlmNewFlags from atf_python.sys.netlink.base_headers import Nlmsghdr from atf_python.sys.netlink.base_headers import NlMsgType from atf_python.sys.netlink.message import BaseNetlinkMessage +from atf_python.sys.netlink.message import NlMsgCategory +from atf_python.sys.netlink.message import NlMsgProps from atf_python.sys.netlink.message import StdNetlinkMessage -from atf_python.sys.netlink.netlink_route import NetlinkIfaMessage -from atf_python.sys.netlink.netlink_route import NetlinkIflaMessage -from atf_python.sys.netlink.netlink_route import NetlinkNdMessage -from atf_python.sys.netlink.netlink_route import NetlinkRtMessage -from atf_python.sys.netlink.netlink_route import NlRtMsgType +from atf_python.sys.netlink.netlink_generic import GenlCtrlMsgType +from atf_python.sys.netlink.netlink_generic import GenlCtrlAttrType +from atf_python.sys.netlink.netlink_generic import handler_classes as genl_classes +from atf_python.sys.netlink.netlink_route import handler_classes as rt_classes from atf_python.sys.netlink.utils import align4 from atf_python.sys.netlink.utils import AttrDescr from atf_python.sys.netlink.utils import build_propmap +from atf_python.sys.netlink.utils import enum_or_int from atf_python.sys.netlink.utils import get_bitmask_map from atf_python.sys.netlink.utils import NlConst from atf_python.sys.netlink.utils import prepare_attrs_map @@ -114,13 +113,6 @@ propmap = self.get_propmap(cls) return propmap.get(attr_val) - def get_nlmsg_name(self, val): - for cls in [NlRtMsgType, NlMsgType]: - v = self.get_attr_byval(cls, val) - if v is not None: - return v - return "msg#{}".format(val) - def get_af_name(self, family): v = self.get_attr_byval(self._af_cls, family) if v is not None: @@ -141,18 +133,6 @@ bmap = NlHelper.get_bitmask_map(pmap, val) return ",".join([v for k, v in bmap.items()]) - def get_nlm_flags_str(self, msg_str: str, reply: bool, val): - if reply: - return self.get_bitmask_str(NlmAckFlags, val) - if msg_str.startswith("RTM_GET"): - return self.get_bitmask_str(NlmGetFlags, val) - elif msg_str.startswith("RTM_DEL"): - return self.get_bitmask_str(NlmDeleteFlags, val) - elif msg_str.startswith("RTM_NEW"): - return self.get_bitmask_str(NlmNewFlags, val) - else: - return self.get_bitmask_str(NlmBaseFlags, val) - nldone_attrs = prepare_attrs_map([]) @@ -166,7 +146,7 @@ class NetlinkDoneMessage(StdNetlinkMessage): - messages = [NlMsgType.NLMSG_DONE.value] + messages = [NlMsgProps(NlMsgType.NLMSG_DONE, NlMsgCategory.ACK)] nl_attrs_map = nldone_attrs @property @@ -185,7 +165,7 @@ class NetlinkErrorMessage(StdNetlinkMessage): - messages = [NlMsgType.NLMSG_ERROR.value] + messages = [NlMsgProps(NlMsgType.NLMSG_ERROR, NlMsgCategory.ACK)] nl_attrs_map = nlerr_attrs @property @@ -221,30 +201,52 @@ def print_base_header(self, errhdr, prepend=""): print("{}error={}, ".format(prepend, errhdr.error), end="") - self.print_nl_header(errhdr.msg, prepend) + hdr = errhdr.msg + print( + "{}len={}, type={}, flags={}(0x{:X}), seq={}, pid={}".format( + prepend, + hdr.nlmsg_len, + "msg#{}".format(hdr.nlmsg_type), + self.helper.get_bitmask_str(NlmBaseFlags, hdr.nlmsg_flags), + hdr.nlmsg_flags, + hdr.nlmsg_seq, + hdr.nlmsg_pid, + ) + ) + + +core_classes = { + "netlink_core": [ + NetlinkDoneMessage, + NetlinkErrorMessage, + ], +} class Nlsock: + HANDLER_CLASSES = [core_classes, rt_classes, genl_classes] + def __init__(self, family, helper): self.helper = helper self.sock_fd = self._setup_netlink(family) + self._sock_family = family self._data = bytes() self.msgmap = self.build_msgmap() - # self.set_groups(NlRtGroup.RTNLGRP_IPV4_ROUTE.value | NlRtGroup.RTNLGRP_IPV6_ROUTE.value) # noqa: E501 + self._family_map = { + NlConst.GENL_ID_CTRL: "nlctrl", + } def build_msgmap(self): - classes = [ - NetlinkRtMessage, - NetlinkIflaMessage, - NetlinkIfaMessage, - NetlinkNdMessage, - NetlinkDoneMessage, - NetlinkErrorMessage, - ] + handler_classes = {} + for d in self.HANDLER_CLASSES: + handler_classes.update(d) xmap = {} - for cls in classes: - for message in cls.messages: - xmap[message] = cls + # 'family_name': [class.messages[MsgProps.msg], ] + for family_id, family_classes in handler_classes.items(): + xmap[family_id] = {} + for cls in family_classes: + for msg_props in cls.messages: + xmap[family_id][enum_or_int(msg_props.msg)] = cls return xmap def _setup_netlink(self, netlink_family) -> int: @@ -263,9 +265,10 @@ # k = struct.pack("@BBHII", 12, 38, 0, self.pid, mask) # self.sock_fd.bind(k) - def write_message(self, msg): - print("vvvvvvvv OUT vvvvvvvv") - msg.print_message() + def write_message(self, msg, verbose=True): + if verbose: + print("vvvvvvvv OUT vvvvvvvv") + msg.print_message() msg_bytes = bytes(msg) try: ret = os.write(self.sock_fd.fileno(), msg_bytes) @@ -277,12 +280,48 @@ if len(data) < sizeof(Nlmsghdr): raise Exception("Short read from nl: {} bytes".format(len(data))) hdr = Nlmsghdr.from_buffer_copy(data) - nlmsg_type = hdr.nlmsg_type - cls = self.msgmap.get(nlmsg_type) + if hdr.nlmsg_type < 16: + family_name = "netlink_core" + nlmsg_type = hdr.nlmsg_type + elif self._sock_family == NlConst.NETLINK_ROUTE: + family_name = "netlink_route" + nlmsg_type = hdr.nlmsg_type + else: + # Genetlink + if len(data) < sizeof(Nlmsghdr) + sizeof(GenlMsgHdr): + raise Exception("Short read from genl: {} bytes".format(len(data))) + family_name = self._family_map.get(hdr.nlmsg_type, "") + ghdr = GenlMsgHdr.from_buffer_copy(data[sizeof(Nlmsghdr):]) + nlmsg_type = ghdr.cmd + cls = self.msgmap.get(family_name, {}).get(nlmsg_type) if not cls: cls = BaseNetlinkMessage return cls.from_bytes(self.helper, data) + def get_genl_family_id(self, family_name): + hdr = Nlmsghdr( + nlmsg_type=NlConst.GENL_ID_CTRL, + nlmsg_flags=NlmBaseFlags.NLM_F_REQUEST.value, + nlmsg_seq = self.helper.get_seq(), + ) + ghdr = GenlMsgHdr(cmd=GenlCtrlMsgType.CTRL_CMD_GETFAMILY.value) + nla = NlAttrStr(GenlCtrlAttrType.CTRL_ATTR_FAMILY_NAME, family_name) + hdr.nlmsg_len = sizeof(Nlmsghdr) + sizeof(GenlMsgHdr) + len(bytes(nla)) + + msg_bytes = bytes(hdr) + bytes(ghdr) + bytes(nla) + self.write_data(msg_bytes) + while True: + rx_msg = self.read_message() + if hdr.nlmsg_seq == rx_msg.nl_hdr.nlmsg_seq: + if rx_msg.is_type(NlMsgType.NLMSG_ERROR): + if rx_msg.error_code != 0: + raise ValueError("unable to get family {}".format(family_name)) + else: + family_id = rx_msg.get_nla(GenlCtrlAttrType.CTRL_ATTR_FAMILY_ID).u16 + self._family_map[family_id] = family_name + return family_id + raise ValueError("unable to get family {}".format(family_name)) + def write_data(self, data: bytes): self.sock_fd.send(data) diff --git a/tests/atf_python/sys/netlink/netlink_generic.py b/tests/atf_python/sys/netlink/netlink_generic.py new file mode 100644 --- /dev/null +++ b/tests/atf_python/sys/netlink/netlink_generic.py @@ -0,0 +1,110 @@ +#!/usr/local/bin/python3 +from ctypes import sizeof +from enum import Enum + +from atf_python.sys.netlink.attrs import NlAttrStr +from atf_python.sys.netlink.attrs import NlAttrU16 +from atf_python.sys.netlink.attrs import NlAttrU32 +from atf_python.sys.netlink.base_headers import GenlMsgHdr +from atf_python.sys.netlink.message import NlMsgCategory +from atf_python.sys.netlink.message import NlMsgProps +from atf_python.sys.netlink.message import StdNetlinkMessage +from atf_python.sys.netlink.utils import AttrDescr +from atf_python.sys.netlink.utils import prepare_attrs_map +from atf_python.sys.netlink.utils import enum_or_int + + +class NetlinkGenlMessage(StdNetlinkMessage): + messages = [] + nl_attrs_map = {} + family_name = None + + def __init__(self, helper, family_id, cmd=0): + super().__init__(helper, family_id) + self.base_hdr = GenlMsgHdr(cmd=enum_or_int(cmd)) + + def parse_base_header(self, data): + if len(data) < sizeof(GenlMsgHdr): + raise ValueError("length less than GenlMsgHdr header") + ghdr = GenlMsgHdr.from_buffer_copy(data) + return (ghdr, sizeof(GenlMsgHdr)) + + def _get_msg_type(self): + return self.base_hdr.cmd + + def print_nl_header(self, prepend=""): + # len=44, type=RTM_DELROUTE, flags=NLM_F_REQUEST|NLM_F_ACK, seq=1641163704, pid=0 # noqa: E501 + hdr = self.nl_hdr + print( + "{}len={}, family={}, flags={}(0x{:X}), seq={}, pid={}".format( + prepend, + hdr.nlmsg_len, + self.family_name, + self.get_nlm_flags_str(), + hdr.nlmsg_flags, + hdr.nlmsg_seq, + hdr.nlmsg_pid, + ) + ) + + def print_base_header(self, hdr, prepend=""): + print( + "{}cmd={} version={} reserved={}".format( + prepend, self.msg_name, hdr.version, hdr.reserved + ) + ) + + +GenlCtrlFamilyName = "nlctrl" + +class GenlCtrlMsgType(Enum): + CTRL_CMD_UNSPEC = 0 + CTRL_CMD_NEWFAMILY = 1 + CTRL_CMD_DELFAMILY = 2 + CTRL_CMD_GETFAMILY = 3 + CTRL_CMD_NEWOPS = 4 + CTRL_CMD_DELOPS = 5 + CTRL_CMD_GETOPS = 6 + CTRL_CMD_NEWMCAST_GRP = 7 + CTRL_CMD_DELMCAST_GRP = 8 + CTRL_CMD_GETMCAST_GRP = 9 + CTRL_CMD_GETPOLICY = 10 + + +class GenlCtrlAttrType(Enum): + CTRL_ATTR_FAMILY_ID = 1 + CTRL_ATTR_FAMILY_NAME = 2 + CTRL_ATTR_VERSION = 3 + CTRL_ATTR_HDRSIZE = 4 + CTRL_ATTR_MAXATTR = 5 + CTRL_ATTR_OPS = 6 + CTRL_ATTR_MCAST_GROUPS = 7 + CTRL_ATTR_POLICY = 8 + CTRL_ATTR_OP_POLICY = 9 + CTRL_ATTR_OP = 10 + + +genl_ctrl_attrs = prepare_attrs_map( + [ + AttrDescr(GenlCtrlAttrType.CTRL_ATTR_FAMILY_ID, NlAttrU16), + AttrDescr(GenlCtrlAttrType.CTRL_ATTR_FAMILY_NAME, NlAttrStr), + AttrDescr(GenlCtrlAttrType.CTRL_ATTR_VERSION, NlAttrU32), + AttrDescr(GenlCtrlAttrType.CTRL_ATTR_HDRSIZE, NlAttrU32), + AttrDescr(GenlCtrlAttrType.CTRL_ATTR_MAXATTR, NlAttrU32), + ] +) + + +class NetlinkGenlCtrlMessage(NetlinkGenlMessage): + messages = [ + NlMsgProps(GenlCtrlMsgType.CTRL_CMD_NEWFAMILY, NlMsgCategory.NEW), + NlMsgProps(GenlCtrlMsgType.CTRL_CMD_GETFAMILY, NlMsgCategory.GET), + NlMsgProps(GenlCtrlMsgType.CTRL_CMD_DELFAMILY, NlMsgCategory.DELETE), + ] + nl_attrs_map = genl_ctrl_attrs + family_name = GenlCtrlFamilyName + + +handler_classes = { + GenlCtrlFamilyName: [NetlinkGenlCtrlMessage], +} diff --git a/tests/atf_python/sys/netlink/netlink_route.py b/tests/atf_python/sys/netlink/netlink_route.py --- a/tests/atf_python/sys/netlink/netlink_route.py +++ b/tests/atf_python/sys/netlink/netlink_route.py @@ -16,6 +16,8 @@ from atf_python.sys.netlink.attrs import NlAttrU32 from atf_python.sys.netlink.attrs import NlAttrU8 from atf_python.sys.netlink.message import StdNetlinkMessage +from atf_python.sys.netlink.message import NlMsgProps +from atf_python.sys.netlink.message import NlMsgCategory from atf_python.sys.netlink.utils import AttrDescr from atf_python.sys.netlink.utils import get_bitmask_str from atf_python.sys.netlink.utils import prepare_attrs_map @@ -594,9 +596,9 @@ class NetlinkRtMessage(BaseNetlinkRtMessage): messages = [ - NlRtMsgType.RTM_NEWROUTE.value, - NlRtMsgType.RTM_DELROUTE.value, - NlRtMsgType.RTM_GETROUTE.value, + NlMsgProps(NlRtMsgType.RTM_NEWROUTE, NlMsgCategory.NEW), + NlMsgProps(NlRtMsgType.RTM_DELROUTE, NlMsgCategory.DELETE), + NlMsgProps(NlRtMsgType.RTM_GETROUTE, NlMsgCategory.GET), ] nl_attrs_map = rtnl_route_attrs @@ -634,9 +636,9 @@ class NetlinkIflaMessage(BaseNetlinkRtMessage): messages = [ - NlRtMsgType.RTM_NEWLINK.value, - NlRtMsgType.RTM_DELLINK.value, - NlRtMsgType.RTM_GETLINK.value, + NlMsgProps(NlRtMsgType.RTM_NEWLINK, NlMsgCategory.NEW), + NlMsgProps(NlRtMsgType.RTM_DELLINK, NlMsgCategory.DELETE), + NlMsgProps(NlRtMsgType.RTM_GETLINK, NlMsgCategory.GET), ] nl_attrs_map = rtnl_ifla_attrs @@ -666,9 +668,9 @@ class NetlinkIfaMessage(BaseNetlinkRtMessage): messages = [ - NlRtMsgType.RTM_NEWADDR.value, - NlRtMsgType.RTM_DELADDR.value, - NlRtMsgType.RTM_GETADDR.value, + NlMsgProps(NlRtMsgType.RTM_NEWADDR, NlMsgCategory.NEW), + NlMsgProps(NlRtMsgType.RTM_DELADDR, NlMsgCategory.DELETE), + NlMsgProps(NlRtMsgType.RTM_GETADDR, NlMsgCategory.GET), ] nl_attrs_map = rtnl_ifa_attrs @@ -698,9 +700,9 @@ class NetlinkNdMessage(BaseNetlinkRtMessage): messages = [ - NlRtMsgType.RTM_NEWNEIGH.value, - NlRtMsgType.RTM_DELNEIGH.value, - NlRtMsgType.RTM_GETNEIGH.value, + NlMsgProps(NlRtMsgType.RTM_NEWNEIGH, NlMsgCategory.NEW), + NlMsgProps(NlRtMsgType.RTM_DELNEIGH, NlMsgCategory.DELETE), + NlMsgProps(NlRtMsgType.RTM_GETNEIGH, NlMsgCategory.GET), ] nl_attrs_map = rtnl_nd_attrs @@ -725,3 +727,13 @@ hdr.ndm_flags, ) ) + + +handler_classes = { + "netlink_route": [ + NetlinkRtMessage, + NetlinkIflaMessage, + NetlinkIfaMessage, + NetlinkNdMessage, + ], +} diff --git a/tests/atf_python/sys/netlink/utils.py b/tests/atf_python/sys/netlink/utils.py --- a/tests/atf_python/sys/netlink/utils.py +++ b/tests/atf_python/sys/netlink/utils.py @@ -12,6 +12,7 @@ AF_NETLINK = 38 NETLINK_ROUTE = 0 NETLINK_GENERIC = 16 + GENL_ID_CTRL = 16 def roundup2(val: int, num: int) -> int: @@ -69,6 +70,11 @@ def get_bitmask_str(cls, val): - pmap = build_propmap(cls) + if isinstance(cls, type): + pmap = build_propmap(cls) + else: + pmap = {} + for _cls in cls: + pmap.update(build_propmap(_cls)) bmap = get_bitmask_map(pmap, val) return ",".join([v for k, v in bmap.items()])