import torch
import math
from torch.utils.data import TensorDataset

def make_virtual_dataset_ablation(x: torch.Tensor, virt_bins: int = 16) -> TensorDataset:
    """
    Create a virtual dataset for ablation study where each grid point has a single direction [1, 1, ..., 1].
    """
    device, dtype = x.device, x.dtype
    N, D = x.shape

    # Define uniform grid in each dimension
    mins, _ = torch.min(x, dim=0)
    maxs, _ = torch.max(x, dim=0)
    grids = [
        torch.linspace(mins[d].item(), maxs[d].item(), virt_bins, device=device, dtype=dtype)
        for d in range(D)
    ]
    mesh = torch.meshgrid(*grids, indexing="ij")
    Xvirt = torch.stack([g.reshape(-1) for g in mesh], dim=-1)  # (J, D), J=bins**D

    # Single direction: [1, 1, ..., 1] for each site
    v_all = torch.ones(D, device=device, dtype=dtype)
    Vvirt = v_all.repeat(Xvirt.shape[0], 1)  # (J, D)

    return TensorDataset(Xvirt, Vvirt)

def make_virtual_dataset_all(train_x: torch.Tensor, virt_bins: int = 16, qcap_vec: torch.Tensor = None) -> TensorDataset:
    """
    Create a virtual dataset where each grid point has a single direction [1, 1, ..., 1].
    Returns TensorDataset(Xvirt, Vvirt).
    """
    device, dtype = train_x.device, train_x.dtype
    N, D = train_x.shape
    
    # Define uniform grid in each dimension
    mins, _ = torch.min(train_x, dim=0)
    maxs = qcap_vec.clone()

    grids = [
        torch.linspace(mins[d].item(), maxs[d].item(), virt_bins, device=device, dtype=dtype)
        for d in range(D)
    ]
    mesh = torch.meshgrid(*grids, indexing="ij")
    Xvirt = torch.stack([g.reshape(-1) for g in mesh], dim=-1)  # (J, D), J=bins**D

    # Single direction: [1, 1, ..., 1] for each site
    v_all = torch.ones(D, device=device, dtype=dtype)
    Vvirt = v_all.repeat(Xvirt.shape[0], 1)  # (J, D)

    return TensorDataset(Xvirt, Vvirt)

import math
import torch

# -------- stable utilities --------

def _gh_nodes(n: int, device=None, dtype=None):
    """
    Gauss–Hermite nodes/weights for ∫ e^{-x^2} f(x) dx ≈ Σ w_k f(x_k).
    We build in float64 for accuracy and cast back.
    """
    import numpy as np
    xs, ws = np.polynomial.hermite.hermgauss(n)  # float64
    x = torch.from_numpy(xs).to(device=device, dtype=torch.float64)
    w = torch.from_numpy(ws).to(device=device, dtype=torch.float64)
    if dtype is not None and dtype != torch.float64:
        x = x.to(dtype)
        w = w.to(dtype)
    return x, w

# numerically-stable log Φ(z)
def _log_phi_stable(z: torch.Tensor) -> torch.Tensor:
    """
    Stable log CDF of standard normal, works for very negative z.
    Uses torch.distributions.Normal(0,1).log_cdf internally.
    """
    z_dtype = z.dtype
    z = z.to(torch.float64)  # improve stability
    normal = torch.distributions.Normal(0.0, 1.0)
    logcdf = torch.log(normal.cdf(z))
    return logcdf.to(z_dtype)

# -------- expectation E[log Φ(g/ν)] for g ~ N(mu, var) --------

def _elogphi_expectation_gaussian(mu: torch.Tensor,
                                  var: torch.Tensor,
                                  nu: float = 1,
                                  gh_nodes: int | None = 12,
                                  mc_samples: int | None = None) -> torch.Tensor:
    """
    Returns tensor with same shape as mu: E[log Φ(g/nu)], g ~ N(mu, var).
    - MC: unbiased, noisier
    - GH: deterministic, accurate; do it in 64-bit internally
    """
    # safety
    nu = float(max(nu, 1e-12))
    var = var.clamp_min(0.0)
    std = var.sqrt()

    if mc_samples is not None and mc_samples > 0:
        eps = torch.randn((mc_samples,) + mu.shape, device=mu.device, dtype=mu.dtype)
        g = mu.unsqueeze(0) + std.unsqueeze(0) * eps
        z = g / nu
        return _log_phi_stable(z).mean(dim=0)

    assert gh_nodes is not None and gh_nodes > 0, "Provide gh_nodes>0 or mc_samples>0"

    # Do GH in float64 to reduce quadrature error, then cast back
    mu64  = mu.to(torch.float64)
    std64 = std.to(torch.float64)

    x, w = _gh_nodes(gh_nodes, device=mu.device, dtype=torch.float64)  # (Q,)
    # GH transform: g = mu + sqrt(2) * std * x
    g = mu64.unsqueeze(-1) + (math.sqrt(2.0) * std64).unsqueeze(-1) * x  # (..., Q)
    z = g / float(nu)                                                     # (..., Q)
    val = (w * _log_phi_stable(z)).sum(dim=-1) / math.sqrt(math.pi)      # (...,)

    return val.to(mu.dtype)

def compute_L_virt(model, Xv, Vv, derivative_directions, nu=1, gh_nodes=12, mc_samples=15):
    """
    Compute virtual probit monotonic term:
        L_virt = sum_{r=1}^J sum_{k=1}^K E_q[ log Φ(g_{r,k} ) ]
    where g_{r,k} = D_{v_{r,k}} f(x_r)
    """
    # Forward: derivative-only predictions
    mvn, mvn_D = model(Xv, derivative_directions=derivative_directions)
    mean = mvn_D.mean
    var  = mvn_D.variance
    
    # Expected log Φ(g/nu)
    term = _elogphi_expectation_gaussian(
        mean, var, nu=nu,
        gh_nodes=(None if mc_samples > 0 else gh_nodes),
        mc_samples=(mc_samples if mc_samples > 0 else None),
    )
    return term.sum()