diff --git a/sys/dev/hyperv/hvsock/hv_sock.c b/sys/dev/hyperv/hvsock/hv_sock.c --- a/sys/dev/hyperv/hvsock/hv_sock.c +++ b/sys/dev/hyperv/hvsock/hv_sock.c @@ -74,6 +74,8 @@ MALLOC_DEFINE(M_HVSOCK, "hyperv_socket", "hyperv socket control structures"); +static int hvs_dom_probe(void); + /* The MTU is 16KB per host side's design */ #define HVSOCK_MTU_SIZE (1024 * 16) #define HVSOCK_SEND_BUF_SZ (PAGE_SIZE - sizeof(struct vmpipe_proto_header)) @@ -124,6 +126,7 @@ static struct domain hv_socket_domain = { .dom_family = AF_HYPERV, .dom_name = "hyperv", + .dom_probe = hvs_dom_probe, .dom_protosw = hv_socket_protosw, .dom_protoswNPROTOSW = &hv_socket_protosw[nitems(hv_socket_protosw)] }; @@ -323,6 +326,16 @@ sx_xunlock(&hvs_trans_socks_sx); } +static int +hvs_dom_probe(void) +{ + + /* Don't even give us a chance to attach on non-HyperV. */ + if (vm_guest != VM_GUEST_HV) + return (ENXIO); + return (0); +} + void hvs_trans_init(void) { @@ -330,9 +343,6 @@ if (!IS_DEFAULT_VNET(curvnet)) return; - if (vm_guest != VM_GUEST_HV) - return; - HVSOCK_DBG(HVSOCK_DBG_VERBOSE, "%s: HyperV Socket hvs_trans_init called\n", __func__); @@ -355,9 +365,6 @@ { struct hvs_pcb *pcb = so2hvspcb(so); - if (vm_guest != VM_GUEST_HV) - return (ESOCKTNOSUPPORT); - HVSOCK_DBG(HVSOCK_DBG_VERBOSE, "%s: HyperV Socket hvs_trans_attach called\n", __func__); @@ -384,9 +391,6 @@ { struct hvs_pcb *pcb; - if (vm_guest != VM_GUEST_HV) - return; - HVSOCK_DBG(HVSOCK_DBG_VERBOSE, "%s: HyperV Socket hvs_trans_detach called\n", __func__); @@ -604,9 +608,6 @@ { struct hvs_pcb *pcb; - if (vm_guest != VM_GUEST_HV) - return (ESOCKTNOSUPPORT); - HVSOCK_DBG(HVSOCK_DBG_VERBOSE, "%s: HyperV Socket hvs_trans_disconnect called\n", __func__); @@ -934,9 +935,6 @@ { struct hvs_pcb *pcb; - if (vm_guest != VM_GUEST_HV) - return; - HVSOCK_DBG(HVSOCK_DBG_VERBOSE, "%s: HyperV Socket hvs_trans_close called\n", __func__); @@ -978,9 +976,6 @@ { struct hvs_pcb *pcb = so2hvspcb(so); - if (vm_guest != VM_GUEST_HV) - return; - HVSOCK_DBG(HVSOCK_DBG_VERBOSE, "%s: HyperV Socket hvs_trans_abort called\n", __func__);