import argparse
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import wandb
import math
import os
import time
import matplotlib.pyplot as plt
import numpy as np
import ot as pot
import torchdyn
from torchdyn.core import NeuralODE
from torchdyn.datasets import generate_moons
from torchcfm.conditional_flow_matching import *
from torchcfm.utils import *
from torch.distributions import Categorical, MultivariateNormal, MixtureSameFamily, Independent, Normal
from validation_utils import *
import copy
from data import Synthetic
from nn import WMLP, IPMLP, GMLP, EMLP  
import torch.nn as nn
from albergo import HybridFlowModel, BFromVS
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def train_epoch(model, optimizer, FM, args, device, epoch, x0, x1, reflow_iteration=0, use_hybrid=False):
    optimizer.zero_grad()
    x0 = x0.to(device)
    x1 = x1.to(device)
    model = model.to(device)

    t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)
    
    if use_hybrid:
        # For HybridFlowModel
        vt, st = model(t, xt)
        # Combined loss for both velocity and score networks
        velocity_loss = torch.mean((vt - ut) ** 2)
        
        # Compute score target using the conditional score formula
        # For Schrodinger Bridge, conditional score is derived from velocity and noise
        t = t.view(-1, 1)
        eps = (xt - (1-t) * x0)/FM.compute_sigma_t(t)
        score_loss = torch.mean((st + eps) ** 2)
        
        loss = velocity_loss + score_loss
    else:
        # Original single network model
        vt = model(t, xt)
        loss = torch.mean((vt - ut) ** 2)

    loss.backward()
    total_norm = calculate_gradient_norm(model)
    
    log_dict = {
        "epoch": epoch,
        "gradient_norm": total_norm,
        "loss": loss.item(),
        "iteration": reflow_iteration
    }
    
    if use_hybrid:
        log_dict.update({
            "velocity_loss": velocity_loss.item(),
            "score_loss": score_loss.item()
        })
        
    wandb.log(log_dict)
    
    optimizer.step()


def main(args):
    data = Synthetic(args)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    gradient_storage = {ti/10: [] for ti in range(11)}
    use_hybrid = args.method == "hybrid"

    # Model selection based on method argument
    if args.method == "cfm":
        model = WMLP(dim=args.dim, w=args.w, time_varying=True)
        FM = ConditionalFlowMatcher(sigma=0.0)       
    elif args.method == "sb":
        model = WMLP(dim=args.dim, w=args.w, time_varying=True)
        FM = SchrodingerBridgeConditionalFlowMatcher(sigma=args.sigma)
    elif args.method == "hybrid":
        model = HybridFlowModel(dim=args.dim, w=args.w)
        FM = SchrodingerBridgeConditionalFlowMatcher(sigma=args.sigma)
    elif args.method == "gmlp":
        # Gradient-based approach using sum pooling
        model = GMLP(dim=args.dim, w=args.w)
        FM = SchrodingerBridgeConditionalFlowMatcher(sigma=args.sigma) if args.sigma > 0 else ConditionalFlowMatcher(sigma=0.0)
    elif args.method == "emlp":
        # Energy-based approach with direct gradient modeling
        model = EMLP(dim=args.dim, w=args.w)
        FM = SchrodingerBridgeConditionalFlowMatcher(sigma=args.sigma) if args.sigma > 0 else ConditionalFlowMatcher(sigma=0.0)
    else:
        model = WMLP(dim=args.dim, w=args.w, time_varying=True)
        FM = ConditionalFlowMatcher(sigma=0.0)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    models = []
    wandb.init(project=f"test_flow_reflow_{args.dim}",
              name=f"{args.method}_dim_{args.dim}_w_{args.w}_sigma_{args.sigma}_kreflow_{args.kreflow}",
              config=args)
    
    # Reflow iterations
    for k in range(0, args.kreflow + 1):
        dataloader = torch.utils.data.DataLoader(
            data.all_pairs, 
            batch_size=args.batch_size, 
            shuffle=True,
            drop_last=True
        )
        print(f"Starting reflow iteration {k}...")
        gradient_storage = {ti/10: [] for ti in range(11)}
        
        # Train the k-rectified flow
        if k == 0:
            epochs = args.num_epochs
        else:
            epochs = args.re_epoch
        epoch = 0
        while(epoch<epochs):

            data.update_pairs(reflow_iteration=k)

            for batch_idx, (x0_batch, x1_batch) in enumerate(dataloader):
                train_epoch(model, optimizer, FM, args, device, epoch, x0_batch, x1_batch, 
                           reflow_iteration=k, use_hybrid=use_hybrid)
                
                # Periodically evaluate during reflow training
                if epoch % 100 == 0:
                    if use_hybrid:
                        # Create drift model from hybrid model for evaluation
                        model_cpu = model.cpu()
                        drift_eval_model = BFromVS(
                            v=lambda t,x: model_cpu(t,x)[0],  # velocity component
                            s=lambda t,x: model_cpu(t,x)[1],  # score component
                            sigma=args.sigma
                        )
                        evaluate_model(drift_eval_model, args, data, reflow_iteration=k, epoch=epoch)
                        model = model.to(device)  # Move back to device after evaluation
                    else:
                        evaluate_model(model, args, data, reflow_iteration=k, epoch=epoch)
                        compute_lipschitz_via_weights(model)

                if epoch % 10 == 0 and 4 * epoch > 3 * args.re_epoch:
                    gradient_storage = track_gradient_variance(model, optimizer, FM, device, data, 
                                                              gradient_storage, reflow_iteration=k)
                epoch += 1

        # Update data for next reflow iteration
        if use_hybrid:
            model_cpu = model.cpu()
            drift_model = BFromVS(
                v=lambda t,x: model_cpu(t,x)[0],  # velocity component
                s=lambda t,x: model_cpu(t,x)[1],  # score component
                sigma=args.sigma
            )
            data.forward(drift_model, args)
            models.append(copy.deepcopy(drift_model))
            # Final evaluation
            evaluate_model(drift_model, args, data, reflow_iteration=k)
        else:
            data.forward(model, args)
            models.append(copy.deepcopy(model))
            # Final evaluation
            evaluate_model(model, args, data, reflow_iteration=k)
            
        gradient_plot(gradient_storage, reflow_iteration=k)


    plot_combined_trajectories(models, args, data)


    


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Train and log gradient statistics")
    parser.add_argument("--dim", type=int, default=2, help="Dimension of the data")
    parser.add_argument("--components", type=int, default=4, help="Number of Gaussian medians")
    parser.add_argument("--sigma", type=float, default=0.01, help="Sigma value for SchrodingerBridge or Hybrid model")
    parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
    parser.add_argument("--num_batches", type=int, default=200, help="Number of batches")
    parser.add_argument("--num_epochs", type=int, default=500, help="Number of epochs")
    parser.add_argument("--re_epoch", type=int, default=100, help="Number of epochs for reflow iterations")
    parser.add_argument("--source_type", type=str, default="standard", help="Source distribution type")
    parser.add_argument("--target_type", type=str, default="gmm", help="Target distribution type")
    parser.add_argument("--lr", type=float, default=0.0001, help="Learning rate")
    parser.add_argument("--w", type=int, default=32, help="Width of neural networks")
    parser.add_argument("--time_var", type=int, default=1, help="Time varying parameter")
    parser.add_argument("--method", type=str, default="hybrid", 
                      choices=["cfm", "sb", "hybrid", "gmlp", "emlp"], 
                      help="Flow matching method (ccfm: Conditional Flow Matching, sb: Schrodinger Bridge, hybrid: Hybrid Flow, gmlp: Gradient MLP, emlp: Energy MLP)")
    parser.add_argument("--kreflow", type=int, default=4, help="Number of reflow iterations")

    args = parser.parse_args()

    main(args)