import numpy as np
import torch
from sklearn.kernel_ridge import KernelRidge
from sklearn.preprocessing import StandardScaler
from sklearn.base import BaseEstimator, RegressorMixin


def simplex_projection(v):
    """
    Projects each row of v onto the probability simplex.
    v: (B, n)
    """
    sorted_v, _ = torch.sort(v, dim=1, descending=True)
    cumsum = torch.cumsum(sorted_v, dim=1)
    rho = torch.sum(sorted_v * torch.arange(1, v.shape[1] + 1, device=v.device) > (cumsum - 1), dim=1) - 1

    theta = (cumsum[torch.arange(v.shape[0]), rho] - 1) / (rho + 1)
    return torch.clamp(v - theta.unsqueeze(1), min=0.0)

def solve_simplex_ridge_batch(
    Phi,      # (B, K) complex
    E,        # (N, K) complex
    lam=1e-3,
    n_iters=25,
    step=None,
    P_init=None
):
    """
    Solves:
        min_P || P E - Phi ||^2 + lam ||P||^2
        s.t.   P >= 0, P 1 = 1
    """
    device = Phi.device
    B, K = Phi.shape
    N = E.shape[0]

    # Precompute quadratic form
    EE = torch.real(E @ E.conj().T)     # (N, N)
    Q = EE + lam * torch.eye(N, device=device)

    # Lipschitz constant
    if step is None:
        L = torch.linalg.norm(Q, ord=2)
        step = 1.0 / L

    # Linear term
    Bmat = torch.real(Phi @ E.conj().T)# (B, N)


    # Initialization
    if P_init is None:
        P = torch.full((B, N), 1.0 / N, device=device)
    else:
        P = P_init.clone()

    # PGD iterations
    for _ in range(n_iters):
        grad = P @ Q - Bmat
        P = P - step * grad
        P = simplex_projection(P)

    return P

class MultivariateCFEstimator(BaseEstimator, RegressorMixin):
    def __init__(self, n_freqs=1000, bandwidth=1.0, alpha=0.1, gamma=None, device="cuda"):
        """
        n_freqs:   Number of frequency vectors (Random Fourier Features).
                   Higher = better approximation, but slower.
        bandwidth: Scale of the frequencies (sigma).
                   - Low bandwidth = smooths the distribution (low-pass filter).
                   - High bandwidth = captures sharp peaks but adds noise.
        alpha:     Regularization strength for Kernel Ridge Regression.
        gamma:     RBF Kernel width for the X-variable regression.
        """
        self.n_freqs = n_freqs
        self.bandwidth = bandwidth
        self.alpha = alpha
        self.gamma = gamma
        self.device = torch.device(device)
        
        # Internal state
        self.scaler_x = StandardScaler()
        self.scaler_y = StandardScaler()

        self.krr_real = None
        self.krr_imag = None
        self.Omega = None # The frequency matrix

    def fit(self, X, Y):
        """
        X: (n_samples, x_dim) - Covariates
        Y: (n_samples, y_dim) - Targets
        """
        # 1. Standardize Data
        # Crucial: This ensures our fixed frequency bandwidth works for any Y scale
        X_scaled = self.scaler_x.fit_transform(X)
        Y_scaled = self.scaler_y.fit_transform(Y)
        
        y_dim = Y.shape[1]
        
        # 2. Generate Multivariate Frequencies (Gaussian Sampling)
        self.Omega = np.random.normal(
            loc=0, 
            scale=self.bandwidth, 
            size=(self.n_freqs, y_dim)
        )
        
        # 3. Compute Basis Functions (Characteristic Function Targets)
        # Project Y onto frequencies: (n, y_dim) . (freqs, y_dim).T -> (n, freqs)
        # arg = w^T * y
        args = Y_scaled @ self.Omega.T
        targets_real = np.cos(args)
        targets_imag = np.sin(args)
        
        # 4. Fit Kernel Ridge Regression
        self.krr_real = KernelRidge(alpha=self.alpha, kernel='rbf', gamma=self.gamma)
        self.krr_imag = KernelRidge(alpha=self.alpha, kernel='rbf', gamma=self.gamma)
        
        self.krr_real.fit(X_scaled, targets_real)
        self.krr_imag.fit(X_scaled, targets_imag)
        
        return self

    def get_conditional_distribution_matrix(self, X_query, Y_atoms, batch_size=256, n_iters=25, lam=1e-3):
        """
        Returns the probability matrix P where:
        P[i, j] = Prob(Y = Y_atoms[j] | X = X_query[i])
        
        X_query: (n_query, x_dim)
        Y_atoms: (n_atoms, y_dim) - The support to evaluate densities on.
        """
        # Transform inputs using the scalers fitted on training data
        X_q_scaled = self.scaler_x.transform(X_query)
        Y_a_scaled = self.scaler_y.transform(Y_atoms)
        
        # Predict Characteristic Function for Query X
        phi_real = self.krr_real.predict(X_q_scaled)
        phi_imag = self.krr_imag.predict(X_q_scaled)
        Phi_hat = phi_real + 1j * phi_imag

        # Move to GPU
        Phi_hat = torch.tensor(Phi_hat, dtype=torch.cfloat, device=self.device)
        self.Omega = torch.tensor(self.Omega, device=self.device, dtype=torch.float32)

        # Construct Basis Matrix E on Atoms
        Y_a_scaled = torch.tensor(Y_a_scaled, device=self.device, dtype=torch.float32)
        args_atoms = Y_a_scaled @ self.Omega.T
        E = torch.cos(args_atoms) + 1j * torch.sin(args_atoms)    
        cond = torch.linalg.cond(torch.real(E @ E.conj().T))
        print(cond)    
        
        # Solve Linear System: Phi = P * E.T
        n_query = X_q_scaled.shape[0]
        n_atoms = Y_a_scaled.shape[0]

        P_out = torch.zeros((n_query, n_atoms), device=self.device, dtype=torch.float32)

        # Batched solve
        for i in range(0, n_query, batch_size):
            Phi_batch = Phi_hat[i:i + batch_size]
            P_batch = solve_simplex_ridge_batch(
                Phi_batch,
                E,
                lam=lam,
                n_iters=n_iters
            )
            P_out[i:i + batch_size] = P_batch

        return P_out.cpu().numpy()