import numpy as np
import wandb
import copy
import os
import torch
from absl import app, flags
from torchdyn.core import NeuralODE
from torchvision import datasets, transforms
from tqdm import trange
from utils_cifar import ema, generate_samples, infiniteloop

from torchcfm.conditional_flow_matching import (
    ConditionalFlowMatcher,
    ExactOptimalTransportConditionalFlowMatcher,
    TargetConditionalFlowMatcher,
    VariancePreservingConditionalFlowMatcher,
)
from torchcfm.models.unet.unet import UNetModelWrapper
from torch.utils.data import Subset
import random

def sample_conditional_pt(x0, x1, t, sigma):
    t = t.reshape(-1, *([1] * (x0.dim() - 1)))
    mu_t = t * x1 + (1 - t) * x0
    epsilon = torch.randn_like(x0)
    return mu_t + sigma * epsilon

def compute_conditional_vector_field(x0, x1):
    return x1 - x0

FLAGS = flags.FLAGS
flags.DEFINE_string("wandb_project", "forward_backward_cifar", help="WandB project name")
flags.DEFINE_string("wandb_run_name", "default_run", help="WandB run name")
flags.DEFINE_string("wandb_entity", "your_username", help="WandB entity (username or team)")
flags.DEFINE_string("model", "cfm", help="flow matching model type")
flags.DEFINE_string("output_dir", "./results/", help="output_directory")
# UNet
flags.DEFINE_integer("num_channel", 128, help="base channel of UNet")
flags.DEFINE_string("input_dir", "/slurm-storage/teoreu/git/variance_flows/train_cifar10/trained_models", help="output_directory")

# Training
flags.DEFINE_float("lr", 2e-4, help="target learning rate")  # TRY 2e-4
flags.DEFINE_float("sigma", 0., help="sigma")  # TRY 2e-4
flags.DEFINE_string("time_embed_type", "lin", help="whether we have quadratic param in t for v")
flags.DEFINE_float("grad_clip", 1.0, help="gradient norm clipping")
flags.DEFINE_integer(
    "total_steps", 200001, help="total training steps"
)  # Lipman et al uses 400k but double batch size
flags.DEFINE_integer("img_size", 32, help="image size")
flags.DEFINE_integer("warmup", 100, help="learning rate warmup")
flags.DEFINE_integer("batch_size", 128, help="batch size")  # Lipman et al uses 128

flags.DEFINE_integer("num_res_block", 1, help="num_res_bloks")  # Lipman et al uses 128

flags.DEFINE_integer("num_workers", 4, help="workers of Dataloader")
flags.DEFINE_float("ema_decay", 0.9999, help="ema decay rate")
flags.DEFINE_bool("parallel", False, help="multi gpu training")

# Evaluation
flags.DEFINE_integer( "save_step", 50000, help="frequency of saving checkpoints, 0 to disable during training")
flags.DEFINE_integer("eval_step", 0, help="frequency of evaluating model, 0 to disable during training")
flags.DEFINE_integer("num_images", 50000, help="the number of generated images for evaluation")
flags.DEFINE_integer("dataset_size", 50000, help="the number o  f generated images for evaluation")
flags.DEFINE_integer("seed", 42, help="seed")

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

def warmup_lr(step):
    return min(step, FLAGS.warmup) / FLAGS.warmup


import torch

def compute_divergence_hutchinson(v_theta, x_t, t):
    """
    Computes the divergence using Hutchinson's trace estimator.
    
    Args:
        v_theta: Neural network (input: x_t, t; output: velocity vector)
        x_t: Spatial coordinates (tensor with gradient tracking)
        t: Time (scalar or tensor)
    
    Returns:
        divergence: ∇⋅v_θ at (x_t, t)
    """
    x_t = x_t.clone().requires_grad_(True)
    velocity = v_theta(t, x_t)
    
    # Use a random Gaussian vector (same shape as x_t)
    eps = torch.randn_like(x_t)

    # Compute the directional derivative ⟨∇v, eps⟩
    directional_derivative = torch.autograd.grad(
        outputs=velocity, 
        inputs=x_t, 
        grad_outputs=eps,  # Hutchinson trick
        retain_graph=True, 
        create_graph=True
    )[0]

    # Estimate divergence as ⟨∇v, eps⟩ · eps
    divergence_estimate = (directional_derivative * eps).sum(dim=-1)
    
    return divergence_estimate.mean()



def train(argv):
    random.seed(FLAGS.seed)
    np.random.seed(FLAGS.seed)
    torch.manual_seed(FLAGS.seed)

    wandb.init(
        project=FLAGS.wandb_project,
        config=FLAGS.flag_values_dict(),
    )
    
    dataset = datasets.CIFAR10(
        root="./data",
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]),
    )

    # subset_size = FLAGS.dataset_size  
    # subset_indices = torch.randperm(len(dataset))[:subset_size]  # Random subset
    # dataset = Subset(dataset, subset_indices)

    dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=FLAGS.batch_size, shuffle=True, num_workers=FLAGS.num_workers, drop_last=True
    )
    datalooper = infiniteloop(dataloader)
    
    net_model = UNetModelWrapper(
        dim=(3, 32, 32), num_res_blocks=FLAGS.num_res_block, num_channels=FLAGS.num_channel, channel_mult=[1, 2, 2, 2],
        num_heads=4, num_head_channels=64, attention_resolutions="16", dropout=0.1,  time_embed_type =FLAGS.time_embed_type
    ).to(device)

    # Load the model
    # PATH = f"{FLAGS.input_dir}/{FLAGS.model}_cifar10_weights_step_400000.pt"
    # print("path: ", PATH)
    # checkpoint = torch.load(PATH)
    # state_dict = checkpoint["ema_model"]
    # try:
    #     net_model.load_state_dict(state_dict)
    # except RuntimeError:
    #     from collections import OrderedDict

    #     new_state_dict = OrderedDict()
    #     for k, v in state_dict.items():
    #         new_state_dict[k[7:]] = v
    #     net_model.load_state_dict(new_state_dict)

    optim = torch.optim.Adam(net_model.parameters(), lr=FLAGS.lr)
    sched = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=warmup_lr)
    
    net_node = NeuralODE(net_model, solver="euler", sensitivity="adjoint")
    
    model_size = sum(p.numel() for p in net_model.parameters())
    print(f"Model params: {model_size / 1e6:.2f}M")

    gradient_storage = {ti/10: [] for ti in range(11)}

    sigma = FLAGS.sigma
    if FLAGS.model == "otcfm":
        FM = ExactOptimalTransportConditionalFlowMatcher(sigma=sigma)
    elif FLAGS.model == "cfm":
        FM = ConditionalFlowMatcher(sigma=sigma)
    elif FLAGS.model == "fm":
        FM = TargetConditionalFlowMatcher(sigma=sigma)
    elif FLAGS.model == "vpfm":
        FM = VariancePreservingConditionalFlowMatcher(sigma=sigma)
    else:
        raise NotImplementedError(f"Unknown model {FLAGS.model}")
    
    with trange(FLAGS.total_steps, dynamic_ncols=True) as pbar:
        for step in pbar:
            optim.zero_grad()
            x1 = next(datalooper).to(device)
            x0 = torch.randn_like(x1)
            t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)
            vt = net_model(t, xt)
            loss = torch.mean((vt - ut) ** 2)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(net_model.parameters(), FLAGS.grad_clip)
            optim.step()
            sched.step()
            
            if step % 10 == 0:
                wandb.log({"loss": loss.item(), "lr": optim.param_groups[0]["lr"]}, step=step)

            if step % 100 == 0:
                div = compute_divergence_hutchinson(net_model, xt, t)

                wandb.log({"step": step, "div": div})     
            if step % FLAGS.save_step == 0:
                generate_samples(net_node, net_model, "wandb_artifacts", step, net_="normal")
                torch.save({
                    "net_model": net_model.state_dict(),
                    "sched": sched.state_dict(),
                    "optim": optim.state_dict(),
                    "step": step,
                },f"/slurm-storage/teoreu/git/variance_flows/train_cifar10/trained_models/{FLAGS.model}_{FLAGS.sigma}_{FLAGS.dataset_size}_{FLAGS.time_embed_type}_{FLAGS.num_res_block}__cifar10_weights.pt")     

            #     x1 = next(datalooper).to(device)
            #     x0 = torch.randn_like(x1)
            #     if step > FLAGS.save_step * 4 and step % 500 == 0 :
            #         for ti in range(0, 11):
            #             optim.zero_grad()
                        
            #             t = ti/10 * torch.ones(x0.shape[0]).type_as(x0)

            #             t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1, t)
                        
            #             vt = net_model(t, xt)
            #             loss = torch.mean((vt - ut) ** 2)
            #             loss.backward()

            #             total_norm = 0
            #             for p in net_model.parameters():
            #                 if p.grad is not None:
            #                     param_norm = p.grad.data.norm(2)
            #                     total_norm += param_norm.item() ** 2
            #             total_norm = total_norm ** 0.5
            #             gradient_storage[ti/10].append(total_norm)

            #             optim.zero_grad() 

                    
            #         # wandb.save(f"wandb_artifacts/cifar10_weights_step_{step}.pt")
            # gradient_var_per_t = {ti: np.var(gradient_storage[ti]) for ti in gradient_storage}
            # t_values = list(gradient_var_per_t.keys())
            # var_values = list(gradient_var_per_t.values())

            # # Log as a single line plot in W&B
            # wandb.log({
            #     "Gradient Variance vs t": wandb.plot.line_series(
            #         xs=t_values,
            #         ys=[var_values],
            #         keys=["Variance"],
            #         title="Gradient Variance vs t",
            #         xname="t"
            #     )
            # })

if __name__ == "__main__":
    app.run(train)
