import torch
import numpy as np
import os
from torch.distributions import Categorical, MultivariateNormal, MixtureSameFamily, Independent, Normal
import argparse
from tqdm import tqdm

def make_standard_gmm(dim, ncomp, seed=None, var_between=None):
    """
    Build a GMM with ncomp components in R^dim, then standardize:
      - Overall mean = 0
      - Overall covariance = I (unit variance per axis)

    Returns:
      gmm: a torch.distributions.MixtureSameFamily instance
      params: dict with 'logits', 'locs', 'scales' for further reuse
    """
    if seed is not None:
        torch.manual_seed(seed)

    # 1) Sample component weights
    logits = torch.randn(ncomp)
    
    # 2) Sample component means
    if var_between is None:
        var_between = dim
    locs = torch.randn(ncomp, dim) * np.sqrt(var_between / dim)
    
    # 3) Compute overall mean
    weights = torch.softmax(logits, dim=0)
    mean = (weights.unsqueeze(1) * locs).sum(0)
    
    # 4) Center component means
    locs_centered = locs - mean
    
    # 5) Compute overall covariance
    cov = torch.zeros(dim, dim)
    for i in range(ncomp):
        loc_i = locs_centered[i].unsqueeze(1)
        cov += weights[i] * (loc_i @ loc_i.t())
    
    # 6) Compute eigendecomposition of covariance
    eigvals, eigvecs = torch.linalg.eigh(cov)
    
    # 7) Compute scaling factor to make overall variance = dim
    var_between = eigvals.sum().item()
    sigma2 = torch.clamp((dim - var_between) / dim, min=1e-6)
    sigma = sigma2.sqrt()
    
    # 8) Set all component scales = σ
    scales_standard = sigma.expand(ncomp, dim)
    
    # 9) Build standardized GMM
    mix = Categorical(logits=logits)
    comp = Independent(Normal(loc=locs_centered, scale=scales_standard), 1)
    gmm = MixtureSameFamily(mix, comp)
    
    return gmm, {'logits': logits, 'locs': locs_centered, 'scales': scales_standard}

def generate_dataset(dim, n_samples=10000, n_components=3, seed=42):
    """
    Generate a dataset with specified dimension
    """
    print(f"Generating dataset for dimension {dim}")
    
    # Set random seed for reproducibility
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    # Create directory for datasets
    os.makedirs(f"datasets/dim_{dim}", exist_ok=True)
    
    # Generate source distribution (standard normal)
    source_samples = torch.randn(n_samples, dim)
    
    # Generate target distribution (GMM)
    gmm, gmm_params = make_standard_gmm(dim, n_components, seed=seed)
    target_samples = gmm.sample((n_samples,))
    
    # Generate validation data (outside training distribution)
    val_source_samples = torch.randn(n_samples // 5, dim)
    val_target_samples = gmm.sample((n_samples // 5,))
    
    # Save datasets
    torch.save(source_samples, f"datasets/dim_{dim}/source_samples.pt")
    torch.save(target_samples, f"datasets/dim_{dim}/target_samples.pt")
    torch.save(val_source_samples, f"datasets/dim_{dim}/val_source_samples.pt")
    torch.save(val_target_samples, f"datasets/dim_{dim}/val_target_samples.pt")
    
    # Save GMM parameters
    torch.save(gmm_params, f"datasets/dim_{dim}/gmm_params.pt")
    
    print(f"Dataset for dimension {dim} generated and saved successfully")
    
    return gmm

def generate_all_datasets():
    """
    Generate datasets for all dimensions (3, 10, 50)
    """
    dimensions = [3, 10, 50]
    
    # Create main directory
    os.makedirs("datasets", exist_ok=True)
    
    # Generate datasets for each dimension
    for dim in dimensions:
        gmm = generate_dataset(dim)
        
        # Generate some samples from GMM for reference
        gmm_samples = gmm.sample((1000,))
        torch.save(gmm_samples, f"datasets/dim_{dim}/gmm_samples.pt")
        
        # Calculate GMM log probability on its own samples
        gmm_log_prob = gmm.log_prob(gmm_samples).mean().item()
        print(f"Dimension {dim} - GMM log probability: {gmm_log_prob:.4f}")
        
        # Save this reference score
        with open(f"datasets/dim_{dim}/gmm_score.txt", "w") as f:
            f.write(str(gmm_log_prob))

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Generate datasets for REflow+CFM and REflow+SBM models')
    parser.add_argument('--dimensions', type=int, nargs='+', default=[3, 10, 50], 
                        help='Dimensions to generate datasets for')
    parser.add_argument('--samples', type=int, default=10000, 
                        help='Number of samples in each dataset')
    parser.add_argument('--components', type=int, default=3, 
                        help='Number of components in GMM')
    parser.add_argument('--seed', type=int, default=42, 
                        help='Random seed for reproducibility')
    
    args = parser.parse_args()
    
    # Create main directory
    os.makedirs("datasets", exist_ok=True)
    
    # Generate datasets for specified dimensions
    for dim in args.dimensions:
        gmm = generate_dataset(dim, args.samples, args.components, args.seed)
        
        # Generate some samples from GMM for reference
        gmm_samples = gmm.sample((1000,))
        torch.save(gmm_samples, f"datasets/dim_{dim}/gmm_samples.pt")
        
        # Calculate GMM log probability on its own samples
        gmm_log_prob = gmm.log_prob(gmm_samples).mean().item()
        print(f"Dimension {dim} - GMM log probability: {gmm_log_prob:.4f}")
        
        # Save this reference score
        with open(f"datasets/dim_{dim}/gmm_score.txt", "w") as f:
            f.write(str(gmm_log_prob))
    
    print("All datasets generated successfully!")
