diff --git a/tests/atf_python/sys/netlink/attrs.py b/tests/atf_python/sys/netlink/attrs.py index 58fbab7fc8db..9769ef0fc76f 100644 --- a/tests/atf_python/sys/netlink/attrs.py +++ b/tests/atf_python/sys/netlink/attrs.py @@ -1,289 +1,289 @@ import socket import struct from enum import Enum from atf_python.sys.netlink.utils import align4 from atf_python.sys.netlink.utils import enum_or_int class NlAttr(object): HDR_LEN = 4 # sizeof(struct nlattr) def __init__(self, nla_type, data): if isinstance(nla_type, Enum): self._nla_type = nla_type.value self._enum = nla_type else: self._nla_type = nla_type self._enum = None self.nla_list = [] self._data = data @property def nla_type(self): - return self._nla_type & 0x3F + return self._nla_type & 0x3FFF @property def nla_len(self): return len(self._data) + 4 def add_nla(self, nla): self.nla_list.append(nla) def print_attr(self, prepend=""): if self._enum is not None: type_str = self._enum.name else: type_str = "nla#{}".format(self.nla_type) print( "{}len={} type={}({}){}".format( prepend, self.nla_len, type_str, self.nla_type, self._print_attr_value() ) ) @staticmethod def _validate(data): if len(data) < 4: raise ValueError("attribute too short") nla_len, nla_type = struct.unpack("@HH", data[:4]) if nla_len > len(data): raise ValueError("attribute length too big") if nla_len < 4: raise ValueError("attribute length too short") @classmethod def _parse(cls, data): nla_len, nla_type = struct.unpack("@HH", data[:4]) return cls(nla_type, data[4:]) @classmethod def from_bytes(cls, data, attr_type_enum=None): cls._validate(data) attr = cls._parse(data) attr._enum = attr_type_enum return attr def _to_bytes(self, data: bytes): ret = data if align4(len(ret)) != len(ret): ret = data + bytes(align4(len(ret)) - len(ret)) return struct.pack("@HH", len(data) + 4, self._nla_type) + ret def __bytes__(self): return self._to_bytes(self._data) def _print_attr_value(self): return " " + " ".join(["x{:02X}".format(b) for b in self._data]) class NlAttrNested(NlAttr): def __init__(self, nla_type, val): super().__init__(nla_type, b"") self.nla_list = val @property def nla_len(self): return align4(len(b"".join([bytes(nla) for nla in self.nla_list]))) + 4 def print_attr(self, prepend=""): if self._enum is not None: type_str = self._enum.name else: type_str = "nla#{}".format(self.nla_type) print( "{}len={} type={}({}) {{".format( prepend, self.nla_len, type_str, self.nla_type ) ) for nla in self.nla_list: nla.print_attr(prepend + " ") print("{}}}".format(prepend)) def __bytes__(self): return self._to_bytes(b"".join([bytes(nla) for nla in self.nla_list])) class NlAttrU32(NlAttr): def __init__(self, nla_type, val): self.u32 = enum_or_int(val) super().__init__(nla_type, b"") @property def nla_len(self): return 8 def _print_attr_value(self): return " val={}".format(self.u32) @staticmethod def _validate(data): assert len(data) == 8 nla_len, nla_type = struct.unpack("@HH", data[:4]) assert nla_len == 8 @classmethod def _parse(cls, data): nla_len, nla_type, val = struct.unpack("@HHI", data) return cls(nla_type, val) def __bytes__(self): return self._to_bytes(struct.pack("@I", self.u32)) class NlAttrU16(NlAttr): def __init__(self, nla_type, val): self.u16 = enum_or_int(val) super().__init__(nla_type, b"") @property def nla_len(self): return 6 def _print_attr_value(self): return " val={}".format(self.u16) @staticmethod def _validate(data): assert len(data) == 6 nla_len, nla_type = struct.unpack("@HH", data[:4]) assert nla_len == 6 @classmethod def _parse(cls, data): nla_len, nla_type, val = struct.unpack("@HHH", data) return cls(nla_type, val) def __bytes__(self): return self._to_bytes(struct.pack("@H", self.u16)) class NlAttrU8(NlAttr): def __init__(self, nla_type, val): self.u8 = enum_or_int(val) super().__init__(nla_type, b"") @property def nla_len(self): return 5 def _print_attr_value(self): return " val={}".format(self.u8) @staticmethod def _validate(data): assert len(data) == 5 nla_len, nla_type = struct.unpack("@HH", data[:4]) assert nla_len == 5 @classmethod def _parse(cls, data): nla_len, nla_type, val = struct.unpack("@HHB", data) return cls(nla_type, val) def __bytes__(self): return self._to_bytes(struct.pack("@B", self.u8)) class NlAttrIp(NlAttr): def __init__(self, nla_type, addr: str): super().__init__(nla_type, b"") self.addr = addr if ":" in self.addr: self.family = socket.AF_INET6 else: self.family = socket.AF_INET @staticmethod def _validate(data): nla_len, nla_type = struct.unpack("@HH", data[:4]) data_len = nla_len - 4 if data_len != 4 and data_len != 16: raise ValueError( "Error validating attr {}: nla_len is not valid".format( # noqa: E501 nla_type ) ) @property def nla_len(self): if self.family == socket.AF_INET6: return 20 else: return 8 return align4(len(self._data)) + 4 @classmethod def _parse(cls, data): nla_len, nla_type = struct.unpack("@HH", data[:4]) data_len = len(data) - 4 if data_len == 4: addr = socket.inet_ntop(socket.AF_INET, data[4:8]) else: addr = socket.inet_ntop(socket.AF_INET6, data[4:20]) return cls(nla_type, addr) def __bytes__(self): return self._to_bytes(socket.inet_pton(self.family, self.addr)) def _print_attr_value(self): return " addr={}".format(self.addr) class NlAttrStr(NlAttr): def __init__(self, nla_type, text): super().__init__(nla_type, b"") self.text = text @staticmethod def _validate(data): NlAttr._validate(data) try: data[4:].decode("utf-8") except Exception as e: raise ValueError("wrong utf-8 string: {}".format(e)) @property def nla_len(self): return len(self.text) + 5 @classmethod def _parse(cls, data): text = data[4:-1].decode("utf-8") nla_len, nla_type = struct.unpack("@HH", data[:4]) return cls(nla_type, text) def __bytes__(self): return self._to_bytes(bytes(self.text, encoding="utf-8") + bytes(1)) def _print_attr_value(self): return ' val="{}"'.format(self.text) class NlAttrStrn(NlAttr): def __init__(self, nla_type, text): super().__init__(nla_type, b"") self.text = text @staticmethod def _validate(data): NlAttr._validate(data) try: data[4:].decode("utf-8") except Exception as e: raise ValueError("wrong utf-8 string: {}".format(e)) @property def nla_len(self): return len(self.text) + 4 @classmethod def _parse(cls, data): text = data[4:].decode("utf-8") nla_len, nla_type = struct.unpack("@HH", data[:4]) return cls(nla_type, text) def __bytes__(self): return self._to_bytes(bytes(self.text, encoding="utf-8")) def _print_attr_value(self): return ' val="{}"'.format(self.text) diff --git a/tests/atf_python/sys/netlink/message.py b/tests/atf_python/sys/netlink/message.py index b6fb2f8e357a..1e2b71775102 100644 --- a/tests/atf_python/sys/netlink/message.py +++ b/tests/atf_python/sys/netlink/message.py @@ -1,261 +1,261 @@ #!/usr/local/bin/python3 import struct from ctypes import sizeof from enum import Enum from typing import List from typing import NamedTuple from atf_python.sys.netlink.attrs import NlAttr from atf_python.sys.netlink.attrs import NlAttrNested from atf_python.sys.netlink.base_headers import NlmAckFlags from atf_python.sys.netlink.base_headers import NlmNewFlags from atf_python.sys.netlink.base_headers import NlmGetFlags from atf_python.sys.netlink.base_headers import NlmDeleteFlags from atf_python.sys.netlink.base_headers import NlmBaseFlags from atf_python.sys.netlink.base_headers import Nlmsghdr from atf_python.sys.netlink.base_headers import NlMsgType from atf_python.sys.netlink.utils import align4 from atf_python.sys.netlink.utils import enum_or_int from atf_python.sys.netlink.utils import get_bitmask_str class NlMsgCategory(Enum): UNKNOWN = 0 GET = 1 NEW = 2 DELETE = 3 ACK = 4 class NlMsgProps(NamedTuple): msg: Enum category: NlMsgCategory class BaseNetlinkMessage(object): def __init__(self, helper, nlmsg_type): self.nlmsg_type = enum_or_int(nlmsg_type) self.nla_list = [] self._orig_data = None self.helper = helper self.nl_hdr = Nlmsghdr( nlmsg_type=self.nlmsg_type, nlmsg_seq=helper.get_seq(), nlmsg_pid=helper.pid ) self.base_hdr = None def set_request(self, need_ack=True): self.add_nlflags([NlmBaseFlags.NLM_F_REQUEST]) if need_ack: self.add_nlflags([NlmBaseFlags.NLM_F_ACK]) def add_nlflags(self, flags: List): int_flags = 0 for flag in flags: int_flags |= enum_or_int(flag) self.nl_hdr.nlmsg_flags |= int_flags def add_nla(self, nla): self.nla_list.append(nla) def _get_nla(self, nla_list, nla_type): nla_type_raw = enum_or_int(nla_type) for nla in nla_list: if nla.nla_type == nla_type_raw: return nla return None def get_nla(self, nla_type): return self._get_nla(self.nla_list, nla_type) @staticmethod def parse_nl_header(data: bytes): if len(data) < sizeof(Nlmsghdr): raise ValueError("length less than netlink message header") return Nlmsghdr.from_buffer_copy(data), sizeof(Nlmsghdr) def is_type(self, nlmsg_type): nlmsg_type_raw = enum_or_int(nlmsg_type) return nlmsg_type_raw == self.nl_hdr.nlmsg_type def is_reply(self, hdr): return hdr.nlmsg_type == NlMsgType.NLMSG_ERROR.value @property def msg_name(self): return "msg#{}".format(self._get_msg_type()) def _get_nl_category(self): if self.is_reply(self.nl_hdr): return NlMsgCategory.ACK return NlMsgCategory.UNKNOWN def get_nlm_flags_str(self): category = self._get_nl_category() flags = self.nl_hdr.nlmsg_flags if category == NlMsgCategory.UNKNOWN: return self.helper.get_bitmask_str(NlmBaseFlags, flags) elif category == NlMsgCategory.GET: flags_enum = NlmGetFlags elif category == NlMsgCategory.NEW: flags_enum = NlmNewFlags elif category == NlMsgCategory.DELETE: flags_enum = NlmDeleteFlags elif category == NlMsgCategory.ACK: flags_enum = NlmAckFlags return get_bitmask_str([NlmBaseFlags, flags_enum], flags) def print_nl_header(self, prepend=""): # len=44, type=RTM_DELROUTE, flags=NLM_F_REQUEST|NLM_F_ACK, seq=1641163704, pid=0 # noqa: E501 hdr = self.nl_hdr print( "{}len={}, type={}, flags={}(0x{:X}), seq={}, pid={}".format( prepend, hdr.nlmsg_len, self.msg_name, self.get_nlm_flags_str(), hdr.nlmsg_flags, hdr.nlmsg_seq, hdr.nlmsg_pid, ) ) @classmethod def from_bytes(cls, helper, data): try: hdr, hdrlen = BaseNetlinkMessage.parse_nl_header(data) self = cls(helper, hdr.nlmsg_type) self._orig_data = data self.nl_hdr = hdr except ValueError as e: print("Failed to parse nl header: {}".format(e)) cls.print_as_bytes(data) raise return self def print_message(self): self.print_nl_header() @staticmethod def print_as_bytes(data: bytes, descr: str): print("===vv {} (len:{:3d}) vv===".format(descr, len(data))) off = 0 step = 16 while off < len(data): for i in range(step): if off + i < len(data): print(" {:02X}".format(data[off + i]), end="") print("") off += step print("--------------------") class StdNetlinkMessage(BaseNetlinkMessage): nl_attrs_map = {} @classmethod def from_bytes(cls, helper, data): try: hdr, hdrlen = BaseNetlinkMessage.parse_nl_header(data) self = cls(helper, hdr.nlmsg_type) self._orig_data = data self.nl_hdr = hdr except ValueError as e: print("Failed to parse nl header: {}".format(e)) cls.print_as_bytes(data) raise offset = align4(hdrlen) try: base_hdr, hdrlen = self.parse_base_header(data[offset:]) self.base_hdr = base_hdr offset += align4(hdrlen) # XXX: CAP_ACK except ValueError as e: print("Failed to parse nl rt header: {}".format(e)) cls.print_as_bytes(data) raise orig_offset = offset try: nla_list, nla_len = self.parse_nla_list(data[offset:]) offset += nla_len if offset != len(data): raise ValueError( "{} bytes left at the end of the packet".format(len(data) - offset) ) # noqa: E501 self.nla_list = nla_list except ValueError as e: print( "Failed to parse nla attributes at offset {}: {}".format(orig_offset, e) ) # noqa: E501 cls.print_as_bytes(data, "msg dump") cls.print_as_bytes(data[orig_offset:], "failed block") raise return self def parse_attrs(self, data: bytes, attr_map): ret = [] off = 0 while len(data) - off >= 4: nla_len, raw_nla_type = struct.unpack("@HH", data[off:off + 4]) if nla_len + off > len(data): raise ValueError( "attr length {} > than the remaining length {}".format( nla_len, len(data) - off ) ) - nla_type = raw_nla_type & 0x3F + nla_type = raw_nla_type & 0x3FFF if nla_type in attr_map: v = attr_map[nla_type] val = v["ad"].cls.from_bytes(data[off:off + nla_len], v["ad"].val) if "child" in v: # nested attrs, _ = self.parse_attrs( data[off + 4:off + nla_len], v["child"] ) val = NlAttrNested(v["ad"].val, attrs) else: # unknown attribute val = NlAttr(raw_nla_type, data[off + 4:off + nla_len]) ret.append(val) off += align4(nla_len) return ret, off def parse_nla_list(self, data: bytes) -> List[NlAttr]: return self.parse_attrs(data, self.nl_attrs_map) def __bytes__(self): ret = bytes() for nla in self.nla_list: ret += bytes(nla) ret = bytes(self.base_hdr) + ret self.nl_hdr.nlmsg_len = len(ret) + sizeof(Nlmsghdr) return bytes(self.nl_hdr) + ret def _get_msg_type(self): return self.nl_hdr.nlmsg_type @property def msg_props(self): msg_type = self._get_msg_type() for msg_props in self.messages: if msg_props.msg.value == msg_type: return msg_props return None @property def msg_name(self): msg_props = self.msg_props if msg_props is not None: return msg_props.msg.name return super().msg_name def print_base_header(self, hdr, prepend=""): pass def print_message(self): self.print_nl_header() self.print_base_header(self.base_hdr, " ") for nla in self.nla_list: nla.print_attr(" ")