diff --git a/sys/net/if_ovpn.c b/sys/net/if_ovpn.c --- a/sys/net/if_ovpn.c +++ b/sys/net/if_ovpn.c @@ -130,6 +130,7 @@ #define OVPN_PEER_COUNTER_SIZE (sizeof(struct ovpn_peer_counters)/sizeof(uint64_t)) struct ovpn_kpeer { + RB_ENTRY(ovpn_kpeer) tree; int refcount; uint32_t peerid; @@ -172,13 +173,15 @@ }; #define OVPN_COUNTER_SIZE (sizeof(struct ovpn_counters)/sizeof(uint64_t)) +RB_HEAD(ovpn_kpeers, ovpn_kpeer); + struct ovpn_softc { int refcount; struct rmlock lock; struct ifnet *ifp; struct socket *so; int peercount; - struct ovpn_kpeer *peers[OVPN_MAX_PEERS]; /* XXX Hard limit for now? */ + struct ovpn_kpeers peers; /* Pending notification */ struct buf_ring *notifring; @@ -197,6 +200,10 @@ static int ovpn_get_af(struct mbuf *); static void ovpn_free_kkey_dir(struct ovpn_kkey_dir *); static bool ovpn_check_replay(struct ovpn_kkey_dir *, uint32_t); +static int ovpn_peer_compare(struct ovpn_kpeer *, struct ovpn_kpeer *); + +static RB_PROTOTYPE(ovpn_kpeers, ovpn_kpeer, tree, ovpn_peer_compare); +static RB_GENERATE(ovpn_kpeers, ovpn_kpeer, tree, ovpn_peer_compare); #define OVPN_MTU_MIN 576 #define OVPN_MTU_MAX (IP_MAXPACKET - sizeof(struct ip) - \ @@ -259,25 +266,22 @@ CTLFLAG_VNET | CTLFLAG_RW, &VNET_NAME(async_netisr_queue), 0, "Use netisr_queue() rather than netisr_dispatch()."); +static int +ovpn_peer_compare(struct ovpn_kpeer *a, struct ovpn_kpeer *b) +{ + return (a->peerid - b->peerid); +} + static struct ovpn_kpeer * ovpn_find_peer(struct ovpn_softc *sc, uint32_t peerid) { - struct ovpn_kpeer *p = NULL; + struct ovpn_kpeer p; OVPN_ASSERT(sc); - for (int i = 0; i < OVPN_MAX_PEERS; i++) { - p = sc->peers[i]; - if (p == NULL) - continue; - - if (p->peerid == peerid) { - MPASS(p->sc == sc); - break; - } - } + p.peerid = peerid; - return (p); + return (RB_FIND(ovpn_kpeers, &sc->peers, (struct ovpn_kpeer *)&p)); } static struct ovpn_kpeer * @@ -285,15 +289,7 @@ { OVPN_ASSERT(sc); - for (int i = 0; i < OVPN_MAX_PEERS; i++) { - if (sc->peers[i] == NULL) - continue; - return (sc->peers[i]); - } - - MPASS(false); - - return (NULL); + return (RB_ROOT(&sc->peers)); } static uint16_t @@ -479,7 +475,7 @@ struct socket *so = NULL; int fd; uint32_t peerid; - int ret = 0, i; + int ret = 0; if (nvl == NULL) return (EINVAL); @@ -600,20 +596,9 @@ sc->so = so; /* Insert the peer into the list. */ - for (i = 0; i < OVPN_MAX_PEERS; i++) { - if (sc->peers[i] != NULL) - continue; - - MPASS(sc->peers[i] == NULL); - sc->peers[i] = peer; - sc->peercount++; - soref(sc->so); - break; - } - if (i == OVPN_MAX_PEERS) { - ret = ENOSPC; - goto error_locked; - } + RB_INSERT(ovpn_kpeers, &sc->peers, peer); + sc->peercount++; + soref(sc->so); ret = udp_set_kernel_tunneling(sc->so, ovpn_udp_input, NULL, sc); if (ret == EBUSY) { @@ -621,7 +606,7 @@ ret = 0; } if (ret != 0) { - sc->peers[i] = NULL; + RB_REMOVE(ovpn_kpeers, &sc->peers, peer); sc->peercount--; goto error_locked; } @@ -648,24 +633,16 @@ _ovpn_del_peer(struct ovpn_softc *sc, uint32_t peerid) { struct ovpn_kpeer *peer; - int i; OVPN_WASSERT(sc); CURVNET_ASSERT_SET(); - for (i = 0; i < OVPN_MAX_PEERS; i++) { - if (sc->peers[i] == NULL) - continue; - if (sc->peers[i]->peerid != peerid) - continue; - - peer = sc->peers[i]; - break; - } - if (i == OVPN_MAX_PEERS) + peer = ovpn_find_peer(sc, peerid); + if (peer == NULL) return (ENOENT); + peer = RB_REMOVE(ovpn_kpeers, &sc->peers, peer); + MPASS(peer != NULL); - sc->peers[i] = NULL; sc->peercount--; ovpn_peer_release_ref(peer, true); @@ -1231,6 +1208,7 @@ static int ovpn_get_peer_stats(struct ovpn_softc *sc, nvlist_t **nvl) { + struct ovpn_kpeer *peer; nvlist_t *nvpeer = NULL; int ret; @@ -1249,12 +1227,7 @@ goto error; \ } while(0) - for (int i = 0; i < OVPN_MAX_PEERS; i++) { - struct ovpn_kpeer *peer = sc->peers[i]; - - if (peer == NULL) - continue; - + RB_FOREACH(peer, ovpn_kpeers, &sc->peers) { nvpeer = nvlist_create(0); if (nvpeer == NULL) { nvlist_destroy(*nvl); @@ -1381,6 +1354,8 @@ struct ifdrv *ifd; int error; + CURVNET_ASSERT_SET(); + switch (cmd) { case SIOCSDRVSPEC: case SIOCGDRVSPEC: @@ -1643,13 +1618,10 @@ OVPN_ASSERT(sc); - for (int i = 0; i < OVPN_MAX_PEERS; i++) { - if (sc->peers[i] == NULL) - continue; - if (addr.s_addr == sc->peers[i]->vpn4.s_addr) { - peer = sc->peers[i]; - break; - } + /* TODO: Add a second RB so we can look up by IP. */ + RB_FOREACH(peer, ovpn_kpeers, &sc->peers) { + if (addr.s_addr == peer->vpn4.s_addr) + return (peer); } return (peer); @@ -1664,13 +1636,10 @@ OVPN_ASSERT(sc); - for (int i = 0; i < OVPN_MAX_PEERS; i++) { - if (sc->peers[i] == NULL) - continue; - if (memcmp(addr, &sc->peers[i]->vpn6, sizeof(*addr)) == 0) { - peer = sc->peers[i]; - break; - } + /* TODO: Add a third RB so we can look up by IP. */ + RB_FOREACH(peer, ovpn_kpeers, &sc->peers) { + if (memcmp(addr, &peer->vpn6, sizeof(*addr)) == 0) + return (peer); } return (peer); @@ -2305,21 +2274,16 @@ char *unused __unused) { struct ovpn_softc *sc = ifp->if_softc; - int i; + struct ovpn_kpeer *peer, *tmppeer; int ret __diagused; - i = 0; - OVPN_WLOCK(sc); /* Flush keys & configuration. */ - do { - if (sc->peers[i] != NULL) { - ret = _ovpn_del_peer(sc, sc->peers[i]->peerid); - MPASS(ret == 0); - } - i++; - } while (i < OVPN_MAX_PEERS); + RB_FOREACH_SAFE(peer, ovpn_kpeers, &sc->peers, tmppeer) { + ret = _ovpn_del_peer(sc, peer->peerid); + MPASS(ret == 0); + } ovpn_flush_rxring(sc); @@ -2417,9 +2381,7 @@ sc = __containerof(ctx, struct ovpn_softc, epoch_ctx); MPASS(sc->peercount == 0); - for (int i = 0; i < OVPN_MAX_PEERS; i++) { - MPASS(sc->peers[i] == NULL); - } + MPASS(RB_EMPTY(&sc->peers)); COUNTER_ARRAY_FREE(sc->counters, OVPN_COUNTER_SIZE); @@ -2431,8 +2393,8 @@ ovpn_clone_destroy(struct if_clone *ifc, struct ifnet *ifp, uint32_t flags) { struct ovpn_softc *sc; + struct ovpn_kpeer *peer, *tmppeer; int unit; - int i; int ret __diagused; sc = ifp->if_softc; @@ -2445,14 +2407,10 @@ return (EBUSY); } - i = 0; - do { - if (sc->peers[i] != NULL) { - ret = _ovpn_del_peer(sc, sc->peers[i]->peerid); - MPASS(ret == 0); - } - i++; - } while (i < OVPN_MAX_PEERS); + RB_FOREACH_SAFE(peer, ovpn_kpeers, &sc->peers, tmppeer) { + ret = _ovpn_del_peer(sc, peer->peerid); + MPASS(ret == 0); + } ovpn_flush_rxring(sc); buf_ring_free(sc->notifring, M_OVPN);