import argparse
import torch
import matplotlib.pyplot as plt
import numpy as np
from torch import nn
import torch.distributions as D
from torchdiffeq import odeint
import ot
import torch.nn.functional as F
from models import *
from tqdm import tqdm
from fab.utils.plotting import plot_contours, plot_marginal_pair
from fab.target_distributions.gmm import GMM
from torch.utils.tensorboard import SummaryWriter
from eval_utils import *
import os
def parse_args():
    parser = argparse.ArgumentParser(description="GMM Gradient Flow Training Script")
    parser.add_argument("--dim", type=int, default=2, help="Dimension of the problem")
    parser.add_argument("--n_mixes", type=int, default=40, help="Number of mixtures in GMM")
    parser.add_argument("--loc_scaling", type=float, default=40.0, help="Scale of the problem")
    parser.add_argument("--log_var_scaling", type=float, default=1.0, help="Variance of each Gaussian")
    parser.add_argument("--batch_size", type=int, default=4096, help="Batch size for training")
    parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate")
    parser.add_argument("--decay_rate", type=float, default=0.98, help="Learning rate decay rate")
    parser.add_argument("--ntrain", type=int, default=400000, help="Number of training iterations")
    parser.add_argument("--inner_steps", type=int, default=2, help="Number of inner steps")
    parser.add_argument("--sigma", type=float, default=1.0, help="Sigma for initial distribution")
    parser.add_argument("--beta_min", type=float, default=0.1, help="Minimum beta value")
    parser.add_argument("--beta_max", type=float, default=20.0, help="Maximum beta value")
    parser.add_argument("--eval_samples", type=int, default=1000, help="Evaluation Samples")
    return parser.parse_args()

class TorchWrapper(nn.Module):
    def __init__(self, action, schedule, gmm, beta):
        super().__init__()
        self.action = action
        self.schedule = schedule
        self.gmm = gmm
        self.beta = beta

    def forward(self, t, x):
        with torch.set_grad_enabled(True):
            x.requires_grad_(True)
            time = t.repeat(x.shape[0])[:, None]
            
            learned = time * self.action(torch.cat([x, time], 1)).reshape(x.shape[0], 1) + \
                      ((time) * self.schedule(time).reshape(x.shape[0], 1) + (1 - time)) * (self.gmm.log_prob(x)).reshape(x.shape[0], 1)
            grad = torch.autograd.grad(torch.sum(learned), x, create_graph=True)[0]

        return -self.beta(time) * (grad.reshape(x.shape[0], x.shape[1]) + x.reshape(x.shape[0], x.shape[1]))

def beta(t, beta_min, beta_max):
    return 0.5 * (beta_min + (beta_max - beta_min) * (t))

def beta_int(t, beta_min, beta_max):
    return beta_min * t + (beta_max - beta_min) / 2 * t**2

def main(args):
    device = "cuda"
    os.makedirs("nets_gmm_GF", exist_ok=True)
    os.makedirs("imgsGMM_GF", exist_ok=True)
 
    torch.manual_seed(0)
    gmm = GMM(dim=args.dim, n_mixes=args.n_mixes,
              loc_scaling=args.loc_scaling, log_var_scaling=args.log_var_scaling,
              use_gpu=True, true_expectation_estimation_n_samples=int(1e5))
    gmm.to(device)


    psi = MLP2(dim=args.dim, out_dim=1, time_varying=True, w=256).to(device)
    schedule = MLP3(dim=1, out_dim=1, time_varying=False, w=256).to(device)

   
    optimizer = torch.optim.Adam([*psi.parameters(), *schedule.parameters()], lr=args.lr)
    lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=args.decay_rate)


    writer = SummaryWriter(log_dir=f"runs_GMM/GF")

    progress_bar = tqdm(range(args.ntrain), total=args.ntrain, position=0, leave=True)

    for k in progress_bar:
        total_loss = 0
        x0 = torch.randn((args.batch_size, args.dim), device=device) * args.sigma
        x1 = torch.rand((args.batch_size, args.dim), device=device) * (100) - (50) * torch.ones((args.batch_size, args.dim), device=device)

        for j in range(args.inner_steps - 1):
            optimizer.zero_grad()
            time = torch.rand((args.batch_size, 1), device=device) * (0.999) + 0.001
            xt = torch.sqrt((1 - torch.exp(-beta_int(time, args.beta_min, args.beta_max)))) * torch.randn((args.batch_size, args.dim), device=device) * args.sigma + \
                 torch.exp(0.5 * (-beta_int(time, args.beta_min, args.beta_max))) * x1

            xt.requires_grad_(True)
            time.requires_grad_(True)

            f = (time) * psi(torch.cat([xt, time], 1)).reshape(args.batch_size, 1) + \
                ((time) * schedule(time).reshape(xt.shape[0], 1) + (1 - time)) * (gmm.log_prob(xt)).reshape(xt.shape[0], 1)

            dfdt = torch.autograd.grad(torch.sum(f), time, create_graph=True)[0].reshape(args.batch_size, 1)
            dfdx = torch.autograd.grad(torch.sum(f), xt, create_graph=True)[0].reshape(args.batch_size, args.dim)

            lap_f = 0.0
            for i in range(xt.shape[1]):  
                d2fdx2_i, = torch.autograd.grad((dfdx[:, i]).sum(), xt, create_graph=True)
                lap_f += d2fdx2_i[:, i]

            dot = (dfdx * -1 * (beta(time, args.beta_min, args.beta_max) * (xt + dfdx))).sum(1, keepdims=True)
            residual= (dfdt + dot.reshape(args.batch_size, 1) - beta(time, args.beta_min, args.beta_max) * (lap_f.reshape(args.batch_size, 1) + args.dim))
            loss = (residual**2).mean()
            total_loss += loss.item()

            loss.backward()
            optimizer.step()

        if k % 1000 == 0:
            lr_scheduler.step()

 
        writer.add_scalar('Loss/train', total_loss, k)
        
        if (k + 1) % 10000 == 0 or k ==0:
            evaluate_and_plot(k, gmm, psi, schedule, device, writer, args)

        descr = f"Loss={total_loss:.3f}"
        progress_bar.set_description(descr)

    torch.save(psi.state_dict(), f"nets_gmm_GF/psi.pt")
    torch.save(schedule.state_dict(), f"nets_gmm_GF/schedule.pt")
    writer.close()

def evaluate_and_plot(k, gmm, psi, schedule, device, writer, args):
    wrapper_v = TorchWrapper(psi, schedule, gmm, lambda t: beta(t, args.beta_min, args.beta_max))
    wrapper = cnf_sample(wrapper_v)
    
    try:
        with torch.no_grad():
            traj_forward = odeint(wrapper_v,
                                  gmm.sample((args.eval_samples,)).float(),
                                  torch.linspace(0.001, 1, 50).to(device),
                                  atol=1e-4,
                                  rtol=1e-4,
                                  method='dopri5')
            start= torch.randn((args.eval_samples, args.dim), device=device)*args.sigma
            traj_backward,logs = odeint(wrapper,
                    (start, torch.zeros(args.eval_samples, 1).type(torch.float32).to(device)),
                    torch.linspace(1, 0.001, 2).to(device),
                    atol=1e-4,
                    rtol=1e-4,
                    method='dopri5')

   
       
        logs=logs[-1].detach()
        es,nll= ess((-1) * torch.sum(start**2, dim=1)*(1/(2*args.sigma**2)),logs,gmm.log_prob(traj_backward[-1].detach()))
        print(f'ESS: {es}')
        print(f'NLL: {nll.item()}')
        writer.add_scalar('Evaluation/ESS', es, k)
        writer.add_scalar('Evaluation/NLL', nll.item(), k)
        
        # Plotting
        fig, axs = plt.subplots()
        gmm.to("cpu")
        plot_contours(gmm.log_prob, bounds=(args.loc_scaling * -1.4, args.loc_scaling * 1.4), ax=axs, n_contour_levels=80, grid_width_n_points=200)
        gmm.to(device)
        
        plot_marginal_pair(traj_backward[-1], ax=axs, bounds=(args.loc_scaling * -1.4, args.loc_scaling * 1.4))
        
        plt.savefig(f"imgsGMM_GF/Samples_{k}.png")
        plt.close()
        
   
        writer.add_figure('Evaluation/Samples', fig, k)
        
    except Exception as e:
        print(f"An error occurred during evaluation: {e}")

if __name__ == "__main__":
    args = parse_args()
    main(args)
