import torch
import math
from typing import Tuple, Any, Dict, List,Union, Sequence
from .other_utils import get_grid_uniform

def get_GP_weights(L, D, hyperparameters, label_params):
    """
    Samples GP functions via a Fourier basis.

    Args:
        L : int
            Number of samples
        D : int
            Number of features (e.g., neurons).
        hyperparameters : dict
            - 'len': list of length D (1D) or list of [lx, ly] per output (2D)
            - 'rho': list of length D, the marginal variances.
        label_params : dict
            - 'no_of_outputs': 1 or 2, dimensionality of the input.
            - 'dimension_1_range': [min, max] for axis y1.
            - 'dimension_2_range': [min, max] for axis y2 (if 2D).

    Returns:
        labels : Tensor [L, dims]
            Randomly sampled input locations.
        ground_truth_weights : Tensor [L, D]
            Function values of each GP at the sampled locations.
        ground_truth_weights_FD : Tensor [n_freqs, D]
            Fourier‐domain weights used to generate each GP.
    """
    # --- Extract axis limits and build axis_range dict ---
    if label_params['no_of_outputs'] == 1:
        axis_range = {
            'y1': label_params['dimension_1_range']
        }
    else:
        axis_range = {
            'y1': label_params['dimension_1_range'],
            'y2': label_params['dimension_2_range']
        }

    # --- 1) Determine input dimensionality (1D vs 2D) ---
    axes = ['y1', 'y2'] if 'y2' in axis_range else ['y1']
    dims = len(axes)  # dims == 1 or 2

    # --- 2) Sample L random input locations in each axis uniformly ---
    theta_parts = []
    for ax in axes:
        lo, hi = axis_range[ax]
        # Uniform sample in [lo, hi)
        theta_parts.append(lo + (hi - lo) * torch.rand(L, 1))
    labels = torch.cat(theta_parts, dim=1)  # shape [L, dims]

    # --- 3) Compute Fourier-domain parameters for these locations ---
    # get_fdprs must return a dict containing 'Bmat' and 'wwsq'
    fdprs_labels = get_fdprs(labels, label_params)
    Bmat = fdprs_labels['Bmat']            # shape [n_freqs, L]
    nfreq_w = Bmat.shape[1]                # number of frequency coefficients

    # --- 4) Allocate output tensors ---
    # ground_truth_weights: spatial-domain GP samples
    # ground_truth_weights_FD: Fourier-domain weights
    ground_truth_weights    = torch.zeros(L, D)
    ground_truth_weights_FD = torch.zeros(nfreq_w, D)

    # --- 5) Convert hyperparameters to torch.Tensors ---
    # lengthscales: shape [D, dims]
    lengthscales_all_features = torch.tensor(
        hyperparameters['len'], dtype=torch.float32
    )
    # variances: shape [D]
    variances_all_features = torch.tensor(
        hyperparameters['rho'], dtype=torch.float32
    ).view(D)

    # --- 6) Sample Fourier weights and project back to spatial domain ---
    for d in range(D):
        # 6a) spectral diagonal of RBF kernel at each freq
        ddiag = krbf_fourier_cdiag(
            lengthscales_all_features[d],
            variances_all_features[d],
            fdprs_labels['wwsq']
        )
        # ensure shape [n_freqs]
        ddiag = ddiag.squeeze(-1) if ddiag.dim() > 1 else ddiag
        ddiag = ddiag.view(-1) if dims == 1 else ddiag

        # 6b) draw Fourier coefficients ~ N(0, ddiag)
        fwts_w = torch.sqrt(ddiag) * torch.randn(nfreq_w)

        # 6c) store Fourier‐domain weights and spatial samples
        ground_truth_weights_FD[:, d] = fwts_w
        ground_truth_weights[:, d]    = Bmat @ fwts_w

    return labels, ground_truth_weights, ground_truth_weights_FD

def get_fdprs_uniform(
    T: int,
    label_params: Dict[str, Any]
) -> Dict[str, Any]:
    """
    Generate a uniform grid of T points (in 1D or 2D) and compute the
    corresponding Fourier‐domain parameters via `get_fdprs`.

    This is typically used to build the uniform basis for numerical
    integration (e.g. Riemann sums) or to evaluate the GP on a
    regular grid.

    Args:
        T : int
            Desired number of grid points.
        label_params : dict
            Configuration dictionary with keys:
              - 'no_of_outputs': 1 or 2 (dimensionality of input space).
              - 'dimension_1_range': [min, max] for the first axis.
              - 'dimension_2_range': [min, max] for the second axis (if 2D).
              - Other parameters consumed by `get_fdprs`, including
                'min_length_scale' and 'is_label_range_circular'.

    Returns:
        fdprs_uniform : dict
            Frequency‐domain parameters as returned by `get_fdprs`, including:
              - 'Bmat'        : Fourier basis matrix [n_freqs, T].
              - 'wwsq'        : Squared frequencies [n_freqs] or [n_freqs, dims].
              - 'circinterval': Extended domain bounds [2, dims].
              - 'minlen'      : Passed-through minimum length-scale.
              - 'condthresh'  : Conditioning threshold.
              - 'fBdims'      : Shape tuple of Bmat.
    """
    # 1) Build a uniform grid of input points
    #    For 1D: returns tensor [K, 1]
    #    For 2D: returns tensor [≤K, 2] (truncated if mesh exceeds K)
    grid_uniform = get_grid_uniform(T, label_params)

    # 2) Compute Fourier‐domain parameters on this grid
    #    Delegates to get_fdprs, which constructs the RBF Fourier basis.
    fdprs_uniform = get_fdprs(grid_uniform, label_params)

    return fdprs_uniform


def get_fdprs(
    labels: torch.Tensor,
    label_params: Dict[str, Any]
) -> Dict[str, Any]:
    """
    Derive frequency-domain parameters for an RBF‐GP prior on 1D or 2D inputs.

    This constructs the non‐uniform Fourier basis and frequency spectrum
    needed to sample or approximate a GP in the Fourier domain.

    Parameters
    ----------
    labels : torch.Tensor, shape [L, dims]
        Input locations at which to build the Fourier basis.
    label_params : dict
        Configuration dictionary with keys:
          - 'no_of_outputs'             : int, 1 or 2
          - 'dimension_1_range'        : [min, max] for the first axis
          - 'dimension_2_range'        : [min, max] for the second axis (if 2D)
          - 'min_length_scale'         : float or sequence of floats
          - 'is_label_range_circular'  : bool, whether the label range wraps

    Returns
    -------
    fdprs : dict
        Frequency‐domain parameters including:
          - 'minlen'      : minimum lengthscale(s)
          - 'circinterval': Tensor [2, dims] of extended domain bounds
          - 'condthresh'  : conditioning threshold for basis truncation
          - 'Bmat'        : Fourier‐basis matrix [n_freqs, L]
          - 'wwsq'        : squared frequencies [n_freqs] or [n_freqs, dims]
          - 'fBdims'      : shape tuple of Bmat
    """
    # --- 1) Extract axis limits from label_params ---
    if label_params['no_of_outputs'] == 1:
        axis_range = {'y1': label_params['dimension_1_range']}
    else:
        axis_range = {
            'y1': label_params['dimension_1_range'],
            'y2': label_params['dimension_2_range']
        }

    # --- 2) Extract minimum length-scale for basis construction ---
    minlength = label_params['min_length_scale']

    # --- 3) Determine number of input dimensions ---
    axes = list(axis_range.keys())  # e.g. ['y1'] or ['y1','y2']
    dims = len(axes)

    # --- 4) Determine extension factor for circular interval ---
    if label_params.get('is_label_range_circular', True):
        # No extension for truly circular ranges
        circ_scale = 0.0
    else:
        # Extend by 5% to mitigate edge effects
        circ_scale = 0.05

    # --- 5) Compute extended domain bounds ("circinterval") ---
    limits    = [axis_range[ax] for ax in axes]                        # [[min,max],...]
    extensions = [(mx - mn) * circ_scale for mn, mx in limits]        # per-axis extension
    low  = [mn - ext for (mn, _), ext in zip(limits, extensions)]
    high = [mx + ext for (_, mx), ext in zip(limits, extensions)]
    circinterval = torch.tensor([low, high], dtype=torch.float32)     # [2, dims]

    # --- 6) Initialize fdprs dict with core parameters ---
    fdprs: Dict[str, Any] = {
        'minlen'      : minlength,
        'circinterval': circinterval,
        'condthresh'  : 1e8
    }

    # --- 7) Build Fourier-domain basis & squared frequencies ---
    # mkRBFfourierBasis must return (Bmat, wwsq)
    Bmat, wwsq = mkRBFfourierBasis(labels, fdprs['minlen'], fdprs)
    fBdims = Bmat.shape  # (n_freqs, L)

    # --- 8) Warn if the basis is underdetermined ---
    if fBdims[0] < fBdims[1]:
        print(f"Warning: Basis is linearly dependent ({fBdims[0]} x {fBdims[1]})")

    # --- 9) Populate remaining entries in fdprs ---
    fdprs['Bmat']   = Bmat
    fdprs['wwsq']   = wwsq
    fdprs['fBdims'] = fBdims

    return fdprs


def mkRBFfourierBasis(
    xx: torch.Tensor,
    length: Any,
    fdprs: Dict[str, Any]
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Construct a real Fourier‐domain basis approximation for the RBF kernel.

    Supports 1D or 2D input locations.

    Parameters
    ----------
    xx : torch.Tensor, shape [n] or [n, d]
        Input locations. If 1D, may be shape [n]; will be reshaped to [n,1].
    length : float or sequence of floats
        Minimum length‐scale(s). If 1D, may be a scalar; if 2D, must be a list/tuple
    fdprs : dict
        Frequency‐domain parameters with keys:
          - 'circinterval' : Tensor [2, d], extended domain bounds per axis.
          - 'condthresh'   : float, threshold for truncating basis.

    Returns
    -------
    Bfft : torch.Tensor
        Real Fourier basis matrix:
          - shape [n, nfreq] for d==1
          - shape [n, nbasis] for d==2
    wwsq : torch.Tensor
        Squared frequency values:
          - shape [nfreq, 1] for d==1
          - shape [nbasis, 2] for d==2
    """
    # 0) Ensure xx has shape [n, d]
    xx = xx.clone()
    if xx.ndim == 1:
        xx = xx.unsqueeze(1)
    n, d = xx.shape

    # 1) Convert length to tensor of shape [d]
    length = torch.as_tensor(length, dtype=xx.dtype, device=xx.device)
    if length.numel() == 1 and d > 1:
        length = length.repeat(d)

    # 2) Extract circular interval and conditioning threshold
    circint   = fdprs['circinterval']            # tensor [[min1,...],[max1,...]]
    circrnge  = circint[1] - circint[0]          # range per dimension [d]
    condthresh = float(fdprs['condthresh'])      # drop tensor wrapper

    # Prepare lists for per-dimension basis components
    Bmats: List[torch.Tensor] = []
    w2s:   List[torch.Tensor] = []
    c2s:   List[torch.Tensor] = []

    # 3) Build 1D basis for each dimension j
    for j in range(d):
        # 3a) Determine maximum frequency index based on length-scale & condthresh
        maxfreq = math.floor(
            (circrnge[j].item() / (math.pi * length[j].item()))
            * math.sqrt(0.5 * math.log(condthresh))
        )

        # 3b) Call NU-DFT routine: realnufftbasis returns (B_j [nfreq×n], wvec [nfreq], ...)
        B_j, wvec, *rest = realnufftbasis(
            xx[:, j],
            circrnge[j].item(),
            maxfreq * 2 + 1
        )
        # Convert outputs to torch.Tensor on same device/dtype
        B_j  = torch.as_tensor(B_j,  dtype=xx.dtype, device=xx.device)
        wvec = torch.as_tensor(wvec, dtype=xx.dtype, device=xx.device)

        # 3c) Compute squared frequencies and spectral decay for RBF
        norm2 = (2 * math.pi / circrnge[j].item()) ** 2
        w2    = norm2 * (wvec ** 2)                 # ω² term
        c2    = torch.exp(-0.5 * w2 * (length[j]**2))  # exp(−½ ω² ℓ²)

        Bmats.append(B_j)
        w2s.append(w2)
        c2s.append(c2)

    # 4) Assemble full basis for d==1 or d==2
    if d == 1:
        # 1D: transpose basis to [n × nfreq] and add singleton dim to wwsq
        Bfft = Bmats[0].T                # [n, nfreq]
        wwsq = w2s[0].unsqueeze(1)       # [nfreq, 1]

    elif d == 2:
        # 2D: form outer product of spectral decays, truncate by condthresh
        n1, n2 = w2s[0].numel(), w2s[1].numel()
        # Kron‐product of decay coefficients [n1*n2]
        Cfull = torch.kron(c2s[1], c2s[0])
        Cmat  = Cfull.view(n1, n2)      # [n1, n2]
        # Keep only modes with sufficient weight
        keep_idx = torch.nonzero(Cmat > 1/condthresh, as_tuple=False)
        i1, i2   = keep_idx[:, 0], keep_idx[:, 1]

        # Multiply corresponding basis rows and transpose
        Bfft = (Bmats[0][i1, :] * Bmats[1][i2, :]).T  # [n, nbasis]
        # Stack squared freq dims into shape [nbasis, 2]
        wwsq = torch.stack([w2s[0][i1], w2s[1][i2]], dim=1)

    else:
        raise ValueError("Only d ≤ 2 supported by mkRBFfourierBasis")

    return Bfft, wwsq


def realnufftbasis(
    tvec: torch.Tensor,
    T: float,
    N: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Construct a real-valued Fourier basis for a non-uniform discrete Fourier transform.

    This function generates cosine and sine basis functions evaluated at arbitrary
    sample points `tvec` within a circular interval [0, T).

    Parameters
    ----------
    tvec : torch.Tensor, shape [n] or [n, 1]
        Non-uniform sample points in the interval [0, T).
    T : float
        Period (circular boundary) of the domain.
    N : int
        Total number of Fourier frequencies (cosine + sine).
        - If N is odd: produces (N+1)//2 cosine terms (including DC) and (N-1)//2 sine terms.
        - If N is even and N == T: Nyquist term is normalized specially.

    Returns
    -------
    B : torch.Tensor, shape [M, n]
        Real Fourier basis matrix where M = ncos + nsin:
          - First ncos rows are cosine functions.
          - Next nsin rows are sine functions.
    wvec : torch.Tensor, shape [M, 1]
        Frequency indices corresponding to each row of B.
    wcos : torch.Tensor, shape [ncos, 1]
        Non-negative frequency indices used for cosine terms.
    wsin : torch.Tensor, shape [nsin, 1]
        Negative frequency indices used for sine terms.
    """
    # Ensure tvec is a column vector of shape [n, 1]
    if tvec.dim() == 1:
        tvec = tvec.view(-1, 1)
    n = tvec.size(0)

    # Warn if any samples exceed the boundary
    if tvec.max() > T:
        print("Warning: max(tvec) greater than circular boundary T!")

    # Determine number of cosine and sine terms
    ncos = (N + 1) // 2   # includes DC term at frequency 0
    nsin = (N - 1) // 2   # negative frequencies

    # Build frequency index vectors
    wcos = torch.arange(ncos, dtype=tvec.dtype, device=tvec.device).view(-1, 1)
    wsin = torch.arange(-nsin, 0, dtype=tvec.dtype, device=tvec.device).view(-1, 1)
    wvec = torch.cat([wcos, wsin], dim=0)  # Combined frequencies [M, 1]

    # Scaling factor for orthonormality over [0, T)
    scale = torch.sqrt(torch.tensor(T / 2, dtype=tvec.dtype, device=tvec.device))

    # Compute cosine basis: shape [ncos, n]
    cos_part = torch.cos((2 * torch.pi / T) * (wcos @ tvec.T)) / scale

    # Compute sine basis if needed
    if nsin > 0:
        sin_part = torch.sin((2 * torch.pi / T) * (wsin @ tvec.T)) / scale
        B = torch.cat([cos_part, sin_part], dim=0)  # [M, n]
    else:
        B = cos_part  # only cosine terms

    # Normalize DC term to unit norm
    B[0, :] = B[0, :] / torch.sqrt(torch.tensor(2.0, dtype=tvec.dtype, device=tvec.device))

    # If N is even and equals T, normalize Nyquist term (highest cosine freq)
    if (N % 2 == 0) and (N == T):
        B[ncos - 1, :] = B[ncos - 1, :] / torch.sqrt(torch.tensor(2.0, dtype=tvec.dtype, device=tvec.device))

    return B, wvec, wcos, wsin


def krbf_fourier_cdiag(
    length: Union[float, Sequence[float], torch.Tensor],
    rho: Union[float, torch.Tensor],
    wwsq: torch.Tensor
) -> torch.Tensor:
    """
    Compute the spectral density (diagonal of covariance) of an RBF kernel
    in the Fourier domain.

    For a stationary RBF kernel with length-scale ℓ and variance ρ, the
    spectral density in 1D is:
        S(ω) = ρ * (ℓ √(2π)) * exp(−½ ω² ℓ²)

    In 2D (separable length-scales ℓ₁, ℓ₂):
        S(ω₁, ω₂) = ρ * (ℓ₁√(2π)) * (ℓ₂√(2π)) * exp[−½ (ω₁²ℓ₁² + ω₂²ℓ₂²)]

    Args:
        length : float, two-element sequence, or tensor of shape (d,)
            Length-scale(s) of the RBF kernel. If scalar, applies to all
            dimensions (d=1). If two-element, interpreted as [ℓ₁, ℓ₂] for d=2.
        rho : float or torch.Tensor
            Marginal variance (amplitude) of the GP prior.
        wwsq : torch.Tensor
            Squared Fourier frequencies:
              - shape [n] for the 1D case
              - shape [n, 2] for the 2D case

    Returns:
        cdiag : torch.Tensor
            Spectral density values at each frequency. Shape matches first
            dimension of `wwsq` (n).
    """
    # --- 1) Convert inputs to tensors on same device/dtype as wwsq ---
    length = torch.as_tensor(length, dtype=wwsq.dtype, device=wwsq.device)
    rho    = torch.as_tensor(rho,    dtype=wwsq.dtype, device=wwsq.device)

    # Precompute √(2π) factor
    two_pi_sqrt = math.sqrt(2 * math.pi)

    # --- 2) 1D (scalar) length-scale case ---
    if length.numel() == 1:
        ls = length  # zero-dimensional tensor
        # Compute S(ω) = ρ * (ℓ √(2π)) * exp(−½ ω² ℓ²)
        cdiag = rho * (ls * two_pi_sqrt) * torch.exp(-0.5 * wwsq * (ls**2))
        # Clamp to avoid underflow
        cdiag = torch.clamp(cdiag, min=1e-5)

    # --- 3) 2D (anisotropic) length-scale case ---
    else:
        # Extract per-dimension length-scales ℓ₁, ℓ₂
        ls1, ls2 = length[0], length[1]
        # Prefactor = ρ * (ℓ₁√(2π)) * (ℓ₂√(2π))
        prefix = rho * (ls1 * two_pi_sqrt) * (ls2 * two_pi_sqrt)
        # Exponent = −½ (ω₁²ℓ₁² + ω₂²ℓ₂²)
        exponent = -0.5 * (wwsq[:, 0] * (ls1**2) + wwsq[:, 1] * (ls2**2))
        # Compute spectral density
        cdiag = prefix * torch.exp(exponent)
        # Clamp to avoid underflow
        cdiag = torch.clamp(cdiag, min=1e-3)

    return cdiag


def krbf_fourier_cdiag_multi(
    lengthscale: torch.Tensor,
    rho: torch.Tensor,
    wwsq: torch.Tensor
) -> torch.Tensor:
    """
    Compute the RBF kernel spectral density for multiple outputs in parallel.

    This extends `krbf_fourier_cdiag` to handle D independent Gaussian processes,
    each with its own lengthscale and variance. Returns a matrix where each
    column corresponds to the spectral density of one process.

    Parameters
    ----------
    lengthscale : torch.Tensor, shape [D]
        Length-scale ℓ_d for each of the D processes.
    rho : torch.Tensor, shape [D]
        Marginal variance ρ_d for each of the D processes.
    wwsq : torch.Tensor, shape [Bdim] or [Bdim, 1]
        Squared Fourier frequencies shared across all processes.

    Returns
    -------
    k_m : torch.Tensor, shape [Bdim, D]
        Spectral density matrix where k_m[b, d] is the RBF spectral density
        at frequency index b for the d-th process.
    """
    # Number of independent GP outputs
    D = lengthscale.shape[0]
    
    # List to collect per-output spectral densities
    k_m_list = []
    
    # Compute spectrum for each output dimension
    for d in range(D):
        # Delegate to the single-output routine
        cdiag_d = krbf_fourier_cdiag(
            lengthscale[d],  # ℓ_d
            rho[d],          # ρ_d
            wwsq             # shared frequencies
        )
        # Ensure a 1D tensor of shape [Bdim]
        k_m_list.append(cdiag_d.squeeze())
    
    # Stack along the feature dimension to get [Bdim, D]
    k_m = torch.stack(k_m_list, dim=1)
    return k_m