diff --git a/sys/net/if_ovpn.c b/sys/net/if_ovpn.c --- a/sys/net/if_ovpn.c +++ b/sys/net/if_ovpn.c @@ -174,6 +174,8 @@ #define OVPN_COUNTER_SIZE (sizeof(struct ovpn_counters)/sizeof(uint64_t)) RB_HEAD(ovpn_kpeers, ovpn_kpeer); +RB_HEAD(ovpn_kpeers_by_ip, ovpn_kpeer); +RB_HEAD(ovpn_kpeers_by_ip6, ovpn_kpeer); struct ovpn_softc { int refcount; @@ -182,6 +184,8 @@ struct socket *so; int peercount; struct ovpn_kpeers peers; + struct ovpn_kpeers_by_ip peers_by_ip; + struct ovpn_kpeers_by_ip6 peers_by_ip6; /* Pending notification */ struct buf_ring *notifring; @@ -192,6 +196,10 @@ }; static struct ovpn_kpeer *ovpn_find_peer(struct ovpn_softc *, uint32_t); +static struct ovpn_kpeer *ovpn_find_peer_by_ip(struct ovpn_softc *, + const struct in_addr); +static struct ovpn_kpeer *ovpn_find_peer_by_ip6(struct ovpn_softc *, + const struct in6_addr *); static bool ovpn_udp_input(struct mbuf *, int, struct inpcb *, const struct sockaddr *, void *); static int ovpn_transmit_to_peer(struct ifnet *, struct mbuf *, @@ -200,10 +208,23 @@ static int ovpn_get_af(struct mbuf *); static void ovpn_free_kkey_dir(struct ovpn_kkey_dir *); static bool ovpn_check_replay(struct ovpn_kkey_dir *, uint32_t); -static int ovpn_peer_compare(struct ovpn_kpeer *, struct ovpn_kpeer *); +static int ovpn_peer_compare(const struct ovpn_kpeer *, + const struct ovpn_kpeer *); +static int ovpn_peer_compare_by_ip(const struct ovpn_kpeer *, + const struct ovpn_kpeer *); +static int ovpn_peer_compare_by_ip6(const struct ovpn_kpeer *, + const struct ovpn_kpeer *); static RB_PROTOTYPE(ovpn_kpeers, ovpn_kpeer, tree, ovpn_peer_compare); static RB_GENERATE(ovpn_kpeers, ovpn_kpeer, tree, ovpn_peer_compare); +static RB_PROTOTYPE(ovpn_kpeers_by_ip, ovpn_kpeer, tree, + ovpn_peer_compare_by_ip); +static RB_GENERATE(ovpn_kpeers_by_ip, ovpn_kpeer, tree, + ovpn_peer_compare_by_ip); +static RB_PROTOTYPE(ovpn_kpeers_by_ip6, ovpn_kpeer, tree, + ovpn_peer_compare_by_ip6); +static RB_GENERATE(ovpn_kpeers_by_ip6, ovpn_kpeer, tree, + ovpn_peer_compare_by_ip6); #define OVPN_MTU_MIN 576 #define OVPN_MTU_MAX (IP_MAXPACKET - sizeof(struct ip) - \ @@ -267,11 +288,24 @@ "Use netisr_queue() rather than netisr_dispatch()."); static int -ovpn_peer_compare(struct ovpn_kpeer *a, struct ovpn_kpeer *b) +ovpn_peer_compare(const struct ovpn_kpeer *a, const struct ovpn_kpeer *b) { return (a->peerid - b->peerid); } +static int +ovpn_peer_compare_by_ip(const struct ovpn_kpeer *a, const struct ovpn_kpeer *b) +{ + return (memcmp(&a->vpn4, &b->vpn4, sizeof(a->vpn4))); +} + +static int +ovpn_peer_compare_by_ip6(const struct ovpn_kpeer *a, + const struct ovpn_kpeer *b) +{ + return (memcmp(&a->vpn6, &b->vpn6, sizeof(a->vpn6))); +} + static struct ovpn_kpeer * ovpn_find_peer(struct ovpn_softc *sc, uint32_t peerid) { @@ -596,8 +630,14 @@ if (sc->so == NULL) sc->so = so; - /* Insert the peer into the list. */ + /* Insert the peer into the lists. */ RB_INSERT(ovpn_kpeers, &sc->peers, peer); + if (nvlist_exists_binary(nvl, "vpn_ipv4")) { + RB_INSERT(ovpn_kpeers_by_ip, &sc->peers_by_ip, peer); + } + if (nvlist_exists_binary(nvl, "vpn_ipv6")) { + RB_INSERT(ovpn_kpeers_by_ip6, &sc->peers_by_ip6, peer); + } sc->peercount++; soref(sc->so); @@ -608,6 +648,12 @@ } if (ret != 0) { RB_REMOVE(ovpn_kpeers, &sc->peers, peer); + if (nvlist_exists_binary(nvl, "vpn_ipv4")) { + RB_REMOVE(ovpn_kpeers_by_ip, &sc->peers_by_ip, peer); + } + if (nvlist_exists_binary(nvl, "vpn_ipv6")) { + RB_REMOVE(ovpn_kpeers_by_ip6, &sc->peers_by_ip6, peer); + } sc->peercount--; goto error_locked; } @@ -633,7 +679,7 @@ static int _ovpn_del_peer(struct ovpn_softc *sc, struct ovpn_kpeer *peer) { - struct ovpn_kpeer *tmp __diagused; + struct ovpn_kpeer *tmp; OVPN_WASSERT(sc); CURVNET_ASSERT_SET(); @@ -643,6 +689,13 @@ tmp = RB_REMOVE(ovpn_kpeers, &sc->peers, peer); MPASS(tmp != NULL); + tmp = ovpn_find_peer_by_ip(sc, peer->vpn4); + if (tmp) + RB_REMOVE(ovpn_kpeers_by_ip, &sc->peers_by_ip, tmp); + tmp = ovpn_find_peer_by_ip6(sc, &peer->vpn6); + if (tmp) + RB_REMOVE(ovpn_kpeers_by_ip6, &sc->peers_by_ip6, tmp); + sc->peercount--; ovpn_peer_release_ref(peer, true); @@ -1628,17 +1681,13 @@ static struct ovpn_kpeer * ovpn_find_peer_by_ip(struct ovpn_softc *sc, const struct in_addr addr) { - struct ovpn_kpeer *peer = NULL; + struct ovpn_kpeer peer; OVPN_ASSERT(sc); - /* TODO: Add a second RB so we can look up by IP. */ - RB_FOREACH(peer, ovpn_kpeers, &sc->peers) { - if (addr.s_addr == peer->vpn4.s_addr) - return (peer); - } + peer.vpn4 = addr; - return (peer); + return (RB_FIND(ovpn_kpeers_by_ip, &sc->peers_by_ip, &peer)); } #endif @@ -1646,17 +1695,13 @@ static struct ovpn_kpeer * ovpn_find_peer_by_ip6(struct ovpn_softc *sc, const struct in6_addr *addr) { - struct ovpn_kpeer *peer = NULL; + struct ovpn_kpeer peer; OVPN_ASSERT(sc); - /* TODO: Add a third RB so we can look up by IPv6 address. */ - RB_FOREACH(peer, ovpn_kpeers, &sc->peers) { - if (memcmp(addr, &peer->vpn6, sizeof(*addr)) == 0) - return (peer); - } + peer.vpn6 = *addr; - return (peer); + return (RB_FIND(ovpn_kpeers_by_ip6, &sc->peers_by_ip6, &peer)); } #endif