import argparse
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import wandb
import math
import os
import time
import matplotlib.pyplot as plt
import numpy as np
import ot as pot
import torchdyn
from torchdyn.core import NeuralODE
from torchdyn.datasets import generate_moons
from torchcfm.conditional_flow_matching import *
from torchcfm.utils import *
from torch.distributions import Categorical, MultivariateNormal, MixtureSameFamily, Independent, Normal
from validation_utils import *
import copy
from data import Synthetic
from nn import WMLP, IPMLP
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def plot_combined_trajectories(models, args, data):
    """Plot combined trajectories from all reflow iterations using same x0"""
    # Clear memory before plotting
    torch.cuda.empty_cache()
    
    # Set the matplotlib backend to 'Agg' for non-interactive environments like SLURM
    import matplotlib
    matplotlib.use('Agg')
    # Generate fixed initial samples - limit the number to reduce memory usage
    x0 = data.source_distr.test_values[:100].detach().cpu().numpy()  # Limit to 100 samples

    plt.figure(figsize=(12, 10))

    # Plot target distribution background
    standardized_samples = data.target_distr.test_values.cpu().numpy()
    sns.kdeplot(
        x=standardized_samples[:, 0],
        y=standardized_samples[:, 1],
        cmap='Greys',
        alpha=0.3,
        zorder=0,
        fill=True
    )

    # Create color map for trajectories
    colors = plt.cm.viridis(np.linspace(0, 1, len(models)))

    # Plot trajectories for each model
    for k, model in enumerate(models):
        # Clear memory before processing each model
        torch.cuda.empty_cache()
        model_cpu = model.cpu()
        node = NeuralODE(
            model_cpu, 
            solver="dopri5", 
            sensitivity="adjoint", 
            atol=1e-4, 
            rtol=1e-4
        )


        traj = node.trajectory(
            torch.tensor(x0).float(),
            t_span=torch.linspace(0, 1, 100)
        )
        
        traj = traj.detach().cpu().numpy()  # Convert trajectory to NumPy for plotting

        # Plot trajectory - limit number of trajectories to reduce memory usage
        for i in range(min(traj.shape[1], 100)):  # Limit to 100 trajectories
            plt.plot(
                traj[:, i, 0], 
                traj[:, i, 1], 
                color=colors[k], 
                alpha=0.6, 
                linewidth=1,
                label=f'Reflow {k}' if i == 0 else None  # Add label only once per reflow iteration
            )

        # Plot start and end points - limit number to reduce memory usage
        if k == 0:
            plt.scatter(
                traj[0, :100, 0],  # Limit to first 100 points
                traj[0, :100, 1], 
                s=20, 
                c='black', 
                label='Initial Samples'
            )
        plt.scatter(
            traj[-1, :100, 0],  # Limit to first 100 points
            traj[-1, :100, 1], 
            s=20, 
            color=colors[k], 
            edgecolor='black',
            label=f'Final (k={k})'
        )

        # Free memory
        del traj, node, model_cpu
        torch.cuda.empty_cache()

    plt.title("Comparison of Trajectories Across Reflow Iterations")
    plt.xlabel("Dimension 1")
    plt.ylabel("Dimension 2")
    plt.legend()
    plt.grid(alpha=0.3)

    # Log or save the plot
    wandb.log({"combined_trajectories": wandb.Image(plt)})
    plt.close()


class HybridFlowModel(torch.nn.Module):
    def __init__(self, dim=2, w=64):
        super().__init__()
        self.velocity_net = WMLP(dim=dim, w=w, time_varying=True)
        self.score_net = WMLP(dim=dim, w=w, time_varying=True)
        
    def forward(self, t, x, *args, **kwargs):
        return self.velocity_net(t, x), self.score_net(t, x)


class BFromVS(torch.nn.Module):
    """
    Class for turning a velocity model $v$ and a score model $s$ into a drift model $b$.
    If one-sided interpolation, gg_dot should be replaced with alpha*alpha_dot.
    """
    def __init__(self, v, s, sigma=1.0) -> None:
        super(BFromVS, self).__init__()
        self.v = v
        self.s = s
        self.sigma = sigma
        self.gg_dot = lambda t: (self.sigma**2/2)*(1-2*t)
        
    def forward(self, t, x, *args, **kargs):
        return self.v(t, x) - self.gg_dot(t)*self.s(t, x)
    
    """
    Class for turning a velocity model $v$ and a score model $s$ into a drift model $b$.
    If one-sided interpolation, gg_dot should be replaced with alpha*alpha_dot.
    """
    def __init__(self, v, s, sigma=1.0) -> None:
        super(BFromVS, self).__init__()
        self.v = v
        self.s = s
        self.sigma = sigma
        self.gg_dot = lambda t: (self.sigma**2/2)*(1-2*t)

        
    def forward(self, t, x, *args, **kargs):
        return self.v(t, x) - self.gg_dot(t)*self.s(t, x)

def track_gradient_variance(model, optimizer, FM, device, data, gradient_storage, reflow_iteration):
    """Modified gradient variance tracking for reflow with memory optimizations"""
    # Clear CUDA cache before starting
    torch.cuda.empty_cache()
    
    # Use a smaller subset of samples to reduce memory usage
    max_samples = 500  # Limit the number of samples to reduce memory usage
    
    # Get test values on CPU first, then transfer only what's needed to GPU
    all_x0_samples = data.source_distr.test_values.cpu()
    all_x1_samples = data.target_distr.test_values.cpu()
    
    # Limit number of samples used
    num_samples = min(all_x0_samples.shape[0], max_samples)
    x0_samples = all_x0_samples[:num_samples]
    x1_samples = all_x1_samples[:num_samples]
    
    # Keep model on device
    model.to(device) 
    
    # Process in smaller batches to reduce memory
    batch_size = 100  # Small enough to avoid OOM
    num_batches = (num_samples + batch_size - 1) // batch_size
    
    # For each time step
    for ti in range(0, 11):
        optimizer.zero_grad()
        t_value = ti/10
        
        batch_norms = []
        
        # Process in batches
        for b in range(num_batches):
            start_idx = b * batch_size
            end_idx = min((b+1) * batch_size, num_samples)
            
            # Get batch
            batch_x0 = x0_samples[start_idx:end_idx].to(device)
            batch_x1 = x1_samples[start_idx:end_idx].to(device)
            
            # Create time tensor for this batch
            t = t_value * torch.ones(end_idx - start_idx).to(device)
            
            # Forward pass
            _, xt, ut = FM.sample_location_and_conditional_flow(batch_x0, batch_x1, t)

            # Get model outputs
            vt, st = model(t, xt)
            
            # Hybrid loss calculation
            flow_loss = torch.mean((vt - ut) ** 2)

            # Reshape t for broadcasting
            t = t.view(-1, 1)
            
            # Score matching component
            eps = (xt - (1-t) * batch_x0)/FM.compute_sigma_t(t)
            score_loss = torch.mean((st + eps) ** 2)
        
            # Total loss
            total_loss = flow_loss + score_loss
            
            # Backward pass
            total_loss.backward()
        
            # Get gradient norm for this batch
            batch_norm = calculate_gradient_norm(model)
            batch_norms.append(batch_norm)
            
            # Free memory
            del batch_x0, batch_x1, t, xt, ut, vt, st, eps, flow_loss, score_loss, total_loss
            
            # Don't accumulate gradients between batches
            optimizer.zero_grad()

            # Clear cache after each batch
            if b % 5 == 0:  # Every 5 batches
                torch.cuda.empty_cache()
        
        # Average the norms from all batches
        if batch_norms:
            avg_norm = sum(batch_norms) / len(batch_norms)
            gradient_storage[t_value].append(avg_norm)
        
        # Clear cache after each time step
        torch.cuda.empty_cache()

    return gradient_storage

def train_epoch(model, optimizer, FM, args, device, epoch, x0, x1, reflow_iteration=0):
    optimizer.zero_grad()
    x0 = x0.to(device)
    x1 = x1.to(device)
    model = model.cuda()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


    t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)


    vt, st = model(t, xt)
    
    # Hybrid loss calculation
    flow_loss = torch.mean((vt - ut) ** 2)

    t = t.view(-1, 1)
    # Score matching component
    eps = (xt - (1-t) * x0)/FM.compute_sigma_t(t)

    score_loss = torch.mean((st + eps) ** 2)
    
    total_loss = flow_loss +  score_loss
    total_loss.backward()
    #torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

    total_norm = calculate_gradient_norm(model)
    
    wandb.log({
        "epoch": epoch,
        "gradient_norm": total_norm,
        "loss": total_loss.item(),
        "iteration": reflow_iteration
    })
    
    optimizer.step()


def main(args):
    data = Synthetic(args)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    gradient_storage = {ti/10: [] for ti in range(11)}
   
    model = HybridFlowModel()

    FM = SchrodingerBridgeConditionalFlowMatcher(sigma=args.sigma)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    models = []
    wandb.init(project="albergo_reflow_",
              name=f"{args.method}_dim_{args.dim}_w_{args.w}_sigma_{args.sigma}_kreflow_{args.kreflow}",
              config=args)
    

    # Reflow iterations
    for k in range(0, args.kreflow + 1):
        dataloader = torch.utils.data.DataLoader(
            data.all_pairs, 
            batch_size=args.batch_size, 
            shuffle=True,
            drop_last=True
        )
        print(f"Starting reflow iteration {k}...")
        gradient_storage = {ti/10: [] for ti in range(11)}
        
        # Train the k-rectified flow
        if k == 0:
            epochs = args.num_epochs
        else:
            epochs = args.re_epoch
        epoch = 0
        while(epoch<epochs):
            for batch_idx, (x0_batch, x1_batch) in enumerate(dataloader):
                train_epoch(model, optimizer, FM, args, device, epoch, x0_batch, x1_batch, reflow_iteration=k)
                
                    # Periodically evaluate during reflow training
                if epoch % 100 == 0:
                    model_cpu = model.cpu()
                    b = BFromVS(
                        v=lambda t,x: model_cpu(t,x)[0],  # t first, x second
                            s=lambda t,x: model_cpu(t,x)[1],
                            sigma=args.sigma)
                    evaluate_model(b, args, data, reflow_iteration=k, epoch=epoch)
                #     compute_lipschitz_via_weights(model)

                if epoch % 10 == 0 and 4 * epoch >  3 * args.re_epoch:
                    gradient_storage = track_gradient_variance(model, optimizer, FM, device,  data, gradient_storage, reflow_iteration=k)
                epoch += 1
                # if epoch > args.re_epoch:
                #     break

        model_cpu = model.cpu()
        b = BFromVS(
            v=lambda t,x: model_cpu(t,x)[0],  # t first, x second
            s=lambda t,x: model_cpu(t,x)[1],
            sigma=args.sigma)  
        data.reflow(b, args)

        models.append(copy.deepcopy(b))        
    #     # Final evaluation of k-rectified flow
        evaluate_model(b, args, data, reflow_iteration=k)
        gradient_plot(gradient_storage, reflow_iteration=k)


    plot_combined_trajectories(models, args, data)


    


# if __name__ == "__main__":
    # parser = argparse.ArgumentParser(description="Train and log gradient statistics")
    # parser.add_argument("--dim", type=int, default=2, help="Dimension of the data")
    # parser.add_argument("--components", type=int, default=4, help="Number of Gaussian medians")
    # parser.add_argument("--sigma", type=float, default=0.01, help="Sigma value for ConditionalFlowMatcher")
    # parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
    # parser.add_argument("--num_batches", type=int, default=200, help="Number of batches")
    # parser.add_argument("--num_epochs", type=int, default=50000, help="Number of epochs")
    # parser.add_argument("--re_epoch", type=int, default=20000, help="Number of epochs")
    # parser.add_argument("--source_type", type=str, default="standard", help="Number of epochs")
    # parser.add_argument("--target_type", type=str, default="gmm", help="Number of epochs")

    # parser.add_argument("--lr", type=float, default=0.0001, help="Learning rate")
    # parser.add_argument("--w", type=int, default=32)
    # parser.add_argument("--time_var", type=int, default=1)
    # parser.add_argument("--method", type=str, default="ccfm")
    # parser.add_argument("--kreflow", type=int, default=4, help="Number of reflow iterations")

    # args = parser.parse_args()

    # main(args)