diff --git a/sys/dev/if_wg/module/if_wg_session.c b/sys/dev/if_wg/module/if_wg_session.c --- a/sys/dev/if_wg/module/if_wg_session.c +++ b/sys/dev/if_wg/module/if_wg_session.c @@ -1394,8 +1394,8 @@ CURVNET_SET(inp->inp_vnet); ip_input(m); CURVNET_RESTORE(); - } else if (version == 6) { - af = AF_INET; + } else if (version == 6) { + af = AF_INET6; BPF_MTAP2(sc->sc_ifp, &af, sizeof(af), m); inp = sotoinpcb(so->so_so6); CURVNET_SET(inp->inp_vnet); @@ -1531,6 +1531,7 @@ peer = CONTAINER_OF(remote, struct wg_peer, p_remote); DPRINTF(sc, "Receiving handshake initiation from peer %llu\n", (unsigned long long)peer->p_id); + wg_peer_set_endpoint_from_tag(peer, t); res = wg_send_response(peer); if (res == 0 && noise_remote_begin_session(&peer->p_remote) == 0) wg_timers_event_session_derived(&peer->p_timers); @@ -1851,6 +1852,40 @@ SLIST_INSERT_HEAD(&peer->p_unused_index, iter, i_unused_entry); } +static int +wg_update_endpoint_addrs(struct wg_endpoint *e, const struct sockaddr *srcsa, + struct ifnet *rcvif) +{ + const struct sockaddr_in *sa4; + const struct sockaddr_in6 *sa6; + int ret = 0; + + /* + * UDP passes a 2-element sockaddr array: first element is the + * source addr/port, second the destination addr/port. + */ + if (srcsa->sa_family == AF_INET) { + sa4 = (const struct sockaddr_in *)srcsa; + e->e_remote.r_sin = sa4[0]; + /* Only update dest if not mcast/bcast */ + if (!(IN_MULTICAST(ntohl(sa4[1].sin_addr.s_addr)) || + sa4[1].sin_addr.s_addr == INADDR_BROADCAST || + in_broadcast(sa4[1].sin_addr, rcvif))) { + e->e_local.l_in = sa4[1].sin_addr; + } + } else if (srcsa->sa_family == AF_INET6) { + sa6 = (const struct sockaddr_in6 *)srcsa; + e->e_remote.r_sin6 = sa6[0]; + /* Only update dest if not multicast */ + if (!IN6_IS_ADDR_MULTICAST(&sa6[1].sin6_addr)) + e->e_local.l_in6 = sa6[1].sin6_addr; + } else { + ret = EAFNOSUPPORT; + } + + return (ret); +} + static void wg_input(struct mbuf *m0, int offset, struct inpcb *inpcb, const struct sockaddr *srcsa, void *_sc) @@ -1882,7 +1917,11 @@ goto free; } e = wg_mbuf_endpoint_get(m); - e->e_remote.r_sa = *srcsa; + + if (wg_update_endpoint_addrs(e, srcsa, m->m_pkthdr.rcvif)) { + DPRINTF(sc, "unknown family\n"); + goto free; + } verify_endpoint(m); if_inc_counter(sc->sc_ifp, IFCOUNTER_IPACKETS, 1); diff --git a/sys/dev/if_wg/module/module.c b/sys/dev/if_wg/module/module.c --- a/sys/dev/if_wg/module/module.c +++ b/sys/dev/if_wg/module/module.c @@ -255,7 +255,6 @@ peer = wg_route_lookup(&sc->sc_routes, m, OUT); if (__predict_false(peer == NULL)) { rc = ENOKEY; - printf("peer not found - dropping %p\n", m); /* XXX log */ goto err; } @@ -360,8 +359,15 @@ struct wg_softc *sc; int rc; + if (iflib_in_detach(ctx)) + return; + sc = iflib_get_softc(ctx); ifp = iflib_get_ifp(ctx); + if (sc->sc_socket.so_so4 != NULL) + printf("XXX wg_init, socket non-NULL %p\n", + sc->sc_socket.so_so4); + wg_socket_reinit(sc, NULL, NULL); rc = wg_socket_init(sc); if (rc) return; @@ -377,6 +383,7 @@ sc = iflib_get_softc(ctx); ifp = iflib_get_ifp(ctx); if_link_state_change(ifp, LINK_STATE_DOWN); + wg_socket_reinit(sc, NULL, NULL); } static nvlist_t * @@ -386,13 +393,20 @@ int i, count; nvlist_t *nvl; caddr_t key; + size_t sa_sz; struct wg_allowedip *aip; + struct wg_endpoint *ep; if ((nvl = nvlist_create(0)) == NULL) return (NULL); key = peer->p_remote.r_public; nvlist_add_binary(nvl, "public-key", key, WG_KEY_SIZE); - nvlist_add_binary(nvl, "endpoint", &peer->p_endpoint.e_remote, sizeof(struct sockaddr)); + ep = &peer->p_endpoint; + if (ep->e_remote.r_sa.sa_family != 0) { + sa_sz = (ep->e_remote.r_sa.sa_family == AF_INET) ? + sizeof(struct sockaddr_in) : sizeof(struct sockaddr_in6); + nvlist_add_binary(nvl, "endpoint", &ep->e_remote, sa_sz); + } i = count = 0; CK_LIST_FOREACH(rt, &peer->p_routes, r_entry) { count++; @@ -587,13 +601,12 @@ } if (nvlist_exists_binary(nvl, "endpoint")) { endpoint = nvlist_get_binary(nvl, "endpoint", &size); - if (size != sizeof(*endpoint)) { + if (size > sizeof(peer->p_endpoint.e_remote)) { device_printf(dev, "%s bad length for endpoint %zu\n", __func__, size); err = EBADMSG; goto out; } - memcpy(&peer->p_endpoint.e_remote, endpoint, - sizeof(peer->p_endpoint.e_remote)); + memcpy(&peer->p_endpoint.e_remote, endpoint, size); } if (nvlist_exists_binary(nvl, "pre-shared-key")) { const void *key;