import argparse
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import wandb
import math
import os
import time
import matplotlib.pyplot as plt
import numpy as np
import ot as pot
import torchdyn
from torchdyn.core import NeuralODE
from torchdyn.datasets import generate_moons
from torchcfm.conditional_flow_matching import *
from torchcfm.models.models import *
from torchcfm.utils import *
from torch.distributions import Categorical, MultivariateNormal, MixtureSameFamily, Independent, Normal



def calculate_gradient_norm(model):
    """
    Calculate the gradient norm of a PyTorch model.
    
    Args:
        model (torch.nn.Module): The model whose gradients are to be calculated.
    
    Returns:
        float: The total gradient norm.
    """
    total_norm = 0.0
    for param in model.parameters():
        if param.grad is not None:
            param_norm = param.grad.data.norm(2)  # Compute L2 norm of the gradient
            total_norm += param_norm.item() ** 2  # Accumulate squared norms
    total_norm = total_norm ** 0.5  # Take square root to get the final norm
    return total_norm


def compute_lipschitz_via_weights(model):
    lipschitz_constant = 1.0
    for module in model.modules():
        if isinstance(module, torch.nn.Linear):
            # Compute spectral norm (largest singular value) of weight matrix
            weight = module.weight.data
            sigma_max = torch.linalg.norm(weight, ord=2).item()
            lipschitz_constant *= sigma_max
        elif isinstance(module, torch.nn.Conv2d):
            # For Conv2d layers, reshape weight to 2D matrix for spectral norm computation
            weight = module.weight.data
            weight_matrix = weight.view(weight.size(0), -1)
            sigma_max = torch.linalg.norm(weight_matrix, ord=2).item()
            lipschitz_constant *= sigma_max
        # Add more layer types if necessary
    wandb.log({
        "lip": lipschitz_constant
    })
    return lipschitz_constant


def evaluate_model(model, args, data, extra="", direction="forward", reflow_iteration=0, epoch=None):
    """Evaluate the model and log metrics to wandb, supporting gradient-based models"""
    model_cpu = model.cpu()
    
    # Check if this is a gradient-based model (GMLP or EMLP) to handle correctly
    is_gradient_model = model.__class__.__name__ in ['GMLP', 'EMLP', "IPMLP"]
    
    # Calculate straightness metric as defined in rectified flow papers
    straightness = calculate_straightness(model_cpu, args, data)
    
    # Log straightness metric to wandb
    wandb.log({
        "straightness": straightness,
        "iteration": reflow_iteration,
        "epoch": epoch if epoch is not None else 0
    })
    
    # Create Neural ODE for evaluation - use appropriate settings for gradient models
    node = NeuralODE(
        model_cpu, 
        solver="euler", 
        sensitivity="adjoint", 
    )
    
    # Get test values
    traj_check = data.source_distr.test_values
    
    # Generate trajectories and compute metrics
    for t_max in [100, 2]:
        # For gradient models, we need to keep gradients enabled
        if is_gradient_model:
            traj = node.trajectory(
                traj_check,
                t_span=torch.linspace(0, 1, t_max),
            )
        else:
            # For regular models, we can use no_grad for efficiency
            with torch.no_grad():
                traj = node.trajectory(
                    traj_check,
                    t_span=torch.linspace(0, 1, t_max),
                )
        
        # Calculate validation metric
        val_metric = val(traj[-1, :, :], data.target_distr.distribution)
        wandb.log({
            f"val_{t_max}{extra}": val_metric,
            "iteration": reflow_iteration
        })

        if args.target_type == "normal":
            w2_distance = data.w2_distance(traj[-1, :, :])
            wandb.log({
                f"w2_{t_max}_{extra}_{direction}": w2_distance,
                "iteration": reflow_iteration
            })

def evaluate_forward(model, FM, args, data, direction="forward", reflow_iteration=0, epoch=None):
    """Evaluate the model in the forward direction"""
    return evaluate_model(model, FM, args, data, direction, reflow_iteration, epoch)

def evaluate_backward(model, FM, args, data, direction="backward", reflow_iteration=0, epoch=None):
    """Evaluate the model in the backward direction"""
    model = model.cpu()
    
    # Swap x0 and x1 for backward direction
    x0 = data.source_distr.test_values
    x1 = data.target_distr.test_values

    # Create Neural ODE for backward direction
    node = NeuralODE(
        model,
        solver="dopri5",
        sensitivity="adjoint",
        atol=1e-3,
        rtol=1e-3
    )
    
    # Generate trajectories and compute metrics
    for t_max in [100, 2]:
        # For gradient models, we need to keep gradients enabled
        if is_gradient_model:
            traj = node.trajectory(
                traj_check,
                t_span=torch.linspace(1, 0, t_max),
            )
        else:
            # For regular models, we can use no_grad for efficiency
            with torch.no_grad():
                traj = node.trajectory(
                    traj_check,
                    t_span=torch.linspace(1, 0, t_max),
                )
        
        # Calculate validation metric
        val_metric = val(traj[-1, :, :], data.source_distr.distribution)
        wandb.log({
            f"val_{t_max}_{direction}": val_metric,
            "iteration": reflow_iteration
        })

   

def val(samples, gmm):
    return - gmm.log_prob(samples).mean()


def calculate_straightness(model, args, data):
    """Calculate the straightness metric for the current model, supporting gradient-based models"""
    x0_samples = data.source_distr.test_values.cpu()
    model = model.cpu()
    
    # Check if this is a gradient-based model (GMLP or EMLP)
    is_gradient_model = model.__class__.__name__ in ['GMLP', 'EMLP', "IPMLP"]
    
    node = NeuralODE(
        model, 
        solver="dopri5", 
        sensitivity="adjoint", 
        atol=1e-4, 
        rtol=1e-4
    )
    t_span = torch.linspace(0, 1, 20).cpu()

    #x0_samples.requires_grad = True
    
    # Get trajectory with appropriate gradient settings
    traj = node.trajectory(
        x=x0_samples,
        t_span=t_span,
    )
    
    # Calculate straightness as per the rectified flow paper
    straightness = 0.0
    for i in range(traj.size(1)):
        # Calculate path length along the trajectory
        path_length = 0.0
        for t in range(1, traj.size(0)):
            path_length += torch.norm(traj[t, i] - traj[t-1, i], dim=-1).mean()
        
        # Calculate Euclidean distance between endpoints
        euclidean_dist = torch.norm(traj[-1, i] - traj[0, i], dim=-1).mean()
        
        # Straightness = euclidean distance / path length (closer to 1 is straighter)
        straightness += euclidean_dist / path_length
    
    return straightness / traj.size(1)


def track_gradient_variance(model, optimizer, FM, device, data, gradient_storage, reflow_iteration=0):
    """Modified gradient variance tracking for reflow with memory optimizations"""
    # Clear CUDA cache before starting
    torch.cuda.empty_cache()
    
    # Check if this is a hybrid model that returns a tuple
    is_hybrid_model = model.__class__.__name__ in ['HybridFlowModel']
    
    # Use a smaller subset of samples to reduce memory usage
    max_samples = 64  # Very conservative limit to reduce memory usage
    
    # Get test values on CPU first, then transfer only what's needed to GPU
    all_x0_samples = data.source_distr.test_values.cpu()
    all_x1_samples = data.target_distr.test_values.cpu()
    
    # Ensure sample sizes match between source and target
    num_samples = min(all_x0_samples.shape[0], all_x1_samples.shape[0], max_samples)
    x0_samples = all_x0_samples[:num_samples]
    x1_samples = all_x1_samples[:num_samples]
    
    # Keep model on device
    model.to(device)
    
    # Process in smaller batches to reduce memory
    batch_size = 16  # Small enough to avoid OOM
    num_batches = (num_samples + batch_size - 1) // batch_size
    
    # For each time step
    for ti in range(0, 11):
        optimizer.zero_grad()
        t_value = ti/10
        
        batch_norms = []
        
        # Process in batches
        for b in range(num_batches):
            start_idx = b * batch_size
            end_idx = min((b+1) * batch_size, num_samples)
            
            # Get batch
            batch_x0 = x0_samples[start_idx:end_idx].to(device)
            batch_x1 = x1_samples[start_idx:end_idx].to(device)
            
            # Create time tensor for this batch
            t = t_value * torch.ones(end_idx - start_idx).to(device)
            
            # Forward pass
            _, xt, ut = FM.sample_location_and_conditional_flow(batch_x0, batch_x1, t)
            
            # Get model outputs based on model type
            if is_hybrid_model:
                # For HybridFlowModel that returns (velocity, score)
                vt, st = model(t, xt)
                
                # Hybrid loss calculation
                flow_loss = torch.mean((vt - ut) ** 2)
                
                # Reshape t for broadcasting
                t_reshaped = t.view(-1, 1)
                
                # Score matching component (if needed)
                eps = (xt - (1-t_reshaped) * batch_x0)/FM.compute_sigma_t(t_reshaped)
                score_loss = torch.mean((st + eps) ** 2)
                
                # Total loss
                loss = flow_loss + score_loss
            else:
                # For standard models that return a single tensor
                vt = model(t, xt)
                loss = torch.mean((vt - ut) ** 2)
            
            # Backward pass
            loss.backward()
            
            # Get gradient norm for this batch
            batch_norm = calculate_gradient_norm(model)
            batch_norms.append(batch_norm)
            
            # Free memory
            del batch_x0, batch_x1, t, xt, ut, loss
            if is_hybrid_model:
                del vt, st, flow_loss, score_loss, eps, t_reshaped
            else:
                del vt
            
            # Don't accumulate gradients between batches
            optimizer.zero_grad()
            
            # Clear cache after each batch
            torch.cuda.empty_cache()
        
        # Average the norms from all batches
        if batch_norms:
            avg_norm = sum(batch_norms) / len(batch_norms)
            gradient_storage[t_value].append(avg_norm)
        
        # Clear cache after each time step
        torch.cuda.empty_cache()
    
    return gradient_storage

    # Log variance per reflow iteration


def gradient_plot(gradient_storage, reflow_iteration):
    gradient_var_per_t = {ti: np.var(gradient_storage[ti]) for ti in gradient_storage}
    t_values = list(gradient_var_per_t.keys())
    var_values = list(gradient_var_per_t.values())

    wandb.log({
        f"Gradient Variance vs t at {reflow_iteration}": wandb.plot.line_series(
            xs=t_values,
            ys=[var_values],
            keys=["Variance"],
            title=f"Gradient Variance vs t at {reflow_iteration}",
            xname="t"
        )
    })

def lip_log(model):
    lip_const_weights = compute_lipschitz_via_weights(model)
    print(f"Lipschitz constant (weight product): {lip_const_weights:.4f}")
    wandb.log({"lipshits": lip_const_weights})

    

def plot_combined_trajectories(models, args, data):

    """Plot combined trajectories from all reflow iterations using same x0"""
    # Generate fixed initial samples
    x0 = data.source_distr.test_values.detach().cpu().numpy()

    plt.figure(figsize=(12, 10))

    # Plot target distribution background
    standardized_samples = data.target_distr.test_values.cpu().numpy()
    sns.kdeplot(
        x=standardized_samples[:, 0],
        y=standardized_samples[:, 1],
        cmap='Greys',
        fill=True,
        alpha=0.3,
        zorder=0
    )

    # Create color map for trajectories
    colors = plt.cm.viridis(np.linspace(0, 1, len(models)))

    # Plot trajectories for each model
    for k, model in enumerate(models):
        model_cpu = model.cpu()
        node = NeuralODE(
            model_cpu, 
            solver="dopri5", 
            sensitivity="adjoint", 
            atol=1e-4, 
            rtol=1e-4
        )


        traj = node.trajectory(
            torch.tensor(x0).float(),
            t_span=torch.linspace(0, 1, 100)
        )
        
        traj = traj.detach().cpu().numpy()  # Convert trajectory to NumPy for plotting

        # Plot trajectory
        for i in range(traj.shape[1]):  # Loop over individual trajectories
            plt.plot(
                traj[:, i, 0], 
                traj[:, i, 1], 
                color=colors[k], 
                alpha=0.6, 
                linewidth=1,
                label=f'Reflow {k}' if i == 0 else None  # Add label only once per reflow iteration
            )

        # Plot start and end points
        if k == 0:
            plt.scatter(
                traj[0, :, 0], 
                traj[0, :, 1], 
                s=20, 
                c='black', 
                label='Initial Samples'
            )
        plt.scatter(
            traj[-1, :, 0], 
            traj[-1, :, 1], 
            s=20, 
            color=colors[k], 
            edgecolor='black',
            label=f'Final (k={k})'
        )

    plt.title("Comparison of Trajectories Across Reflow Iterations")
    plt.xlabel("Dimension 1")
    plt.ylabel("Dimension 2")
    plt.legend()
    plt.grid(alpha=0.3)

    # Log or save the plot
    wandb.log({"combined_trajectories": wandb.Image(plt)})
    plt.close()

def compare_reflow_iterations(models, args, gmm):
    """Compare performance metrics across different reflow iterations, supporting gradient-based models"""
    metrics = {
        'iteration': [],
        'straightness': [],
        'val_loss': []
    }
    
    for k, model in enumerate(models):
        model_cpu = model.cpu()
        
        # Check if this is a gradient-based model (GMLP or EMLP)
        is_gradient_model = model_cpu.__class__.__name__ in ['GMLP', 'EMLP']
        
        # Calculate straightness
        straightness = calculate_straightness(model_cpu, args, gmm)
        
        # Calculate validation loss
        node = NeuralODE(
            torch_wrapper(model_cpu), solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4
        )
        
        # Generate validation samples
        val_x0 = torch.normal(0, 1, (100, args.dim))
        
        # For gradient models, we need to keep gradients enabled
        if is_gradient_model:
            traj = node.trajectory(val_x0, t_span=torch.linspace(0, 1, 10))
        else:
            # For regular models, we can use no_grad for efficiency
            with torch.no_grad():
                traj = node.trajectory(val_x0, t_span=torch.linspace(0, 1, 10))
                
        val_loss = val(traj[-1, :, :], gmm)
        
        # Store metrics
        metrics['iteration'].append(k)
        metrics['straightness'].append(straightness)
        metrics['val_loss'].append(val_loss)
    
    # Log comparison table to wandb
    wandb.log({"reflow_comparison": wandb.Table(
        data=[[i, s, v] for i, s, v in zip(
            metrics['iteration'], 
            metrics['straightness'], 
            metrics['val_loss']
        )],
        columns=["Reflow Iteration", "Straightness", "Validation Loss"]
    )})
    
    # Create comparison plots
    fig, ax = plt.subplots(1, 2, figsize=(15, 6))
    
    # Straightness plot
    ax[0].plot(metrics['iteration'], metrics['straightness'], 'o-')
    ax[0].set_title("Straightness vs Reflow Iteration")
    ax[0].set_xlabel("Reflow Iteration")
    ax[0].set_ylabel("Straightness Metric")
    
    # Validation loss plot
    ax[1].plot(metrics['iteration'], metrics['val_loss'], 'o-')
    ax[1].set_title("Validation Loss vs Reflow Iteration")
    ax[1].set_xlabel("Reflow Iteration")
    ax[1].set_ylabel("Validation Loss")
    
    wandb.log({"reflow_comparison_plots": wandb.Image(fig)})
    plt.close()



def old_plot_the_mixture_density(x0, standardized_samples, t_max, node, gmm, reflow_iteration=0, is_gradient_model=False):
    """Plot trajectories and log to wandb with reflow iteration information"""
    # For gradient models, we need to keep gradients enabled
    if is_gradient_model:
        traj = node.trajectory(
            x0,
            t_span=torch.linspace(0, 1, t_max),
        )
    else:
        # For regular models, we can use no_grad for efficiency
        with torch.no_grad():
            traj = node.trajectory(
                x0,
                t_span=torch.linspace(0, 1, t_max),
            )

        # Generate a density plot for the mixture distribution
        # plt.figure(figsize=(10, 8))
        # sns.kdeplot(
        #     x=standardized_samples[:, 0].numpy(),
        #     y=standardized_samples[:, 1].numpy(),
        #     cmap='Blues',
        #     fill=True,
        #     alpha=0.5,
        # )

        # # Plot trajectory points with different colors for different reflow iterations
        # colors = ['black', 'blue', 'green', 'red', 'purple', 'orange']
        # color_idx = min(reflow_iteration, len(colors)-1)
        
        # plt.scatter(traj[0, :15, 0], traj[0, :15, 1], s=10, alpha=0.8, 
        #            c=colors[0], label="Prior sample z(S)")
        # plt.scatter(traj[-1, :15, 0], traj[-1, :15, 1], s=4, alpha=1, 
        #            c=colors[color_idx], label=f"{reflow_iteration+1}-Rectified Flow")
        
        # traj_x = traj[:, :15, 0].cpu().numpy().flatten()
        # traj_y = traj[:, :15, 1].cpu().numpy().flatten()
        # plt.scatter(traj_x, traj_y, s=0.2, alpha=0.3, c=colors[color_idx])

        # # Add legend and format plot
        # plt.legend()
        # plt.title(f"Trajectory with GMM Density - {reflow_iteration+1}-Rectified Flow")
        # plt.xlabel("Dimension 1")
        # plt.ylabel("Dimension 2")
        # plt.xticks([])
        # plt.yticks([])

        # Log metrics and plot to wandb
        val_metric = val(traj[-1, :, :], gmm)
        wandb.log({
            f"trajectory_density_k{reflow_iteration}": wandb.Image(plt), 
            f"val": val_metric,
            "t_max": t_max, 
            "iteration": reflow_iteration
        })

        #plt.close()

    # with torch.no_grad():
    #     traj = node.trajectory(
    #         x0,
    #         t_span=torch.linspace(0, 1, t_max),
    #     )

        # Generate a density plot for the mixture distribution
        # sns.kdeplot(
        #     x=standardized_samples[:, 0].numpy(),
        #     y=standardized_samples[:, 1].numpy(),
        #     cmap='Blues',
        #     fill=True,
        #     alpha=0.5,
        # )

        # # Overlay trajectory points
        # plt.scatter(traj[0, :15, 0], traj[0, :15, 1], s=10, alpha=0.8, c="black", label="Prior sample z(S)")
        # plt.scatter(traj[-1, :15, 0], traj[-1, :15, 1], s=4, alpha=1, c="blue", label="Flow z(t)")
        # traj_x = traj[:, :15, 0].cpu().numpy().flatten()
        # traj_y = traj[:, :15, 1].cpu().numpy().flatten()

        # plt.scatter(traj_x, traj_y, s=0.2, alpha=0.3, c="red")

        # # Add legend and format plot
        # plt.legend()
        # plt.title("Trajectory with GMM Density in Background")
        # plt.xlabel("Dimension 1")
        # plt.ylabel("Dimension 2")
        # plt.xticks([])
        # plt.yticks([])

        # # Log or save the plot
        # wandb.log({f"trajectory_with_density": wandb.Image(plt), "t_max":t_max, "val": val(traj[-1, :, :], gmm)})    

        #plt.close()

def x1(args, gmm):
    return gmm.sample((args.batch_size,))