diff --git a/sys/netpfil/pf/pf.c b/sys/netpfil/pf/pf.c --- a/sys/netpfil/pf/pf.c +++ b/sys/netpfil/pf/pf.c @@ -205,6 +205,8 @@ VNET_DEFINE(struct pf_krule *, pf_rulemarker); #endif +#define PF_SCTP_MAX_ENDPOINTS 8 + struct pf_sctp_endpoint; RB_HEAD(pf_sctp_endpoints, pf_sctp_endpoint); struct pf_sctp_source { @@ -7298,6 +7300,7 @@ }; struct pf_sctp_source *i; struct pf_sctp_endpoint *ep; + int count; PF_SCTP_ENDPOINTS_LOCK(); @@ -7316,13 +7319,21 @@ } /* Avoid inserting duplicates. */ + count = 0; TAILQ_FOREACH(i, &ep->sources, entry) { + count++; if (pf_addr_cmp(&i->addr, a, pd->af) == 0) { PF_SCTP_ENDPOINTS_UNLOCK(); return; } } + /* Limit the number of addresses per endpoint. */ + if (count >= PF_SCTP_MAX_ENDPOINTS) { + PF_SCTP_ENDPOINTS_UNLOCK(); + return; + } + i = malloc(sizeof(*i), M_PFTEMP, M_NOWAIT); if (i == NULL) { PF_SCTP_ENDPOINTS_UNLOCK(); diff --git a/tests/sys/netpfil/pf/sctp.py b/tests/sys/netpfil/pf/sctp.py --- a/tests/sys/netpfil/pf/sctp.py +++ b/tests/sys/netpfil/pf/sctp.py @@ -426,6 +426,34 @@ assert re.search(r"all sctp 192.0.2.4:.*192.0.2.3:1234", states) assert re.search(r"all sctp 192.0.2.4:.*192.0.2.2:1234", states) + @pytest.mark.require_user("root") + def test_limit_addresses(self): + srv_vnet = self.vnet_map["vnet2"] + + ifname = self.vnet_map["vnet1"].iface_alias_map["if1"].name + for i in range(0, 16): + ToolsHelper.print_output("/sbin/ifconfig %s inet alias 192.0.2.%d/24" % (ifname, 4 + i)) + + ToolsHelper.print_output("/sbin/pfctl -e") + ToolsHelper.pf_rules([ + "block proto sctp", + "pass on lo", + "pass inet proto sctp to 192.0.2.0/24"]) + + # Set up a connection, which will try to create states for all addresses + # we have assigned + client = SCTPClient("192.0.2.3", 1234) + client.send(b"hello", 0) + rcvd = self.wait_object(srv_vnet.pipe) + print(rcvd) + assert rcvd['ppid'] == 0 + assert rcvd['data'] == "hello" + + # But the number should be limited to 9 (original + 8 extra) + states = ToolsHelper.get_output("/sbin/pfctl -ss | grep 192.0.2.2") + print(states) + assert(states.count('\n') <= 9) + @pytest.mark.require_user("root") def test_disallow_related(self): srv_vnet = self.vnet_map["vnet2"]