import torch
import numpy as np
import os
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical, MultivariateNormal, MixtureSameFamily, Independent, Normal
import argparse
from tqdm import tqdm
from torchdyn.core import NeuralODE
import matplotlib.pyplot as plt
from torchcfm.conditional_flow_matching import ConditionalFlowMatcher


class WMLP(nn.Module):
    """Simple MLP with time conditioning for vector field modeling."""
    def __init__(self, dim, w=64, time_varying=True):
        super().__init__()
        self.time_varying = time_varying
        input_dim = dim + (1 if time_varying else 0)
        
        self.net = nn.Sequential(
            nn.Linear(input_dim, w),
            nn.SELU(),
            nn.Linear(w, w),
            nn.SELU(),
            nn.Linear(w, w),
            nn.SELU(),
            nn.Linear(w, dim)
        )
    
    def forward(self, t, x, *args, **kwargs):
        if self.time_varying:
            # Reshape t to match batch size of x
            t_expanded = t.view(-1, 1).expand(x.size(0), 1)
            # Concatenate t and x
            tx = torch.cat([t_expanded, x], dim=1)
            return self.net(tx)
        else:
            return self.net(x)


def make_standard_gmm(dim: int, seed: int = None, ncomp: int = 3):
    """
    Build a GMM with ncomp components in R^dim, then standardize:
      - Overall mean = 0
      - Overall covariance = I (unit variance per axis)

    Returns:
      gmm: a torch.distributions.MixtureSameFamily instance
      params: dict with 'logits', 'locs', 'scales' for further reuse
    """
    if seed is not None:
        torch.manual_seed(seed)

    # 1) Initialize uniform mixture logits → equal weights after softmax
    logits = torch.zeros(ncomp)

    # 2) Random initial locs & positive scales
    locs   = torch.randn(ncomp, dim)
    scales = F.softplus(torch.randn(ncomp, dim))

    # 3) Compute mixture weights π_k
    weights = torch.softmax(logits, dim=0)  # shape [ncomp]

    # 4) Compute current mixture mean μ̄ = ∑ π_k μ_k
    mu_bar = (weights.unsqueeze(1) * locs).sum(dim=0)  # [dim]

    # 5) Center locs: μ_k ← μ_k – μ̄
    locs_centered = locs - mu_bar

    # 6) Compute between-component variance: ∑ π_k ‖μ_k−μ̄‖²
    var_between = (weights * (locs_centered.pow(2).sum(dim=1))).sum()

    # 7) Solve for σ² so total variance = 1:
    #      trace(Σ) = d*σ² + ∑ π_k‖μ_k−μ̄‖² = d  ⇒  σ² = (d - var_between) / d
    sigma2 = torch.clamp((dim - var_between) / dim, min=1e-6)
    sigma  = sigma2.sqrt()

    # 8) Set all component scales = σ
    scales_standard = sigma.expand(ncomp, dim)

    # 9) Build standardized GMM
    mix  = Categorical(logits=logits)
    comp = Independent(Normal(loc=locs_centered, scale=scales_standard), 1)
    gmm  = MixtureSameFamily(mix, comp)

    return gmm

def train_cfm_model(dim, gmm, num_epochs=50000, batch_size=128, lr=1e-3, device="cpu"):
    """
    Train a simple CFM model on the specified dimension
    """
    print(f"Training CFM model for dimension {dim}")
    
    # Create model
    model = WMLP(dim, w=128 if dim < 50 else 256).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    # Create flow matcher
    flow_matcher = ConditionalFlowMatcher()

    
    # Training loop
    losses = []
    for i in tqdm(range(num_epochs), desc=f"Training CFM model for dim={dim}"):
        # Sample batch
        x0_batch = torch.randn(batch_size, dim).to(device)
        x1_batch = gmm.sample((batch_size,)).to(device)
        

        # Get conditional samples and vector field
        t, xt, ut = flow_matcher.sample_location_and_conditional_flow(x0_batch, x1_batch)
        
        # Forward pass
        vt_hat = model(t, xt)
        
        # Compute loss
        loss = torch.mean((vt_hat - ut) ** 2)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        losses.append(loss.item())
        
        if i % 1000 == 0:
            print(f"Step {i}, Loss: {loss.item():.6f}")
    
    # Create Neural ODE
    node = NeuralODE(
        model,
        solver="dopri5",
        sensitivity="adjoint",
        atol=1e-3,
        rtol=1e-3
    )
    
    # Plot loss curve
    plt.figure(figsize=(10, 6))
    plt.plot(losses)
    plt.title(f"Training Loss for CFM Model (dim={dim})")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.yscale("log")
    plt.grid(True)
    
    # Create directory for models
    os.makedirs(f"folder_generate_table/models/dim_{dim}", exist_ok=True)
    plt.savefig(f"folder_generate_table/models/dim_{dim}/cfm_loss.png")
    plt.close()
    
    # Save model
    torch.save(model.state_dict(), f"folder_generate_table/models/dim_{dim}/cfm_model.pt")
    
    return model, node

def generate_dataset(dim, n_samples=10000, n_components=3, seed=42, device="cpu"):
    """
    Generate a dataset with specified dimension and train a simple CFM model
    """
    print(f"Generating dataset for dimension {dim}")
    
    # Set random seed for reproducibility
    np.random.seed(seed)
    torch.manual_seed(seed)
    
    # Create directories
    os.makedirs(f"folder_generate_table/datasets/dim_{dim}", exist_ok=True)
    os.makedirs(f"folder_generate_table/models/dim_{dim}", exist_ok=True)
    
    # Generate source distribution (standard normal)
    source_samples = torch.randn(n_samples, dim)
    
    # Generate target distribution (GMM)
    gmm = make_standard_gmm(dim, seed=seed)
    target_samples = gmm.sample((n_samples,))
    
    # Save GMM object
    torch.save(gmm, f"folder_generate_table/datasets/dim_{dim}/gmm.pt")
    
    # Train CFM model
    print(f"Training CFM model for dimension {dim}")
    # Adjust epochs based on dimension
    if dim == 3:
        num_epochs = 50000
    elif dim == 10:
        num_epochs = 75000
    else:  # dim == 50
        num_epochs = 100000
    
    cfm_model, cfm_node = train_cfm_model(
        dim, 
        gmm, 
        num_epochs=num_epochs,
        device=device
    )
    
    # Generate and save trajectories
    print(f"Generating trajectories for dimension {dim}")
    
    # Generate trajectories
    with torch.no_grad():
        traj = cfm_node.trajectory(
            source_samples[:1000].to(device),
            t_span=torch.linspace(0, 1, 100, device=device)
        )
    
    # Save only the initial and final states of the trajectory
    torch.save(traj[0].cpu(), f"folder_generate_table/datasets/dim_{dim}/traj_start.pt")
    torch.save(traj[-1].cpu(), f"folder_generate_table/datasets/dim_{dim}/traj_end.pt")
    
    # Plot samples (for 2D visualization)
    if dim >= 2:
        plt.figure(figsize=(10, 6))
        plt.scatter(traj[-1, :, 0].cpu(), traj[-1, :, 1].cpu(), alpha=0.5, label='Generated')
        
        # Generate some samples from GMM for reference
        gmm_samples = gmm.sample((1000,))
        plt.scatter(gmm_samples[:, 0], gmm_samples[:, 1], alpha=0.5, label='GMM')
        
        plt.title(f'Generated Samples (dim={dim})')
        plt.legend()
        plt.savefig(f"folder_generate_table/datasets/dim_{dim}/samples.png")
        plt.close()
    
    return gmm

def main():
    """
    Generate datasets for all dimensions (3, 10, 50)
    """
    # Parse arguments
    parser = argparse.ArgumentParser(description='Generate datasets for dimensions 3, 10, and 50')
    parser.add_argument('--dimensions', type=int, nargs='+', default=[3, 10, 50], 
                        help='Dimensions to generate datasets for')
    parser.add_argument('--device', type=str, default="cuda" if torch.cuda.is_available() else "cpu", 
                        help='Device to use for training (cuda or cpu)')
    args = parser.parse_args()
    
    # Create main directories
    os.makedirs("folder_generate_table/datasets", exist_ok=True)
    os.makedirs("folder_generate_table/models", exist_ok=True)
    
    # Generate datasets for each dimension
    for dim in args.dimensions:
        gmm = generate_dataset(dim, device=args.device)
        print(f"Dataset for dimension {dim} generated successfully")
    
    print("All datasets generated successfully!")

if __name__ == "__main__":
    main()
