diff --git a/sys/net/if_ovpn.c b/sys/net/if_ovpn.c --- a/sys/net/if_ovpn.c +++ b/sys/net/if_ovpn.c @@ -122,6 +122,7 @@ struct ovpn_softc; struct ovpn_kpeer { + RB_ENTRY(ovpn_kpeer) tree; int refcount; uint32_t peerid; @@ -141,8 +142,6 @@ struct callout ping_rcv; }; -#define OVPN_MAX_PEERS 128 - struct ovpn_counters { uint64_t lost_ctrl_pkts_in; uint64_t lost_ctrl_pkts_out; @@ -162,13 +161,15 @@ }; #define OVPN_COUNTER_SIZE (sizeof(struct ovpn_counters)/sizeof(uint64_t)) +RB_HEAD(ovpn_kpeers, ovpn_kpeer); + struct ovpn_softc { int refcount; struct rmlock lock; struct ifnet *ifp; struct socket *so; int peercount; - struct ovpn_kpeer *peers[OVPN_MAX_PEERS]; /* XXX Hard limit for now? */ + struct ovpn_kpeers peers; /* Pending notification */ struct buf_ring *notifring; @@ -187,6 +188,10 @@ static int ovpn_get_af(struct mbuf *); static void ovpn_free_kkey_dir(struct ovpn_kkey_dir *); static bool ovpn_check_replay(struct ovpn_kkey_dir *, uint32_t); +static int ovpn_peer_compare(struct ovpn_kpeer *, struct ovpn_kpeer *); + +static RB_PROTOTYPE(ovpn_kpeers, ovpn_kpeer, tree, ovpn_peer_compare); +static RB_GENERATE(ovpn_kpeers, ovpn_kpeer, tree, ovpn_peer_compare); #define OVPN_MTU_MIN 576 #define OVPN_MTU_MAX (IP_MAXPACKET - sizeof(struct ip) - \ @@ -246,25 +251,22 @@ CTLFLAG_VNET | CTLFLAG_RW, &VNET_NAME(async_netisr_queue), 0, "Use netisr_queue() rather than netisr_dispatch()."); +static int +ovpn_peer_compare(struct ovpn_kpeer *a, struct ovpn_kpeer *b) +{ + return (a->peerid - b->peerid); +} + static struct ovpn_kpeer * ovpn_find_peer(struct ovpn_softc *sc, uint32_t peerid) { - struct ovpn_kpeer *p = NULL; + struct ovpn_kpeer p; OVPN_ASSERT(sc); - for (int i = 0; i < OVPN_MAX_PEERS; i++) { - p = sc->peers[i]; - if (p == NULL) - continue; - - if (p->peerid == peerid) { - MPASS(p->sc == sc); - break; - } - } + p.peerid = peerid; - return (p); + return (RB_FIND(ovpn_kpeers, &sc->peers, &p)); } static struct ovpn_kpeer * @@ -272,15 +274,7 @@ { OVPN_ASSERT(sc); - for (int i = 0; i < OVPN_MAX_PEERS; i++) { - if (sc->peers[i] == NULL) - continue; - return (sc->peers[i]); - } - - MPASS(false); - - return (NULL); + return (RB_ROOT(&sc->peers)); } static uint16_t @@ -466,7 +460,7 @@ struct socket *so = NULL; int fd; uint32_t peerid; - int ret = 0, i; + int ret = 0; if (nvl == NULL) return (EINVAL); @@ -586,20 +580,9 @@ sc->so = so; /* Insert the peer into the list. */ - for (i = 0; i < OVPN_MAX_PEERS; i++) { - if (sc->peers[i] != NULL) - continue; - - MPASS(sc->peers[i] == NULL); - sc->peers[i] = peer; - sc->peercount++; - soref(sc->so); - break; - } - if (i == OVPN_MAX_PEERS) { - ret = ENOSPC; - goto error_locked; - } + RB_INSERT(ovpn_kpeers, &sc->peers, peer); + sc->peercount++; + soref(sc->so); ret = udp_set_kernel_tunneling(sc->so, ovpn_udp_input, NULL, sc); if (ret == EBUSY) { @@ -607,7 +590,7 @@ ret = 0; } if (ret != 0) { - sc->peers[i] = NULL; + RB_REMOVE(ovpn_kpeers, &sc->peers, peer); sc->peercount--; goto error_locked; } @@ -633,24 +616,16 @@ _ovpn_del_peer(struct ovpn_softc *sc, uint32_t peerid) { struct ovpn_kpeer *peer; - int i; OVPN_WASSERT(sc); CURVNET_ASSERT_SET(); - for (i = 0; i < OVPN_MAX_PEERS; i++) { - if (sc->peers[i] == NULL) - continue; - if (sc->peers[i]->peerid != peerid) - continue; - - peer = sc->peers[i]; - break; - } - if (i == OVPN_MAX_PEERS) + peer = ovpn_find_peer(sc, peerid); + if (peer == NULL) return (ENOENT); + peer = RB_REMOVE(ovpn_kpeers, &sc->peers, peer); + MPASS(peer != NULL); - sc->peers[i] = NULL; sc->peercount--; ovpn_peer_release_ref(peer, true); @@ -1362,6 +1337,8 @@ struct ifdrv *ifd; int error; + CURVNET_ASSERT_SET(); + switch (cmd) { case SIOCSDRVSPEC: case SIOCGDRVSPEC: @@ -1622,13 +1599,10 @@ OVPN_ASSERT(sc); - for (int i = 0; i < OVPN_MAX_PEERS; i++) { - if (sc->peers[i] == NULL) - continue; - if (addr.s_addr == sc->peers[i]->vpn4.s_addr) { - peer = sc->peers[i]; - break; - } + /* TODO: Add a second RB so we can look up by IP. */ + RB_FOREACH(peer, ovpn_kpeers, &sc->peers) { + if (addr.s_addr == peer->vpn4.s_addr) + return (peer); } return (peer); @@ -1643,13 +1617,10 @@ OVPN_ASSERT(sc); - for (int i = 0; i < OVPN_MAX_PEERS; i++) { - if (sc->peers[i] == NULL) - continue; - if (memcmp(addr, &sc->peers[i]->vpn6, sizeof(*addr)) == 0) { - peer = sc->peers[i]; - break; - } + /* TODO: Add a third RB so we can look up by IPv6 address. */ + RB_FOREACH(peer, ovpn_kpeers, &sc->peers) { + if (memcmp(addr, &peer->vpn6, sizeof(*addr)) == 0) + return (peer); } return (peer); @@ -2281,21 +2252,16 @@ char *unused __unused) { struct ovpn_softc *sc = ifp->if_softc; - int i; + struct ovpn_kpeer *peer, *tmppeer; int ret __diagused; - i = 0; - OVPN_WLOCK(sc); /* Flush keys & configuration. */ - do { - if (sc->peers[i] != NULL) { - ret = _ovpn_del_peer(sc, sc->peers[i]->peerid); - MPASS(ret == 0); - } - i++; - } while (i < OVPN_MAX_PEERS); + RB_FOREACH_SAFE(peer, ovpn_kpeers, &sc->peers, tmppeer) { + ret = _ovpn_del_peer(sc, peer->peerid); + MPASS(ret == 0); + } ovpn_flush_rxring(sc); @@ -2393,9 +2359,7 @@ sc = __containerof(ctx, struct ovpn_softc, epoch_ctx); MPASS(sc->peercount == 0); - for (int i = 0; i < OVPN_MAX_PEERS; i++) { - MPASS(sc->peers[i] == NULL); - } + MPASS(RB_EMPTY(&sc->peers)); COUNTER_ARRAY_FREE(sc->counters, OVPN_COUNTER_SIZE); @@ -2407,8 +2371,8 @@ ovpn_clone_destroy(struct if_clone *ifc, struct ifnet *ifp, uint32_t flags) { struct ovpn_softc *sc; + struct ovpn_kpeer *peer, *tmppeer; int unit; - int i; int ret __diagused; sc = ifp->if_softc; @@ -2421,14 +2385,10 @@ return (EBUSY); } - i = 0; - do { - if (sc->peers[i] != NULL) { - ret = _ovpn_del_peer(sc, sc->peers[i]->peerid); - MPASS(ret == 0); - } - i++; - } while (i < OVPN_MAX_PEERS); + RB_FOREACH_SAFE(peer, ovpn_kpeers, &sc->peers, tmppeer) { + ret = _ovpn_del_peer(sc, peer->peerid); + MPASS(ret == 0); + } ovpn_flush_rxring(sc); buf_ring_free(sc->notifring, M_OVPN);