import numpy as np
from scipy.stats import wishart
from numpy.random import uniform, choice, default_rng
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torch.distributions.multivariate_normal import MultivariateNormal

# MLP Mixer
class MLPMixer(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=8):
        super().__init__()
        self.mixer = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, output_dim),
        )

    def forward(self, s):
        return self.mixer(s)

# Fast L1 projection onto the weighted L1 ball
def fast_l1_projection(v, w, tau=1.0):
    if np.any(w < 0):
        raise ValueError("Weights must be non-negative")

    if np.dot(w, np.abs(v)) <= tau:
        return v.copy()

    u = np.abs(v) / w
    idx = np.argsort(u)[::-1]
    u_sorted = u[idx]
    w_sorted = w[idx]
    abs_v_sorted = np.abs(v)[idx]

    cs_num = np.cumsum(w_sorted * abs_v_sorted)
    cs_den = np.cumsum(w_sorted**2)

    lambdas = (cs_num - tau) / cs_den

    valid = u_sorted > lambdas
    rho = np.nonzero(valid)[0]
    if rho.size == 0:
        lam = 0.0
    else:
        lam = lambdas[rho[-1]]

    return np.sign(v) * np.maximum(np.abs(v) - lam * w, 0.0)

def mvee_weighted_l1(w, tau=1.0):
    w = np.asarray(w, dtype=float)
    if (w <= 0).any():
        raise ValueError("All weights w_i must be positive.")
    center = np.zeros_like(w)
    axes = tau / w
    A = np.diag((w / tau)**2)
    return A, center, axes

def gen_ssc(m, d, rho_coeff, weights):
    rng = default_rng(seed=None)
    rho = rho_coeff * np.sqrt(d)
    assert m > d, f"Must be m > d"
    assert rho > 1, f"Must be rho > 1"

    # Sample unit-ball points
    Z = rng.standard_normal((m, d))
    U = Z / np.linalg.norm(Z, axis=1, keepdims=True)
    r = rng.random((m,1)) ** (1.0 / d)
    Y = U * r

    # Stretch to the ellipsoid
    axes = rho / weights
    A = np.diag([1/axes[i]**2 for i in range(len(axes))])
    W = Y * axes # Broadcast the axes to the shape of Y

    # Project to the weighted L1-ball
    W = np.vstack([fast_l1_projection(W[i,:], weights) for i in range(W.shape[0])])
    return W, A

class SynthDataset(Dataset):
    def __init__(self, X, S, transform):
        self.obs = X
        self.factors = S
        self.transform = transform

    def __len__(self):
        return len(self.obs)

    def __getitem__(self, idx):
        x = self.obs[idx]
        if self.transform != None:
            x = self.transform(x)
        factors = self.factors[idx]
        return x, factors

def get_data(exp, nobs, input_dim, latent_dim, batch_size, num_workers=2):

    # Mix latents
    if exp == 'a':
        # Sample latents from Gaussian
        dist = MultivariateNormal(
            torch.FloatTensor([0.0] * latent_dim),
            torch.FloatTensor(wishart.rvs(latent_dim, np.eye(latent_dim), size=1))
        )
        S = dist.sample([nobs]).float().numpy().reshape(nobs, latent_dim)
        
        # Mix
        weights = uniform(1, 2, size=latent_dim)
        A, _ = gen_ssc(input_dim, latent_dim, rho_coeff=0.9, weights=weights)
        X = S @ A.T

    elif exp == 'b':
        # Sample latents from Gaussian
        dist = MultivariateNormal(
            torch.FloatTensor([0.0] * latent_dim),
            torch.FloatTensor(wishart.rvs(df=latent_dim, scale=np.eye(latent_dim), size=1))
        )
        S = dist.sample([nobs]).float().numpy().reshape(nobs, latent_dim)
        
        # Mix
        weights = uniform(3, 4, size=latent_dim)
        A, _ = gen_ssc(input_dim, latent_dim, rho_coeff=0.9, weights=weights)
        Z = S @ A.T
        X = uniform(0.5, 1.0)*np.cos(Z) + Z

    elif exp == 'c':
        # Sample latents from Uniform
        S = uniform(-1, 1, (nobs, latent_dim))
        
        # Mix
        obs_choice = choice(range(input_dim), size=(input_dim // 2), replace=False)
        X = np.full((nobs, input_dim), None)
        for i in range(input_dim):
            input_S = S.copy()
            if i in obs_choice:
                which_s = choice(range(latent_dim), size=latent_dim-1, replace=False)
                alpha = uniform(0.1, 0.2, size=latent_dim-1)
                sign = choice([-1, 1], size=latent_dim-1)
                suppressor = sign * alpha
                input_S[:, which_s] *= suppressor
            mixer = MLPMixer(input_dim=latent_dim, output_dim=1)
            X[:, i] = mixer(torch.tensor(input_S, dtype=torch.float32)).detach().numpy().reshape(-1)

    X = X.astype(np.float32)
    S = S.astype(np.float32)

    # Train and validation splits
    S_train, S_val = np.split(S, [int(0.9 * len(S))])
    X_train, X_val = np.split(X, [int(0.9 * len(S))])

    # Create dataloaders
    train_loader = DataLoader(
        SynthDataset(X_train, S_train, transform=None),
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=True
    )
    val_loader = DataLoader(
        SynthDataset(X_val, S_val, transform=None),
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=False
    )

    return train_loader, val_loader