diff --git a/tests/atf_python/sys/net/netlink.py b/tests/atf_python/sys/net/netlink.py --- a/tests/atf_python/sys/net/netlink.py +++ b/tests/atf_python/sys/net/netlink.py @@ -29,6 +29,12 @@ return roundup2(val, 4) +def enum_or_int(val) -> int: + if isinstance(val, Enum): + return val.value + return val + + class SockaddrNl(Structure): _fields_ = [ ("nl_len", c_ubyte), @@ -125,8 +131,8 @@ RTM_DELROUTE = 25 RTM_GETROUTE = 26 RTM_NEWNEIGH = 28 - RTM_DELNEIGH = 27 - RTM_GETNEIGH = 28 + RTM_DELNEIGH = 29 + RTM_GETNEIGH = 30 RTM_NEWRULE = 32 RTM_DELRULE = 33 RTM_GETRULE = 34 @@ -491,6 +497,39 @@ IFA_TARGET_NETNSID = auto() +class NdMsg(Structure): + _fields_ = [ + ("ndm_family", c_ubyte), + ("ndm_pad1", c_ubyte), + ("ndm_pad2", c_ubyte), + ("ndm_ifindex", c_uint), + ("ndm_state", c_ushort), + ("ndm_flags", c_ubyte), + ("ndm_type", c_ubyte), + ] + + +class NdAttrType(Enum): + NDA_UNSPEC = 0 + NDA_DST = 1 + NDA_LLADDR = 2 + NDA_CACHEINFO = 3 + NDA_PROBES = 4 + NDA_VLAN = 5 + NDA_PORT = 6 + NDA_VNI = 7 + NDA_IFINDEX = 8 + NDA_MASTER = 9 + NDA_LINK_NETNSID = 10 + NDA_SRC_VNI = 11 + NDA_PROTOCOL = 12 + NDA_NH_ID = 13 + NDA_FDB_EXT_ATTRS = 14 + NDA_FLAGS_EXT = 15 + NDA_NDM_STATE_MASK = 16 + NDA_NDM_FLAGS_MASK = 17 + + class GenlMsgHdr(Structure): _fields_ = [ ("cmd", c_ubyte), @@ -702,7 +741,7 @@ class NlAttrU32(NlAttr): def __init__(self, nla_type, val): - self.u32 = val + self.u32 = enum_or_int(val) super().__init__(nla_type, b"") @property @@ -729,7 +768,7 @@ class NlAttrU16(NlAttr): def __init__(self, nla_type, val): - self.u16 = val + self.u16 = enum_or_int(val) super().__init__(nla_type, b"") @property @@ -756,7 +795,7 @@ class NlAttrU8(NlAttr): def __init__(self, nla_type, val): - self.u8 = val + self.u8 = enum_or_int(val) super().__init__(nla_type, b"") @property @@ -842,6 +881,11 @@ return " iface=if#{}".format(self.u32) +class NlAttrMac(NlAttr): + def _print_attr_value(self): + return ["{:02}".format(int(d)) for d in data[4:]].join(":") + + class NlAttrTable(NlAttrU32): def _print_attr_value(self): return " rtable={}".format(self.u32) @@ -1067,26 +1111,44 @@ ) +rtnl_nd_attrs = prepare_attrs_map( + [ + AttrDescr(NdAttrType.NDA_DST, NlAttrIp), + AttrDescr(NdAttrType.NDA_IFINDEX, NlAttrIfindex), + AttrDescr(NdAttrType.NDA_FLAGS_EXT, NlAttrU32), + AttrDescr(NdAttrType.NDA_LLADDR, NlAttrMac), + ] +) + + class BaseNetlinkMessage(object): def __init__(self, helper, nlmsg_type): - self.nlmsg_type = nlmsg_type + self.nlmsg_type = enum_or_int(nlmsg_type) self.ut = unittest.TestCase() self.nla_list = [] self._orig_data = None self.helper = helper self.nl_hdr = Nlmsghdr( - nlmsg_type=nlmsg_type, nlmsg_seq=helper.get_seq(), nlmsg_pid=helper.pid + nlmsg_type=self.nlmsg_type, nlmsg_seq=helper.get_seq(), nlmsg_pid=helper.pid ) self.base_hdr = None + def set_request(self, need_ack=True): + self.add_nlflags([NlmBaseFlags.NLM_F_REQUEST]) + if need_ack: + self.add_nlflags([NlmBaseFlags.NLM_F_ACK]) + + def add_nlflags(self, flags: List): + int_flags = 0 + for flag in flags: + int_flags |= enum_or_int(flag) + self.nl_hdr.nlmsg_flags |= int_flags + def add_nla(self, nla): self.nla_list.append(nla) def _get_nla(self, nla_list, nla_type): - if isinstance(nla_type, Enum): - nla_type_raw = nla_type.value - else: - nla_type_raw = nla_type + nla_type_raw = enum_or_int(nla_type) for nla in nla_list: if nla.nla_type == nla_type_raw: return nla @@ -1102,10 +1164,7 @@ return Nlmsghdr.from_buffer_copy(data), sizeof(Nlmsghdr) def is_type(self, nlmsg_type): - if isinstance(nlmsg_type, Enum): - nlmsg_type_raw = nlmsg_type.value - else: - nlmsg_type_raw = nlmsg_type + nlmsg_type_raw = enum_or_int(nlmsg_type) return nlmsg_type_raw == self.nl_hdr.nlmsg_type def is_reply(self, hdr): @@ -1422,6 +1481,37 @@ ) +class NetlinkNdMessage(BaseNetlinkRtMessage): + messages = [ + NlRtMsgType.RTM_NEWNEIGH.value, + NlRtMsgType.RTM_DELNEIGH.value, + NlRtMsgType.RTM_GETNEIGH.value, + ] + nl_attrs_map = rtnl_nd_attrs + + def __init__(self, helper, nlm_type): + super().__init__(helper, nlm_type) + self.base_hdr = NdMsg() + + def parse_base_header(self, data): + if len(data) < sizeof(NdMsg): + raise ValueError("length less than NdMsg header") + nd_hdr = NdMsg.from_buffer_copy(data) + return (nd_hdr, sizeof(NdMsg)) + + def print_base_header(self, hdr, prepend=""): + family = self.helper.get_af_name(hdr.ndm_family) + print( + "{}family={}, ndm_ifindex={}, ndm_state={}, ndm_flags={}".format( # noqa: E501 + prepend, + family, + hdr.ndm_ifindex, + hdr.ndm_state, + hdr.ndm_flags, + ) + ) + + class Nlsock: def __init__(self, family, helper): self.helper = helper @@ -1435,6 +1525,7 @@ NetlinkRtMessage, NetlinkIflaMessage, NetlinkIfaMessage, + NetlinkNdMessage, NetlinkDoneMessage, NetlinkErrorMessage, ] diff --git a/tests/sys/netlink/Makefile b/tests/sys/netlink/Makefile --- a/tests/sys/netlink/Makefile +++ b/tests/sys/netlink/Makefile @@ -9,6 +9,7 @@ ATF_TESTS_PYTEST += test_nl_core.py ATF_TESTS_PYTEST += test_rtnl_iface.py ATF_TESTS_PYTEST += test_rtnl_ifaddr.py +ATF_TESTS_PYTEST += test_rtnl_neigh.py ATF_TESTS_PYTEST += test_rtnl_route.py CFLAGS+= -I${.CURDIR:H:H:H} diff --git a/tests/sys/netlink/test_rtnl_neigh.py b/tests/sys/netlink/test_rtnl_neigh.py new file mode 100644 --- /dev/null +++ b/tests/sys/netlink/test_rtnl_neigh.py @@ -0,0 +1,53 @@ +import socket +import pytest + +from atf_python.sys.net.netlink import NdAttrType +from atf_python.sys.net.netlink import NetlinkNdMessage +from atf_python.sys.net.netlink import NetlinkTestTemplate +from atf_python.sys.net.netlink import NlConst +from atf_python.sys.net.netlink import NlRtMsgType +from atf_python.sys.net.vnet import SingleVnetTestTemplate + + +class TestRtNlNeigh(NetlinkTestTemplate, SingleVnetTestTemplate): + def setup_method(self, method): + method_name = method.__name__ + if "4" in method_name: + self.IPV4_PREFIXES = ["192.0.2.1/24"] + if "6" in method_name: + self.IPV6_PREFIXES = ["2001:db8::1/64"] + super().setup_method(method) + self.setup_netlink(NlConst.NETLINK_ROUTE) + + def filter_iface(self, family, num_items): + epair_ifname = self.vnet.iface_alias_map["if1"].name + epair_ifindex = socket.if_nametoindex(epair_ifname) + + msg = NetlinkNdMessage(self.helper, NlRtMsgType.RTM_GETNEIGH) + msg.set_request() + msg.base_hdr.ndm_family = family + msg.base_hdr.ndm_ifindex = epair_ifindex + self.write_message(msg) + + ret = [] + for rx_msg in self.read_msg_list( + msg.nl_hdr.nlmsg_seq, NlRtMsgType.RTM_NEWNEIGH + ): + ifname = socket.if_indextoname(rx_msg.base_hdr.ndm_ifindex) + family = rx_msg.base_hdr.ndm_family + assert ifname == epair_ifname + assert family == family + assert rx_msg.get_nla(NdAttrType.NDA_DST) is not None + assert rx_msg.get_nla(NdAttrType.NDA_LLADDR) is not None + ret.append(rx_msg) + assert len(ret) == num_items + + @pytest.mark.timeout(5) + def test_6_filter_iface(self): + """Tests that listing outputs all nd6 records""" + return self.filter_iface(socket.AF_INET6, 2) + + @pytest.mark.timeout(5) + def test_4_filter_iface(self): + """Tests that listing outputs all arp records""" + return self.filter_iface(socket.AF_INET, 1) diff --git a/tests/sys/netlink/test_rtnl_route.py b/tests/sys/netlink/test_rtnl_route.py --- a/tests/sys/netlink/test_rtnl_route.py +++ b/tests/sys/netlink/test_rtnl_route.py @@ -2,9 +2,11 @@ import socket import pytest +from atf_python.sys.net.tools import ToolsHelper from atf_python.sys.net.netlink import NetlinkRtMessage from atf_python.sys.net.netlink import NetlinkTestTemplate from atf_python.sys.net.netlink import NlAttrIp +from atf_python.sys.net.netlink import NlAttrU32 from atf_python.sys.net.netlink import NlConst from atf_python.sys.net.netlink import NlmBaseFlags from atf_python.sys.net.netlink import NlmGetFlags @@ -22,6 +24,27 @@ super().setup_method(method) self.setup_netlink(NlConst.NETLINK_ROUTE) + @pytest.mark.timeout(5) + def test_add_route6_ll_gw(self): + epair_ifname = self.vnet.iface_alias_map["if1"].name + epair_ifindex = socket.if_nametoindex(epair_ifname) + + msg = NetlinkRtMessage(self.helper, NlRtMsgType.RTM_NEWROUTE) + msg.set_request() + msg.add_nlflags([NlmNewFlags.NLM_F_CREATE]) + msg.base_hdr.rtm_family = socket.AF_INET6 + msg.base_hdr.rtm_dst_len = 64 + msg.add_nla(NlAttrIp(RtattrType.RTA_DST, "2001:db8:2::")) + msg.add_nla(NlAttrIp(RtattrType.RTA_GATEWAY, "fe80::1")) + msg.add_nla(NlAttrU32(RtattrType.RTA_OIF, epair_ifindex)) + + rx_msg = self.get_reply(msg) + assert rx_msg.is_type(NlMsgType.NLMSG_ERROR) + assert rx_msg.error_code == 0 + + ToolsHelper.print_net_debug() + ToolsHelper.print_output("netstat -6onW") + @pytest.mark.timeout(20) def test_buffer_override(self): msg_flags = (