diff --git a/sys/dev/wg/if_wg.c b/sys/dev/wg/if_wg.c --- a/sys/dev/wg/if_wg.c +++ b/sys/dev/wg/if_wg.c @@ -526,46 +526,62 @@ rw_runlock(&peer->p_endpoint_lock); } -/* Allowed IP */ static int -wg_aip_add(struct wg_softc *sc, struct wg_peer *peer, sa_family_t af, const void *addr, uint8_t cidr) +wg_aip_addrinfo(struct wg_softc *sc, sa_family_t af, const void *addr, + uint8_t cidr, struct aip_addr *a_addr, struct aip_addr *a_mask, + struct radix_node_head **root) { - struct radix_node_head *root; - struct radix_node *node; - struct wg_aip *aip; - int ret = 0; - - aip = malloc(sizeof(*aip), M_WG, M_WAITOK | M_ZERO); - aip->a_peer = peer; - aip->a_af = af; switch (af) { #ifdef INET case AF_INET: if (cidr > 32) cidr = 32; - root = sc->sc_aip4; - aip->a_addr.in = *(const struct in_addr *)addr; - aip->a_mask.ip = htonl(~((1LL << (32 - cidr)) - 1) & 0xffffffff); - aip->a_addr.ip &= aip->a_mask.ip; - aip->a_addr.length = aip->a_mask.length = offsetof(struct aip_addr, in) + sizeof(struct in_addr); + *root = sc->sc_aip4; + a_addr->in = *(const struct in_addr *)addr; + a_mask->ip = htonl(~((1LL << (32 - cidr)) - 1) & 0xffffffff); + a_addr->ip &= a_mask->ip; + a_addr->length = a_mask->length = offsetof(struct aip_addr, in) + sizeof(struct in_addr); break; #endif #ifdef INET6 case AF_INET6: if (cidr > 128) cidr = 128; - root = sc->sc_aip6; - aip->a_addr.in6 = *(const struct in6_addr *)addr; - in6_prefixlen2mask(&aip->a_mask.in6, cidr); + *root = sc->sc_aip6; + a_addr->in6 = *(const struct in6_addr *)addr; + in6_prefixlen2mask(&a_mask->in6, cidr); for (int i = 0; i < 4; i++) - aip->a_addr.ip6[i] &= aip->a_mask.ip6[i]; - aip->a_addr.length = aip->a_mask.length = offsetof(struct aip_addr, in6) + sizeof(struct in6_addr); + a_addr->ip6[i] &= a_mask->ip6[i]; + a_addr->length = a_mask->length = offsetof(struct aip_addr, in6) + sizeof(struct in6_addr); break; #endif default: - free(aip, M_WG); return (EAFNOSUPPORT); } + return (0); +} + +/* Allowed IP */ +static int +wg_aip_add(struct wg_softc *sc, struct wg_peer *peer, sa_family_t af, const void *addr, uint8_t cidr) +{ + struct radix_node_head *root = NULL; + struct radix_node *node; + struct wg_aip *aip; + int ret = 0; + + aip = malloc(sizeof(*aip), M_WG, M_WAITOK | M_ZERO); + aip->a_peer = peer; + aip->a_af = af; + + ret = wg_aip_addrinfo(sc, af, addr, cidr, &aip->a_addr, &aip->a_mask, + &root); + if (ret != 0) { + free(aip, M_WG); + return (ret); + } + + MPASS(root != NULL); RADIX_NODE_HEAD_LOCK(root); node = root->rnh_addaddr(&aip->a_addr, &aip->a_mask, &root->rh, aip->a_nodes); if (node == aip->a_nodes) {