import torch
import numpy as np
import os
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import MixtureSameFamily
import argparse
from tqdm import tqdm
from torchdyn.core import NeuralODE
import matplotlib.pyplot as plt
from torchcfm.conditional_flow_matching import ConditionalFlowMatcher, SchrodingerBridgeConditionalFlowMatcher
from torch.utils.data import TensorDataset, DataLoader
from geomloss import SamplesLoss
from generate_datasets import WMLP
import math

# Function to compute Maximum Mean Discrepancy (MMD)
def compute_mmd(x, y, sigma=1.0):
    """Compute Maximum Mean Discrepancy (MMD) between two sets of samples."""
    def gaussian_kernel(a, b, sigma):
        a_norm = a.pow(2).sum(dim=1, keepdim=True)
        b_norm = b.pow(2).sum(dim=1, keepdim=True)
        dist = a_norm + b_norm.T - 2 * a @ b.T
        return torch.exp(-dist / (2 * sigma ** 2))
    
    Kxx = gaussian_kernel(x, x, sigma).mean()
    Kyy = gaussian_kernel(y, y, sigma).mean()
    Kxy = gaussian_kernel(x, y, sigma).mean()
    return Kxx + Kyy - 2 * Kxy

# Function to compute Sinkhorn Wasserstein distance
def compute_sinkhorn_wasserstein(x, y, p=2, blur=0.05, scaling=0.9):
    """
    Computes the Sinkhorn Wasserstein distance between two point clouds.
    - x, y: [N, d] tensors
    - p: order of the distance (2 = classic Wasserstein)
    - blur: regularization strength (lower = closer to real Wasserstein)
    """
    loss_fn = SamplesLoss("sinkhorn", p=p, blur=blur, scaling=scaling, debias=True)
    return loss_fn(x, y)

# Define a simple MLP model for the vector field


def train_model(dim, traj_start, traj_end, flow_matcher_type="cfm", num_epochs=50000, batch_size=128, lr=1e-3, device="cpu", sigma=0.0):
    """
    Train a model on the specified dimension using either CFM or SBM
    """
    print(f"Training {flow_matcher_type.upper()} model for dimension {dim}")
    
    # Create model
    model = WMLP(dim, w=128 if dim < 50 else 256).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    # Create flow matcher based on type
    if flow_matcher_type.lower() == "cfm":
        flow_matcher = ConditionalFlowMatcher(sigma=sigma)
    elif flow_matcher_type.lower() == "sbm":
        flow_matcher = SchrodingerBridgeConditionalFlowMatcher(sigma=sigma)
    else:
        raise ValueError(f"Unknown flow matcher type: {flow_matcher_type}")
    
    # Move data to device
    traj_start = traj_start.to(device)
    traj_end = traj_end.to(device)
    
    # Create dataset and dataloader
    dataset = TensorDataset(traj_start, traj_end)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    # Training loop
    losses = []
    for epoch in tqdm(range(num_epochs), desc=f"Training {flow_matcher_type.upper()} model for dim={dim}"):
        epoch_losses = []
        for x0_batch, x1_batch in dataloader:
            # Sample time            
            t, xt, ut = flow_matcher.sample_location_and_conditional_flow(x0_batch, x1_batch)
            
            vt = model(t, xt)
            
            loss = torch.mean((vt - ut) ** 2)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_losses.append(loss.item())
        
        # Average loss for this epoch
        avg_loss = np.mean(epoch_losses)
        losses.append(avg_loss)
        
        if epoch % 1000 == 0:
            print(f"Epoch {epoch}, Loss: {avg_loss:.6f}")
    
    # Create Neural ODE
    node = NeuralODE(
        model,
        solver="dopri5",
        sensitivity="adjoint",
        atol=1e-3,
        rtol=1e-3
    )
    
    # Plot loss curve
    plt.figure(figsize=(10, 6))
    plt.plot(losses)
    plt.title(f"Training Loss for {flow_matcher_type.upper()} Model (dim={dim})")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.yscale("log")
    plt.grid(True)
    
    # Create directory for models
    os.makedirs(f"folder_generate_table/models/dim_{dim}", exist_ok=True)
    plt.savefig(f"folder_generate_table/models/dim_{dim}/{flow_matcher_type.lower()}_loss.png")
    plt.close()
    
    # Save model
    torch.save(model.state_dict(), f"folder_generate_table/models/dim_{dim}/{flow_matcher_type.lower()}_model.pt")
    
    return model, node

def evaluate_model(dim, model, gmm, traj_start, traj_end, device="cpu"):
    """
    Evaluate the model on the specified dimension
    """

    # Make sure the model and node are on the correct device
    model = model.to(device)
    traj_start = traj_start.detach().to(device)
    traj_end = traj_end.detach().to(device)

    # Evaluate traj_end
    gmm_samples = gmm.sample((1000,)).to(device)

    log_prob_traj_end = gmm.log_prob(traj_end.to("cpu")).mean().item()
    gmm_log_prob_traj_end = gmm.log_prob(traj_end.to("cpu")).mean().item()
    mmd_traj_end = compute_mmd(traj_end.to("cpu"), gmm_samples.to("cpu")).item()
    sinkhorn_traj_end = compute_sinkhorn_wasserstein(traj_end.to("cpu"), gmm_samples.to("cpu")).item()

    # Create Neural ODE
    node = NeuralODE(
        model,
        solver="dopri5",
        sensitivity="adjoint",
        atol=1e-3,
        rtol=1e-3
    )
    
    print(f"Evaluating model for dimension {dim} - outside of data")

    
    # Generate source samples for evaluation
    source_samples = torch.randn(1000, dim)
    
    # Generate trajectories
    with torch.no_grad():
        traj = node.trajectory(
            source_samples.to(device),
            t_span=torch.linspace(0, 1, 100).to(device)
        )
    
    # Get the final samples (transformed from source to target)
    mmd_outside = []
    sinkhorn_outside = []
    log_prob_outside = []
    gmm_log_prob_outside = []
    for trials in range(5):
        generated_samples = traj[-1].to("cpu")
        gmm_sampled = gmm.sample((1000,)).to("cpu")
        mmd_outside.append(compute_mmd(generated_samples, gmm_sampled).item())
        sinkhorn_outside.append(compute_sinkhorn_wasserstein(generated_samples, gmm_sampled).item())
        log_prob_outside.append(gmm.log_prob(generated_samples).mean().item())
        gmm_log_prob_outside.append(gmm.log_prob(gmm_sampled).mean().item())
    
    # Generate samples from the target GMM distribution
    gmm_samples = gmm.sample((1000,)).to("cpu")
    
    # Calculate log probability of generated samples under GMM
    log_prob_outside_mean = np.mean(log_prob_outside)
    gmm_log_prob_outside_mean = np.mean(gmm_log_prob_outside)
    
    # Calculate MMD between generated samples and GMM samples
    mmd_outside_mean = np.mean(mmd_outside)
    
    # Calculate Sinkhorn distance between generated samples and GMM samples
    sinkhorn_outside_mean = np.mean(sinkhorn_outside)
    
    log_prob_outside_std = np.std(log_prob_outside)
    gmm_log_prob_outside_std = np.std(gmm_log_prob_outside)
    mmd_outside_std = np.std(mmd_outside)
    sinkhorn_outside_std = np.std(sinkhorn_outside)
    
    # Generate another set of GMM samples to compute self-distance as reference
    gmm_samples2 = gmm.sample((1000,)).cpu()
    mmd_self_outside = compute_mmd(gmm_samples, gmm_samples2).item()
    sinkhorn_self_outside = compute_sinkhorn_wasserstein(gmm_samples, gmm_samples2).item()

    
    print(f"Evaluating model for dimension {dim} - on data data")
    with torch.no_grad():
        traj = node.trajectory(
            traj_start,
            t_span=torch.linspace(0, 1, 100, device=device)
        )
    
    # Get the final samples (transformed from source to target)
    generated_samples = traj[-1].detach().cpu()
    
    # Calculate log probability of generated samples under GMM
    log_prob_inside = gmm.log_prob(generated_samples).mean().item()
    gmm_log_prob_inside = gmm.log_prob(traj_end.detach().cpu()).mean().item()
    
    # Calculate MMD between generated samples and GMM samples
    mmd_inside = compute_mmd(generated_samples, traj_end.detach().cpu()).item()
    
    # Calculate Sinkhorn distance between generated samples and GMM samples
    sinkhorn_inside = compute_sinkhorn_wasserstein(generated_samples, traj_end.detach().cpu()).item()
    
    # Generate another set of GMM samples to compute self-distance as reference
    gmm_samples2 = gmm.sample((1000,))
    mmd_self_inside = compute_mmd(gmm_samples, gmm_samples2).item()
    sinkhorn_self_inside = compute_sinkhorn_wasserstein(gmm_samples, gmm_samples2).item()
    
    # Return all metrics
    return {
        "log_prob_traj_end": log_prob_traj_end,
        "gmm_log_prob_traj_end": gmm_log_prob_traj_end,
        "mmd_traj_end": mmd_traj_end,
        "sinkhorn_traj_end": sinkhorn_traj_end,
        "log_prob_outside_mean": log_prob_outside_mean,
        "gmm_log_prob_outside_mean": gmm_log_prob_outside_mean,
        "log_prob_outside_std": log_prob_outside_std,
        "gmm_log_prob_outside_std": gmm_log_prob_outside_std,
        "mmd_outside_mean": mmd_outside_mean,
        "sinkhorn_outside_mean": sinkhorn_outside_mean,
        "mmd_outside_std": mmd_outside_std,
        "mmd_self_outside": mmd_self_outside,
        "sinkhorn_outside_std": sinkhorn_outside_std,
        "sinkhorn_self_outside": sinkhorn_self_outside,
        "log_prob_inside": log_prob_inside,
        "gmm_log_prob_inside": gmm_log_prob_inside,
        "mmd_inside": mmd_inside,
        "sinkhorn_inside": sinkhorn_inside,
        "mmd_self_inside": mmd_self_inside,
        "sinkhorn_self_inside": sinkhorn_self_inside,
        "last_trajectory": generated_samples
    }

def train_and_evaluate_dimension(dim, args,flow_matcher_types=["cfm", "sbm"], device="cpu"):
    """
    Train and evaluate models for the specified dimension
    """
    print(f"Processing dimension {dim}")
    
    # Load data from the correct directory
    data_dir = f"/slurm-storage/teoreu/git/variance_flows/src/train_synthetic/generate_table/datasets/dim_{dim}"
    traj_start = torch.load(f"{data_dir}/traj_start.pt")
    traj_end = torch.load(f"{data_dir}/traj_end.pt")
    gmm = torch.load(f"{data_dir}/gmm.pt")
    
    # Create results directory
    results_dir = f"/slurm-storage/teoreu/git/variance_flows/src/train_synthetic/generate_table/results/dim_{dim}"
    os.makedirs(results_dir, exist_ok=True)
    
    # Train and evaluate models for each flow matcher type
    results = {}
    all_metrics = {}
    
    for flow_matcher_type in flow_matcher_types:
        print(f"Training {flow_matcher_type.upper()} model for dimension {dim}")
        if flow_matcher_type == "cfm":
            sigma = 0.0
        else:
            sigma = args.sigma
        # Adjust epochs based on dimension
        if dim == 3:
            num_epochs = 5000
        elif dim == 10:
            num_epochs = 10000
        else:  # dim == 50
            num_epochs = 20000
        
        # Train model
        model, node = train_model(
            dim, 
            traj_start, 
            traj_end, 
            flow_matcher_type=flow_matcher_type,
            num_epochs=num_epochs,
            device=device,
            sigma=sigma
        )
        
        # Evaluate model
        metrics = evaluate_model(dim, model, gmm, traj_start.detach(),traj_end.detach(), device=device)
        results[flow_matcher_type] = metrics
        all_metrics[flow_matcher_type] = metrics
        
        # Save metrics to text file
        with open(f"{results_dir}/{flow_matcher_type}_sigma_{sigma}_metrics.txt", "w") as f:
            # Header for the file
            f.write(f"Metrics for {flow_matcher_type.upper()} model on dimension {dim}\n")
            f.write("=" * 50 + "\n\n")
            
            # OUTSIDE METRICS (out-of-sample evaluation)
            f.write("OUTSIDE METRICS (Out-of-sample evaluation)\n")

            f.write("-" * 40 + "\n")

            f.write(f"Trajectory End Log Probability (mean): {metrics['log_prob_traj_end']}\n")
            f.write(f"Trajectory End GMM Log Probability (mean): {metrics['gmm_log_prob_traj_end']}\n")
            f.write(f"Traj End MMD: {metrics['mmd_traj_end']}\n")
            f.write(f"Traj End Sinkhorn: {metrics['sinkhorn_traj_end']}\n")

            f.write("-" * 40 + "\n")

            f.write(f"Log Probability (mean): {metrics['log_prob_outside_mean']}\n")
            f.write(f"Log Probability (std): {metrics['log_prob_outside_std']}\n")
            f.write(f"GMM Log Probability (mean): {metrics['gmm_log_prob_outside_mean']}\n")
            f.write(f"GMM Log Probability (std): {metrics['gmm_log_prob_outside_std']}\n")
            f.write(f"Log Probability Ratio: {metrics['log_prob_outside_mean'] / metrics['gmm_log_prob_outside_mean']}\n\n")
            
            f.write(f"MMD (mean): {metrics['mmd_outside_mean']}\n")
            f.write(f"MMD (std): {metrics['mmd_outside_std']}\n")
            f.write(f"MMD Self (GMM-GMM): {metrics['mmd_self_outside']}\n")
            f.write(f"MMD Ratio: {metrics['mmd_outside_mean'] / metrics['mmd_self_outside']}\n\n")
            
            f.write(f"Sinkhorn (mean): {metrics['sinkhorn_outside_mean']}\n")
            f.write(f"Sinkhorn (std): {metrics['sinkhorn_outside_std']}\n")
            f.write(f"Sinkhorn Self (GMM-GMM): {metrics['sinkhorn_self_outside']}\n")
            f.write(f"Sinkhorn Ratio: {metrics['sinkhorn_outside_mean'] / metrics['sinkhorn_self_outside']}\n\n")
            
            # INSIDE METRICS (in-sample evaluation)
            f.write("INSIDE METRICS (In-sample evaluation)\n")
            f.write("-" * 40 + "\n")
            f.write(f"Log Probability: {metrics['log_prob_inside']}\n")
            f.write(f"GMM Log Probability: {metrics['gmm_log_prob_inside']}\n")
            f.write(f"Log Probability Ratio: {metrics['log_prob_inside'] / metrics['gmm_log_prob_inside']}\n\n")
            
            f.write(f"MMD: {metrics['mmd_inside']}\n")
            f.write(f"MMD Self (GMM-GMM): {metrics['mmd_self_inside']}\n")
            f.write(f"MMD Ratio: {metrics['mmd_inside'] / metrics['mmd_self_inside']}\n\n")
            
            f.write(f"Sinkhorn: {metrics['sinkhorn_inside']}\n")
            f.write(f"Sinkhorn Self (GMM-GMM): {metrics['sinkhorn_self_inside']}\n")
            f.write(f"Sinkhorn Ratio: {metrics['sinkhorn_inside'] / metrics['sinkhorn_self_inside']}\n")
        
        # Plot samples if dimension allows
        if dim >= 2:
            plt.figure(figsize=(10, 6))
            plt.scatter(metrics['last_trajectory'][:, 0], metrics['last_trajectory'][:, 1], alpha=0.5, label=f'{flow_matcher_type.upper()} Generated')
            plt.scatter(gmm.sample((1000,))[:, 0], gmm.sample((1000,))[:, 1], alpha=0.5, label='GMM')
            plt.title(f'{flow_matcher_type.upper()} Generated Samples (dim={dim})')
            plt.legend()
            plt.savefig(f"{results_dir}/{flow_matcher_type}_samples.png")
            plt.close()
    
    # Compare models if we have multiple flow matcher types
    if len(flow_matcher_types) > 1:
        with open(f"{results_dir}/comparison.txt", "w") as f:
            f.write("Model Comparison for Dimension {}\n".format(dim))
            f.write("=" * 40 + "\n\n")
            
            # OUTSIDE METRICS COMPARISON
            f.write("OUTSIDE METRICS COMPARISON (Out-of-sample)\n")
            f.write("-" * 40 + "\n\n")
            
            # Compare log probabilities
            f.write("Log Probability Comparison:\n")
            for fm_type in flow_matcher_types:
                f.write(f"{fm_type.upper()}: {all_metrics[fm_type]['log_prob_outside_mean']:.6f}\n")
            
            # Compare MMD
            f.write("MMD Comparison:\n")
            for fm_type in flow_matcher_types:
                f.write(f"{fm_type.upper()}: {all_metrics[fm_type]['mmd_outside_mean']:.6f}\n")
            f.write("GMM Self: {:.6f}\n\n".format(all_metrics[flow_matcher_types[0]]['mmd_self_outside']))
            
            # Compare Sinkhorn
            f.write("Sinkhorn Comparison:\n")
            for fm_type in flow_matcher_types:
                f.write(f"{fm_type.upper()}: {all_metrics[fm_type]['sinkhorn_outside_mean']:.6f}\n")
            f.write("GMM Self: {:.6f}\n\n".format(all_metrics[flow_matcher_types[0]]['sinkhorn_self_outside']))
            
            # INSIDE METRICS COMPARISON
            f.write("INSIDE METRICS COMPARISON (In-sample)\n")
            f.write("-" * 40 + "\n\n")
            
            # Compare log probabilities
            f.write("Log Probability Comparison:\n")
            for fm_type in flow_matcher_types:
                f.write(f"{fm_type.upper()}: {all_metrics[fm_type]['log_prob_inside']:.6f}\n")
            f.write("GMM Self: {:.6f}\n\n".format(all_metrics[flow_matcher_types[0]]['gmm_log_prob_inside']))
            
            # Compare MMD
            f.write("MMD Comparison:\n")
            for fm_type in flow_matcher_types:
                f.write(f"{fm_type.upper()}: {all_metrics[fm_type]['mmd_inside']:.6f}\n")
            f.write("GMM Self: {:.6f}\n\n".format(all_metrics[flow_matcher_types[0]]['mmd_self_inside']))
            
            # Compare Sinkhorn
            f.write("Sinkhorn Comparison:\n")
            for fm_type in flow_matcher_types:
                f.write(f"{fm_type.upper()}: {all_metrics[fm_type]['sinkhorn_inside']:.6f}\n")
            f.write("GMM Self: {:.6f}\n".format(all_metrics[flow_matcher_types[0]]['sinkhorn_self_inside']))
    
    return results
    
def main():
    """
    Train and evaluate models for all dimensions
    """
    # Parse arguments
    parser = argparse.ArgumentParser(description='Train and evaluate models for dimensions 3, 10, and 50')
    parser.add_argument('--dimensions', type=int, nargs='+', default=[3], 
                        help='Dimensions to process')
    parser.add_argument('--flow_matcher_types', type=str, nargs='+', default=["cfm", "sbm"], 
                        help='Flow matcher types to use (cfm, sbm)')
    parser.add_argument('--device', type=str, default="cuda" if torch.cuda.is_available() else "cpu", 
                        help='Device to use for training (cuda or cpu)')
    parser.add_argument('--data_dir', type=str, 
                        default="/slurm-storage/teoreu/git/variance_flows/src/train_synthetic/generate_table", 
                        help='Directory containing the datasets')
    parser.add_argument('--sigma', type=float, default=0.1,
                        help='Sigma value for ConditionalFlowMatcher')
    args = parser.parse_args()
    
    # Create main results directory
    results_dir = f"{args.data_dir}/results"
    os.makedirs(results_dir, exist_ok=True)
    
    # Create summary file for all dimensions
    summary_file = f"{results_dir}/summary.txt"
    with open(summary_file, "w") as f:
        f.write("Summary of Results Across All Dimensions\n")
        f.write("=====================================\n")
        f.write("This summary contains both outside (out-of-sample) and inside (in-sample) metrics\n\n")
    
    # Train and evaluate models for each dimension

    all_results = {}
    for dim in args.dimensions:
        results = train_and_evaluate_dimension(dim, args, args.flow_matcher_types, args.device)
        all_results[dim] = results
        print(f"Dimension {dim} processed successfully")
        
        # Add to summary file
        with open(summary_file, "a") as f:
            f.write(f"Dimension {dim}:\n")
            f.write("------------\n")
            
            for fm_type in args.flow_matcher_types:
                metrics = results[fm_type]  # results is already specific to this dimension
                f.write(f"{fm_type.upper()} Results:\n")
                
                # Outside metrics
                f.write(f"  OUTSIDE METRICS:\n")
                f.write(f"    Log Probability: {metrics['log_prob_outside_mean']}\n")
                f.write(f"    MMD: {metrics['mmd_outside_mean']}\n")
                f.write(f"    Sinkhorn: {metrics['sinkhorn_outside_mean']}\n")
                
                # Inside metrics
                f.write(f"  INSIDE METRICS:\n")
                f.write(f"    Log Probability: {metrics['log_prob_inside']}\n")
                f.write(f"    MMD: {metrics['mmd_inside']}\n")
                f.write(f"    Sinkhorn: {metrics['sinkhorn_inside']}\n\n")
    
    print("All models trained and evaluated successfully!")
    print(f"Results saved to {results_dir}")

if __name__ == "__main__":
    main()
