diff --git a/sys/net/pfvar.h b/sys/net/pfvar.h --- a/sys/net/pfvar.h +++ b/sys/net/pfvar.h @@ -375,6 +375,41 @@ #define PF_STATE_LOCK_ASSERT(s) do {} while (0) #endif /* INVARIANTS */ +#ifdef INVARIANTS +#define PF_SRC_NODE_LOCK(sn) \ + do { \ + struct pf_ksrc_node *_sn = (sn); \ + struct pf_srchash *_sh = &V_pf_srchash[ \ + pf_hashsrc(&_sn->addr, _sn->af)]; \ + MPASS(_sn->lock == &_sh->lock); \ + mtx_lock(_sn->lock); \ + } while (0) +#define PF_SRC_NODE_UNLOCK(sn) \ + do { \ + struct pf_ksrc_node *_sn = (sn); \ + struct pf_srchash *_sh = &V_pf_srchash[ \ + pf_hashsrc(&_sn->addr, _sn->af)]; \ + MPASS(_sn->lock == &_sh->lock); \ + mtx_unlock(_sn->lock); \ + } while (0) +#else +#define PF_SRC_NODE_LOCK(sn) mtx_lock((sn)->lock) +#define PF_SRC_NODE_UNLOCK(sn) mtx_unlock((sn)->lock) +#endif + +#ifdef INVARIANTS +#define PF_SRC_NODE_LOCK_ASSERT(sn) \ + do { \ + struct pf_ksrc_node *_sn = (sn); \ + struct pf_srchash *_sh = &V_pf_srchash[ \ + pf_hashsrc(&_sn->addr, _sn->af)]; \ + MPASS(_sn->lock == &_sh->lock); \ + PF_HASHROW_ASSERT(_sh); \ + } while (0) +#else /* !INVARIANTS */ +#define PF_SRC_NODE_LOCK_ASSERT(sn) do {} while (0) +#endif /* INVARIANTS */ + extern struct mtx_padalign pf_unlnkdrules_mtx; #define PF_UNLNKDRULES_LOCK() mtx_lock(&pf_unlnkdrules_mtx) #define PF_UNLNKDRULES_UNLOCK() mtx_unlock(&pf_unlnkdrules_mtx) @@ -845,6 +880,7 @@ u_int32_t expire; sa_family_t af; u_int8_t ruletype; + struct mtx *lock; }; #endif @@ -2120,7 +2156,8 @@ extern bool pf_find_state_all_exists(struct pf_state_key_cmp *, u_int); extern struct pf_ksrc_node *pf_find_src_node(struct pf_addr *, - struct pf_krule *, sa_family_t, int); + struct pf_krule *, sa_family_t, + struct pf_srchash **, bool); extern void pf_unlink_src_node(struct pf_ksrc_node *); extern u_int pf_free_src_nodes(struct pf_ksrc_node_list *); extern void pf_print_state(struct pf_kstate *); diff --git a/sys/netpfil/pf/pf.c b/sys/netpfil/pf/pf.c --- a/sys/netpfil/pf/pf.c +++ b/sys/netpfil/pf/pf.c @@ -683,6 +683,10 @@ int bad = 0; PF_STATE_LOCK_ASSERT(*state); + /* + * XXXKS: The src node is accessed unlocked! + * PF_SRC_NODE_LOCK_ASSERT((*state)->src_node); + */ (*state)->src_node->conn++; (*state)->src.tcp_est = 1; @@ -827,25 +831,25 @@ */ struct pf_ksrc_node * pf_find_src_node(struct pf_addr *src, struct pf_krule *rule, sa_family_t af, - int returnlocked) + struct pf_srchash **sh, bool returnlocked) { - struct pf_srchash *sh; struct pf_ksrc_node *n; counter_u64_add(V_pf_status.scounters[SCNT_SRC_NODE_SEARCH], 1); - sh = &V_pf_srchash[pf_hashsrc(src, af)]; - PF_HASHROW_LOCK(sh); - LIST_FOREACH(n, &sh->nodes, entry) + *sh = &V_pf_srchash[pf_hashsrc(src, af)]; + PF_HASHROW_LOCK(*sh); + LIST_FOREACH(n, &(*sh)->nodes, entry) if (n->rule.ptr == rule && n->af == af && ((af == AF_INET && n->addr.v4.s_addr == src->v4.s_addr) || (af == AF_INET6 && bcmp(&n->addr, src, sizeof(*src)) == 0))) break; + if (n != NULL) { n->states++; - PF_HASHROW_UNLOCK(sh); - } else if (returnlocked == 0) - PF_HASHROW_UNLOCK(sh); + PF_HASHROW_UNLOCK(*sh); + } else if (returnlocked == false) + PF_HASHROW_UNLOCK(*sh); return (n); } @@ -865,17 +869,16 @@ pf_insert_src_node(struct pf_ksrc_node **sn, struct pf_krule *rule, struct pf_addr *src, sa_family_t af) { + struct pf_srchash *sh = NULL; KASSERT((rule->rule_flag & PFRULE_SRCTRACK || rule->rpool.opts & PF_POOL_STICKYADDR), ("%s for non-tracking rule %p", __func__, rule)); if (*sn == NULL) - *sn = pf_find_src_node(src, rule, af, 1); + *sn = pf_find_src_node(src, rule, af, &sh, true); if (*sn == NULL) { - struct pf_srchash *sh = &V_pf_srchash[pf_hashsrc(src, af)]; - PF_HASHROW_ASSERT(sh); if (!rule->max_src_nodes || @@ -904,6 +907,9 @@ rule->max_src_conn_rate.limit, rule->max_src_conn_rate.seconds); + MPASS((*sn)->lock == NULL); + (*sn)->lock = &sh->lock; + (*sn)->af = af; (*sn)->rule.ptr = rule; PF_ACPY(&(*sn)->addr, src, af); @@ -929,8 +935,8 @@ void pf_unlink_src_node(struct pf_ksrc_node *src) { + PF_SRC_NODE_LOCK_ASSERT(src); - PF_HASHROW_ASSERT(&V_pf_srchash[pf_hashsrc(&src->addr, src->af)]); LIST_REMOVE(src, entry); if (src->rule.ptr) counter_u64_add(src->rule.ptr->src_nodes, -1); @@ -1982,7 +1988,6 @@ pf_src_tree_remove_state(struct pf_kstate *s) { struct pf_ksrc_node *sn; - struct pf_srchash *sh; uint32_t timeout; timeout = s->rule.ptr->timeout[PFTM_SRC_NODE] ? @@ -1991,21 +1996,19 @@ if (s->src_node != NULL) { sn = s->src_node; - sh = &V_pf_srchash[pf_hashsrc(&sn->addr, sn->af)]; - PF_HASHROW_LOCK(sh); + PF_SRC_NODE_LOCK(sn); if (s->src.tcp_est) --sn->conn; if (--sn->states == 0) sn->expire = time_uptime + timeout; - PF_HASHROW_UNLOCK(sh); + PF_SRC_NODE_UNLOCK(sn); } if (s->nat_src_node != s->src_node && s->nat_src_node != NULL) { sn = s->nat_src_node; - sh = &V_pf_srchash[pf_hashsrc(&sn->addr, sn->af)]; - PF_HASHROW_LOCK(sh); + PF_SRC_NODE_LOCK(sn); if (--sn->states == 0) sn->expire = time_uptime + timeout; - PF_HASHROW_UNLOCK(sh); + PF_SRC_NODE_UNLOCK(sn); } s->src_node = s->nat_src_node = NULL; } @@ -4805,31 +4808,25 @@ uma_zfree(V_pf_state_key_z, nk); if (sn != NULL) { - struct pf_srchash *sh; - - sh = &V_pf_srchash[pf_hashsrc(&sn->addr, sn->af)]; - PF_HASHROW_LOCK(sh); + PF_SRC_NODE_LOCK(sn); if (--sn->states == 0 && sn->expire == 0) { pf_unlink_src_node(sn); uma_zfree(V_pf_sources_z, sn); counter_u64_add( V_pf_status.scounters[SCNT_SRC_NODE_REMOVALS], 1); } - PF_HASHROW_UNLOCK(sh); + PF_SRC_NODE_UNLOCK(sn); } if (nsn != sn && nsn != NULL) { - struct pf_srchash *sh; - - sh = &V_pf_srchash[pf_hashsrc(&nsn->addr, nsn->af)]; - PF_HASHROW_LOCK(sh); + PF_SRC_NODE_LOCK(nsn); if (--nsn->states == 0 && nsn->expire == 0) { pf_unlink_src_node(nsn); uma_zfree(V_pf_sources_z, nsn); counter_u64_add( V_pf_status.scounters[SCNT_SRC_NODE_REMOVALS], 1); } - PF_HASHROW_UNLOCK(sh); + PF_SRC_NODE_UNLOCK(nsn); } return (PF_DROP); diff --git a/sys/netpfil/pf/pf_lb.c b/sys/netpfil/pf/pf_lb.c --- a/sys/netpfil/pf/pf_lb.c +++ b/sys/netpfil/pf/pf_lb.c @@ -348,12 +348,13 @@ { struct pf_kpool *rpool = &r->rpool; struct pf_addr *raddr = NULL, *rmask = NULL; + struct pf_srchash *sh = NULL; /* Try to find a src_node if none was given and this is a sticky-address rule. */ if (*sn == NULL && r->rpool.opts & PF_POOL_STICKYADDR && (r->rpool.opts & PF_POOL_TYPEMASK) != PF_POOL_NONE) - *sn = pf_find_src_node(saddr, r, af, 0); + *sn = pf_find_src_node(saddr, r, af, &sh, false); /* If a src_node was found or explicitly given and it has a non-zero route address, use this address. A zeroed address is found if the