import torch
import numpy as np
import torch
from scipy.stats import norm

def generate_simulation_data_torch(n=1000, qf_size=100, p=4, link='linear', theta=None, seed=0, rho=0.25):
    """
    Generate simulation data in torch, matching the format of your example and the paper's description.
    Args:
        n: number of observations
        qf_size: number of samples per observation (quantile function size)
        p: number of features
        link: 'linear', 'quadratic', or 'exp'
        theta: optional, use this theta (otherwise default)
        seed: random seed
        rho: off-diagonal correlation for Z
    Returns:
        X: (n, p) torch tensor
        Y: (n, qf_size) torch tensor, each row sorted
        theta: (p,) torch tensor
        mu: (n,) torch tensor, true mean for each observation
        sigma: (n,) torch tensor, true stddev for each observation
    """
    torch.manual_seed(seed)
    np.random.seed(seed)
    # Step 1: Generate Z ~ N(0, Sigma) with off-diagonal rho
    mean = np.zeros(p)
    cov = np.full((p, p), rho)
    np.fill_diagonal(cov, 1.0)
    Z = np.random.multivariate_normal(mean, cov, size=n)
    # Step 2: X_j = 2*Phi(Z_j) - 1
    X = 2 * norm.cdf(Z) - 1
    X = torch.tensor(X, dtype=torch.float32)
    # Step 3: theta (unit vector, first entry positive)
    if theta is None:
        theta = torch.tensor([0.5, 0.1, 0, -0.5][:p], dtype=torch.float32)
        theta = theta / torch.norm(theta)
        if theta[0] < 0:
            theta = -theta
    else:
        theta = torch.tensor(theta, dtype=torch.float32)
        theta = theta / torch.norm(theta)
        if theta[0] < 0:
            theta = -theta
    # Step 4: projection
    s = X @ theta
    if link == 'linear':
        zeta = s
    elif link == 'quadratic':
        zeta = s ** 2
    elif link == 'exp':
        zeta = torch.exp(s)
    else:
        raise ValueError("Unknown link function")
    eta = torch.exp(s) / (1 + torch.exp(s))
    mu = torch.normal(zeta, 0.25)
    sigma = torch.distributions.Exponential(1/eta).sample()
    qf_obs = []
    for i in range(n):
        samples = torch.normal(mu[i], sigma[i], size=(qf_size,))
        qf_i = torch.sort(samples)[0]
        qf_obs.append(qf_i)
    Y = torch.stack(qf_obs)  # (n, qf_size)
    true_mu = zeta
    true_sigma = eta
    return X, Y, theta, true_mu, true_sigma

def generate_simulation_data_torch_true(n=1000, qf_size=100, p=4, link='linear', theta=None, seed=0, rho=0.25):
    torch.manual_seed(seed)
    np.random.seed(seed)
    mean = np.zeros(p)
    cov = np.full((p, p), rho)
    np.fill_diagonal(cov, 1.0)
    Z = np.random.multivariate_normal(mean, cov, size=n)
    X = 2 * norm.cdf(Z) - 1
    X = torch.tensor(X, dtype=torch.float32)
    if theta is None:
        theta = torch.tensor([0.5, 0.1, 0, -0.5][:p], dtype=torch.float32)
        theta = theta / torch.norm(theta)
        if theta[0] < 0:
            theta = -theta
    else:
        theta = torch.tensor(theta, dtype=torch.float32)
        theta = theta / torch.norm(theta)
        if theta[0] < 0:
            theta = -theta
    s = X @ theta
    if link == 'linear':
        zeta = s
    elif link == 'quadratic':
        zeta = s ** 2
    elif link == 'exp':
        zeta = torch.exp(s)
    else:
        raise ValueError("Unknown link function")
    eta = torch.exp(s) / (1 + torch.exp(s))
    mu = torch.normal(zeta, 0.25)
    sigma = torch.distributions.Exponential(1/eta).sample()
    # True quantile function: use dense grid and normal ppf
    quantile_grid = np.linspace(0, 1, qf_size+2)[1:-1]  # avoid 0 and 1
    qf_obs = []
    for i in range(n):
        qf_i = norm.ppf(quantile_grid, loc=mu[i].item(), scale=sigma[i].item())
        qf_obs.append(torch.tensor(qf_i, dtype=torch.float32))
    Y = torch.stack(qf_obs)  # (n, qf_size)
    true_mu = zeta
    true_sigma = eta
    return X, Y, theta, true_mu, true_sigma

# Example usage
if __name__ == "__main__":
    import matplotlib.pyplot as plt
    import numpy as np
    from scipy.stats import norm
    for link in ['linear', 'quadratic', 'exp']:
        X, Y, theta, mu, sigma = generate_simulation_data_torch_true(n=1000, qf_size=100, p=4, link=link, seed=10)
        print(f"Link: {link}")
        print("X shape:", X.shape)
        print("Y shape:", Y.shape)
        print("Theta:", theta)
        print("mu shape:", mu.shape)
        print("sigma shape:", sigma.shape)
        print("Y mean/std:", Y.mean().item(), Y.std().item())
        print('-'*40)
        # Plot the first 5 quantile functions: empirical vs true
        plt.figure(figsize=(8, 5))
        quantile_levels = np.linspace(0, 1, Y.shape[1])
        for i in range(5):
            # Empirical quantile function (from samples)
            plt.plot(quantile_levels, Y[i].cpu().numpy(), label=f'Obs {i+1} Empirical', linestyle='--')
            # True quantile function (from normal)
            true_qf = norm.ppf(quantile_levels, loc=mu[i].item(), scale=sigma[i].item())
            plt.plot(quantile_levels, true_qf, label=f'Obs {i+1} True', linestyle='-')
        plt.xlabel('Quantile Level')
        plt.ylabel('Value')
        plt.title('Empirical vs True Quantile Functions (First 5 Observations)')
        plt.legend()
        plt.tight_layout()
        plt.show()
