import argparse
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import wandb
import math
import copy
from torchdyn.core import NeuralODE
from torchcfm.conditional_flow_matching import ConditionalFlowMatcher
import torch.nn as nn
from torch.distributions import MultivariateNormal

# Set random seed for reproducibility
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class WMLP(nn.Module):
    """Simple MLP with time conditioning for vector field modeling."""
    def __init__(self, dim, w=64, time_varying=True):
        super().__init__()
        self.time_varying = time_varying
        input_dim = dim + (1 if time_varying else 0)
        
        self.net = nn.Sequential(
            nn.Linear(input_dim, w),
            nn.SELU(),
            nn.Linear(w, w),
            nn.SELU(),
            nn.Linear(w, w),
            nn.SELU(),
            nn.Linear(w, dim)
        )
    
    def forward(self, t, x, *args, **kwargs):
        if self.time_varying:
            # Reshape t to match batch size of x
            t_expanded = t.view(-1, 1).expand(x.size(0), 1)
            # Concatenate t and x
            tx = torch.cat([t_expanded, x], dim=1)
            return self.net(tx)
        else:
            return self.net(x)

def calculate_gradient_norm(model):
    """Calculate the gradient norm of a PyTorch model."""
    total_norm = 0.0
    for param in model.parameters():
        if param.grad is not None:
            param_norm = param.grad.data.norm(2)  # Compute L2 norm of the gradient
            total_norm += param_norm.item() ** 2  # Accumulate squared norms
    total_norm = total_norm ** 0.5  # Take square root to get the final norm
    return total_norm


def train_epoch(model, optimizer, flow_matcher, x0, x1, epoch, gradient_storage=None):
    """Train for one epoch and track gradient variance if requested."""
    model.train()
    optimizer.zero_grad()
    
    # Move data to device
    x0 = x0.to(device)
    x1 = x1.to(device)
    model = model.to(device)

    # Sample location and conditional flow
    t, xt, ut = flow_matcher.sample_location_and_conditional_flow(x0, x1)
    
    # Forward pass
    vt = model(t, xt)
    
    # Calculate loss
    loss = torch.mean((vt - ut) ** 2)
    
    # Backward pass
    loss.backward()
    
    # Calculate gradient norm
    grad_norm = calculate_gradient_norm(model)
    
    # Track gradient variance if requested
    if epoch > 1900 and epoch % 10 == 0:
        gradient_storage = track_gradient_variance(model, optimizer, FM, device, data, gradient_storage)
    # Optimize
    optimizer.step()
    
    # Log metrics
    wandb.log({
        "loss": loss.item(),
        "epoch": epoch,
        "gradient_norm": grad_norm
    })
    
    return loss.item()

def plot_gradient_variance(gradient_storage, title="Gradient Variance Across Time Steps"):
    """Plot the variance of gradients across different time steps."""
    plt.figure(figsize=(10, 6))
    
    # Calculate statistics for each time step
    time_steps = sorted(gradient_storage.keys())
    means = []
    stds = []
    
    for t in time_steps:
        values = gradient_storage[t]
        if values:  # Check if there are any values for this time step
            means.append(np.mean(values))
            stds.append(np.std(values))
        else:
            means.append(0)
            stds.append(0)
    
    # Plot mean with standard deviation as error bars
    plt.errorbar(time_steps, means, yerr=stds, fmt='o-', capsize=5, label='Mean Gradient Norm ± Std')
    
    # Add labels and title
    plt.xlabel('Time Step (t)')
    plt.ylabel('Gradient Norm')
    plt.title(title)
    plt.grid(True, alpha=0.3)
    plt.legend()
    
    # Log to wandb
    wandb.log({"gradient_variance_plot": wandb.Image(plt)})
    
    # Save locally too
    plt.savefig("gradient_variance.png")
    plt.close()

def plot_trajectories(model, source_dist, target_dist, num_samples=100, num_timesteps=100):
    """Plot trajectories from source to target distribution using the trained model."""
    model = model.cpu().eval()
    
    # Sample points from source distribution
    x0_samples = source_dist.sample((num_samples,)).cpu()
    
    # Create Neural ODE for integration
    node = NeuralODE(
        model,
        solver="dopri5",
        sensitivity="adjoint",
        atol=1e-3,
        rtol=1e-3
    )
    
    # Generate trajectories
    with torch.no_grad():
        traj = node.trajectory(
            x0_samples,
            t_span=torch.linspace(0, 1, num_timesteps)
        )
    
    # Plot if in 2D
    if x0_samples.shape[1] == 2:
        plt.figure(figsize=(12, 10))
        
        # Plot source distribution
        source_samples = source_dist.sample((500,)).cpu().numpy()
        sns.kdeplot(
            x=source_samples[:, 0],
            y=source_samples[:, 1],
            cmap='Blues',
            fill=True,
            alpha=0.3,
            zorder=0,
            label="Source Distribution"
        )
        
        # Plot target distribution
        target_samples = target_dist.sample((500,)).cpu().numpy()
        sns.kdeplot(
            x=target_samples[:, 0],
            y=target_samples[:, 1],
            cmap='Reds',
            fill=True,
            alpha=0.3,
            zorder=0,
            label="Target Distribution"
        )
        
        # Plot initial and final points
        plt.scatter(
            traj[0, :, 0].numpy(), 
            traj[0, :, 1].numpy(), 
            s=30, 
            c='blue', 
            alpha=0.7,
            label='Initial samples'
        )
        
        plt.scatter(
            traj[-1, :, 0].numpy(), 
            traj[-1, :, 1].numpy(), 
            s=30, 
            c='red', 
            alpha=0.7,
            label='Final samples'
        )
        
        # Plot trajectories for a subset of points
        plot_indices = np.random.choice(num_samples, min(10, num_samples), replace=False)
        for idx in plot_indices:
            plt.plot(
                traj[:, idx, 0].numpy(),
                traj[:, idx, 1].numpy(),
                'k-',
                alpha=0.3
            )
        
        plt.title("Flow Trajectories: N(μ, Σ) to N(0, I)")
        plt.xlabel("Dimension 1")
        plt.ylabel("Dimension 2")
        plt.legend()
        plt.grid(alpha=0.3)
        
        # Log to wandb
        wandb.log({"trajectories": wandb.Image(plt)})
        
        # Save locally too
        plt.savefig("trajectories.png")
        plt.close()
    
    # Calculate straightness metric
    straightness = calculate_straightness(traj)
    wandb.log({"straightness": straightness})
    print(f"Path straightness: {straightness:.4f} (closer to 1 is straighter)")
    
    return traj

def calculate_straightness(traj):
    """Calculate the straightness metric for trajectories."""
    straightness = 0.0
    for i in range(traj.size(1)):
        # Calculate path length along the trajectory
        path_length = 0.0
        for t in range(1, traj.size(0)):
            path_length += torch.norm(traj[t, i] - traj[t-1, i], dim=-1).mean().item()
        
        # Calculate Euclidean distance between endpoints
        euclidean_dist = torch.norm(traj[-1, i] - traj[0, i], dim=-1).mean().item()
        
        # Straightness = euclidean distance / path length (closer to 1 is straighter)
        sample_straightness = euclidean_dist / (path_length + 1e-8)  # Avoid division by zero
        straightness += sample_straightness
    
    return straightness / traj.size(1)

def generate_covariance_matrix(dim):
    """Generate a positive definite covariance matrix with a specific structure."""
    # Start with a random matrix
    A = torch.randn(dim, dim)
    # Make it positive definite
    cov = A @ A.T + torch.eye(dim) * 0.1
    # Scale eigenvalues to create a more interesting distribution
    if dim >= 2:
        # For 2D case, create a specific structure
        if dim == 2:
            # Create a covariance matrix with correlation 0.7
            var1 = 2.0
            var2 = 0.5
            corr = 0.7
            cov = torch.tensor([
                [var1, corr * torch.sqrt(torch.tensor(var1 * var2))],
                [corr * torch.sqrt(torch.tensor(var1 * var2)), var2]
            ])
        else:
            # For higher dimensions, create a block-like structure
            eigenvalues = torch.linspace(0.1, 3.0, dim)
            U, _ = torch.linalg.qr(torch.randn(dim, dim))
            cov = U @ torch.diag(eigenvalues) @ U.T
    
    return cov

def integrate_samples(model, x0, num_timesteps=20):
    """Integrate samples from x0 using the model to get x1."""
    model = model.to(device).eval()
    x0 = x0.to(device)
    
    # Create Neural ODE for integration
    node = NeuralODE(
        model,
        solver="dopri5",
        sensitivity="adjoint",
        atol=1e-3,
        rtol=1e-3
    )
    
    # Integrate from t=0 to t=1
    with torch.no_grad():
        traj = node.trajectory(
            x0,
            t_span=torch.linspace(0, 1, num_timesteps)
        )
    
    # Return final point (should be close to target distribution)
    return traj[-1].detach()

def main(args):
    """Main training function."""
    # Configure wandb
    wandb.init(
        project="simple_cfm",
        name=f"simple_cfm_dim_{args.dim}_w_{args.w}_sigma_{args.sigma}_integrated",
        config=vars(args)
    )
    
    print(f"=== Simple CFM: N(μ, Σ) to N(0, I) with Integration ===")
    print(f"Dimensions: {args.dim}, Width: {args.w}, Sigma: {args.sigma}")
    print(f"Using device: {device}")
    
    # Create source and target distributions
    source_mean = torch.ones(args.dim) * args.mean_scale
    source_cov = generate_covariance_matrix(args.dim)
    
    source_dist = MultivariateNormal(
        loc=source_mean,
        covariance_matrix=source_cov
    )
    
    target_dist = MultivariateNormal(
        loc=torch.zeros(args.dim),
        covariance_matrix=torch.eye(args.dim)
    )
    
    # Initialize model and optimizer
    model = WMLP(dim=args.dim, w=args.w, time_varying=True)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    
    # Initialize flow matcher
    flow_matcher = ConditionalFlowMatcher(sigma=args.sigma)
    
    # Setup gradient storage for variance tracking
    gradient_storage = {t/10: [] for t in range(11)}  # Track at t = 0.0, 0.1, ..., 1.0
    
    # Training loop
    for epoch in range(args.num_epochs):
        # Sample from source distribution
        x0 = source_dist.sample((args.batch_size,))
        
        # Every 100 epochs, regenerate x1 by integrating from x0
        if epoch % 100 == 0 or epoch == 0:
            print(f"Epoch {epoch}: Integrating to generate new x1 samples...")
            # Initialize with some target samples for the first epoch
            if epoch == 0:
                x1 = target_dist.sample((args.batch_size,))
            else:
                # After first epoch, integrate from x0 to get x1
                x1 = integrate_samples(model, x0)
        
        # Train for one epoch
        loss = train_epoch(model, optimizer, flow_matcher, x0, x1, epoch, gradient_storage)
        
        # Print progress
        if epoch % 100 == 0 or epoch == args.num_epochs - 1:
            print(f"Epoch {epoch}/{args.num_epochs}, Loss: {loss:.6f}")
            
            # Periodically plot gradient variance
            if epoch % 500 == 0 or epoch == args.num_epochs - 1:
                plot_gradient_variance(gradient_storage, f"Gradient Variance at Epoch {epoch}")
    
    # Plot final trajectories
    print("Generating final trajectories...")
    plot_trajectories(model, source_dist, target_dist, num_samples=100)
    
    # Plot final gradient variance
    plot_gradient_variance(gradient_storage, "Final Gradient Variance")
    
    print("Training complete!")
    wandb.finish()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train CFM: N(μ, Σ) to N(0, I) with Integration")
    
    # Model parameters
    parser.add_argument("--dim", type=int, default=2, help="Data dimension")
    parser.add_argument("--w", type=int, default=64, help="Width of network layers")
    
    # Distribution parameters
    parser.add_argument("--mean_scale", type=float, default=5.0, help="Scale for source distribution mean")
    
    # Training parameters
    parser.add_argument("--sigma", type=float, default=0.0, help="Sigma value for flow matcher")
    parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate")
    parser.add_argument("--num_epochs", type=int, default=2000, help="Number of epochs")
    parser.add_argument("--batch_size", type=int, default=128, help="Batch size")
    parser.add_argument("--integration_steps", type=int, default=20, help="Number of steps for integration")
    
    args = parser.parse_args()
    main(args)
