import argparse
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import wandb
import math
import os
import time
import matplotlib.pyplot as plt
import numpy as np
import ot as pot
import torchdyn
from torchdyn.core import NeuralODE
from torchdyn.datasets import generate_moons
from torchcfm.conditional_flow_matching import *
from torchcfm.models.models import *
from torchcfm.utils import *
from torch.distributions import Categorical, MultivariateNormal, MixtureSameFamily, Independent, Normal
from validation_utils import *
import copy
from torch import nn

class IPCCFM(nn.Module):
    def __init__(self, dim, w=32):
        super().__init__()
        self.NN = nn.Sequential(
            torch.nn.Linear(dim + 1 , w),
            torch.nn.SELU(),
            torch.nn.Linear(w, w),
            torch.nn.SELU(),
            torch.nn.Linear(w, w),
            torch.nn.SELU(),
            torch.nn.Linear(w, dim),
        )
        
    def forward(self, t, x, *args, **kwargs):
        t_ = t.view(-1, 1).expand(x.size(0), 1).to(x.dtype)
        s = self.NN(torch.cat([x, t_], dim=1))
        energy = (x.T @ s).sum(dim=1, keepdim=True)  # (B,1)
        grad_x = torch.autograd.grad(
            outputs=energy,
            inputs=x,
            grad_outputs=torch.ones_like(energy),
            create_graph=True,
            retain_graph=True
        )[0]
        
        return grad_x

SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



def generate_pairs(model, reflow_x1_samples, num_pairs, args, device):
    """Generate (x0, x1) pairs using the current model"""
    x0_samples = generate_x0_from_x1(model, reflow_x1_samples, device)
    return x0_samples.detach(), reflow_x1_samples.detach()


def generate_x0_from_x1(model, x1_samples, device):
    model_cpu = model.cpu()
    x1_cpu = x1_samples.cpu()
    x1_cpu.requires_grad= True
    
    # Create Neural ODE
    node = NeuralODE(model_cpu, solver="dopri5", sensitivity="adjoint", atol=1e-4, rtol=1e-4)
    

        # Run ODE backwards from t=1 to t=0
    traj = node.trajectory(
        x1_cpu,
        t_span=torch.linspace(1, 0, 100),  # From t=1 to t=0
    )
    x0_samples = traj[-1]  # Get samples at t=0

    # Return model to device and send samples to device
    model.to(device)
    return x0_samples.detach().to(device)




def train_epoch(model, optimizer, FM, args, device, gmm, epoch, x0_samples, x1_samples, reflow_iteration=0):
    optimizer.zero_grad()
    torch.autograd.set_detect_anomaly(True)
    
    # Sample randomly from initial and target distributions
    if x0_samples is False:
        x0 = torch.normal(0, 1, (args.batch_size, args.dim)).to(device)
        x1_pairs = x1(args, gmm).to(device)

    else: 
        x0 = x0_samples
        x1_pairs= x1_samples

    model.to(device)
    
    t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1_pairs)
    
    # Changed velocity computation
    xt.requires_grad = True
    vt = model(t,xt)

    loss = torch.mean((vt - ut) ** 2)

    loss.backward(retain_graph=False)
    
    total_norm = calculate_gradient_norm(model)
    
    wandb.log({
        "epoch": epoch,
        "gradient_norm": total_norm,
        "loss": loss.item(),
        "iteration": reflow_iteration
    })
    
    optimizer.step()



def main(args):
    comp = Independent(Normal(torch.randn(args.components, args.dim), torch.rand(args.components, args.dim)), 1)
    mix = Categorical(torch.ones(args.components))
    gmm = MixtureSameFamily(mix, comp)

    model = IPCCFM(dim=args.dim, w=args.w)
    models = []
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    gradient_x0 =  torch.normal(0, 1, (args.batch_size, args.dim)).to(device)
    gradient_x1 = gmm.sample((args.batch_size,)).to(device)

    traj_check = torch.normal(0, 1, (100, args.dim))
    traj_check_cuda = torch.clone(traj_check).to(device)

    gradient_storage = {ti/10: [] for ti in range(11)}

    print(f"Using device: {device}")

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)


    FM = ConditionalFlowMatcher(sigma=0)

    wandb.init(project="test_ccfm",
              name=f"{args.method}_dim_{args.dim}_w_{args.w}_sigma_{args.sigma}_kreflow_{args.kreflow}",
              config=args)
    

    # Train initial 1-rectified flow model
    print("Training 1-rectified flow model...")
    gradient_storage = {ti/10: [] for ti in range(11)}
    for epoch in range(args.num_epochs):
        train_epoch(model, optimizer, FM, args, device, gmm, epoch, False, False, reflow_iteration=0)
        if epoch % 100 == 0:
            compute_lipschitz_via_weights(model)
            evaluate_model(model, args, gmm, reflow_iteration=0, epoch=epoch)
        if epoch % 10 and 4 * epoch > 3 * args.num_epochs: 
            gradient_storage = track_gradient_variance(model, optimizer, FM, device, gradient_x0, gradient_x1, gradient_storage, reflow_iteration=0)
    # Store the initial model

    gradient_plot(gradient_storage, reflow_iteration=0)
    models.append(model)
    
    
    # Reflow iterations
    current_model = model
    for k in range(1, args.kreflow + 1):
        print(f"Starting reflow iteration {k}...")
        gradient_storage = {ti/10: [] for ti in range(11)}

        new_model = copy.deepcopy(current_model).to(device)
        new_optimizer = torch.optim.Adam(new_model.parameters(), lr=args.lr)
        
        # Train the k-rectified flow
        for epoch in range(args.re_epoch):
            x1_samples = x1(args, gmm).to(device)  # Sample from target distribution
            x0_samples = generate_x0_from_x1(current_model, x1_samples, device)
            train_epoch(new_model, new_optimizer, FM, args, device,  gmm, epoch, x0_samples, x1_samples, reflow_iteration=k)
            
            # Periodically evaluate during reflow training
            if epoch % 100 == 0:
                evaluate_model(new_model, args, gmm, reflow_iteration=k, epoch=epoch)
                compute_lipschitz_via_weights(new_model)

            if epoch % 10 == 0 and 4 * epoch >  3 * args.re_epoch:
                gradient_storage = track_gradient_variance(new_model, new_optimizer, FM, device, generate_x0_from_x1(current_model, gradient_x1, device), gradient_x1, gradient_storage, reflow_iteration=0)

        models.append(new_model)
        current_model = new_model
        
        # Final evaluation of k-rectified flow
        evaluate_model(new_model, args, gmm, reflow_iteration=k)
        gradient_plot(gradient_storage, reflow_iteration=k)


    plot_combined_trajectories(models, args, gmm)
    model_dict = {k: m.state_dict() for k, m in enumerate(models)}
    save_path = f"src/train_synthetic/saved_models/{args.method}_dim_{args.dim}_w_{args.w}_sigma_{args.sigma}_kreflow{args.kreflow}.pt"
    torch.save(model_dict, save_path)
    print(f"Saved all models to {save_path}")


# # Reconstruct models
# loaded_models = []
# for k in range(len(model_dict)):
#     model = MLP(dim=args.dim, w=args.w, time_varying=True).to(device)
#     model.load_state_dict(model_dict[k])
#     model.eval()
#     loaded_models.append(model)
    


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train and log gradient statistics")
    parser.add_argument("--dim", type=int, default=2, help="Dimension of the data")
    parser.add_argument("--components", type=int, default=4, help="Number of Gaussian medians")
    parser.add_argument("--sigma", type=float, default=0.1, help="Sigma value for ConditionalFlowMatcher")
    parser.add_argument("--batch_size", type=int, default=64, help="Batch size")
    parser.add_argument("--num_batches", type=int, default=200, help="Number of batches")
    parser.add_argument("--num_epochs", type=int, default=50000, help="Number of epochs")
    parser.add_argument("--re_epoch", type=int, default=10000, help="Number of epochs")

    parser.add_argument("--lr", type=float, default=0.0001, help="Learning rate")
    parser.add_argument("--w", type=int, default=16)
    parser.add_argument("--time_var", type=int, default=1)
    parser.add_argument("--method", type=str, default="ccfm")
    parser.add_argument("--kreflow", type=int, default=4, help="Number of reflow iterations")

    args = parser.parse_args()

    main(args)