diff --git a/sys/netpfil/pf/pf.c b/sys/netpfil/pf/pf.c --- a/sys/netpfil/pf/pf.c +++ b/sys/netpfil/pf/pf.c @@ -6232,6 +6232,21 @@ } break; + case IPPROTO_SCTP: { + uint16_t checksum = 0; + if (afto || *pd->sport != sport) { + pf_change_ap(pd->m, pd->src, pd->sport, pd->ip_sum, &checksum, + saddr, sport, 1, pd->af, pd->naf); + rewrite = 1; + } + if (afto || *pd->dport != dport) { + pf_change_ap(pd->m, pd->dst, pd->dport, pd->ip_sum, &checksum, + daddr, dport, 1, pd->af, pd->naf); + rewrite = 1; + } + break; + } + #ifdef INET case IPPROTO_ICMP: /* pf_translate() is also used when logging invalid packets */ @@ -7014,6 +7029,33 @@ return (action); } +static int +pf_sctp_track(struct pf_kstate *state, struct pf_pdesc *pd, + u_short *reason) +{ + struct pf_state_peer *src; + if (pd->dir == state->direction) { + if (PF_REVERSED_KEY(state->key, pd->af)) + src = &state->dst; + else + src = &state->src; + } else { + if (PF_REVERSED_KEY(state->key, pd->af)) + src = &state->src; + else + src = &state->dst; + } + + if (src->scrub != NULL) { + if (src->scrub->pfss_v_tag == 0) + src->scrub->pfss_v_tag = pd->hdr.sctp.v_tag; + else if (src->scrub->pfss_v_tag != pd->hdr.sctp.v_tag) + return (PF_DROP); + } + + return (PF_PASS); +} + static int pf_test_state_sctp(struct pf_kstate **state, struct pf_pdesc *pd, u_short *reason) @@ -7090,37 +7132,45 @@ (*state)->timeout = PFTM_SCTP_CLOSED; } - if (src->scrub != NULL) { - if (src->scrub->pfss_v_tag == 0) { - src->scrub->pfss_v_tag = pd->hdr.sctp.v_tag; - } else if (src->scrub->pfss_v_tag != pd->hdr.sctp.v_tag) - return (PF_DROP); - } + if (pf_sctp_track(*state, pd, reason) != PF_PASS) + return (PF_DROP); (*state)->expire = pf_get_uptime(); /* translate source/destination address, if necessary */ if ((*state)->key[PF_SK_WIRE] != (*state)->key[PF_SK_STACK]) { uint16_t checksum = 0; - struct pf_state_key *nk = (*state)->key[pd->didx]; + struct pf_state_key *nk; + int afto, sidx, didx; - if (pd->af != nk->af) { - /* XXX No nat64 for SCTP for now. */ - return (PF_DROP); - } + if (PF_REVERSED_KEY((*state)->key, pd->af)) + nk = (*state)->key[pd->sidx]; + else + nk = (*state)->key[pd->didx]; - if (PF_ANEQ(pd->src, &nk->addr[pd->sidx], pd->af) || - nk->port[pd->sidx] != pd->hdr.sctp.src_port) { + afto = pd->af != nk->af; + sidx = afto ? pd->didx : pd->sidx; + didx = afto ? pd->sidx : pd->didx; + + if (afto || PF_ANEQ(pd->src, &nk->addr[sidx], pd->af) || + nk->port[sidx] != pd->hdr.sctp.src_port) { pf_change_ap(pd->m, pd->src, &pd->hdr.sctp.src_port, - pd->ip_sum, &checksum, &nk->addr[pd->sidx], - nk->port[pd->sidx], 1, pd->af, pd->naf); + pd->ip_sum, &checksum, &nk->addr[sidx], + nk->port[sidx], 1, pd->af, pd->naf); } - if (PF_ANEQ(pd->dst, &nk->addr[pd->didx], pd->af) || - nk->port[pd->didx] != pd->hdr.sctp.dest_port) { + if (afto || PF_ANEQ(pd->dst, &nk->addr[didx], pd->af) || + nk->port[didx] != pd->hdr.sctp.dest_port) { pf_change_ap(pd->m, pd->dst, &pd->hdr.sctp.dest_port, - pd->ip_sum, &checksum, &nk->addr[pd->didx], - nk->port[pd->didx], 1, pd->af, pd->naf); + pd->ip_sum, &checksum, &nk->addr[didx], + nk->port[didx], 1, pd->af, pd->naf); + } + + if (afto) { + PF_ACPY(&pd->nsaddr, &nk->addr[sidx], nk->af); + PF_ACPY(&pd->ndaddr, &nk->addr[didx], nk->af); + pd->naf = nk->af; + return (PF_AFRT); } }