import argparse
import torch
import matplotlib.pyplot as plt
import numpy as np
from torch import nn
import torch.distributions as D
from torchdiffeq import odeint
import ot
import torch.nn.functional as F
from models import *
from tqdm import tqdm
from fab.utils.plotting import plot_contours, plot_marginal_pair
from fab.target_distributions.gmm import GMM
from torch.utils.tensorboard import SummaryWriter
from eval_utils import *
import os
def parse_args():
    parser = argparse.ArgumentParser(description="GMM Linear Training Script")
    parser.add_argument("--dim", type=int, default=2, help="Dimension of the problem")
    parser.add_argument("--n_mixes", type=int, default=40, help="Number of mixtures in GMM")
    parser.add_argument("--loc_scaling", type=float, default=40.0, help="Scale of the problem")
    parser.add_argument("--log_var_scaling", type=float, default=1.0, help="Variance of each Gaussian")
    parser.add_argument("--batch_size", type=int, default=256, help="Batch size for training")
    parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate")
    parser.add_argument("--decay_rate", type=float, default=0.98, help="Learning rate decay rate")
    parser.add_argument("--ntrain", type=int, default=50000, help="Number of training iterations")
    parser.add_argument("--inner_steps", type=int, default=51, help="Number of inner steps")
    parser.add_argument("--sigma", type=float, default=40.0, help="Sigma for initial distribution")
    parser.add_argument("--end", type=float, default=1.0, help="End time for ODE")
    parser.add_argument("--eval_samples", type=int, default=1000, help="Evaluation Samples")

    return parser.parse_args()


class TorchWrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, t, x):
        with torch.set_grad_enabled(True):
            x.requires_grad_(True)
            time = t.repeat(x.shape[0])[:, None]
            v = self.model(torch.cat([x, time], 1))
        return v

def main(args):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    os.makedirs("nets_gmm_linear", exist_ok=True)
    os.makedirs("imgsGMM_linear", exist_ok=True)
    torch.manual_seed(0)


    gmm = GMM(dim=args.dim, n_mixes=args.n_mixes,
              loc_scaling=args.loc_scaling, log_var_scaling=args.log_var_scaling,
              use_gpu=True, true_expectation_estimation_n_samples=int(1e5))
    gmm.to("cuda")

    velo = MLP2(dim=args.dim, out_dim=args.dim, time_varying=True, w=256).to(device)
    
    Ct = MLP2(dim=1, out_dim=1, time_varying=False, w=256).to(device)

    optimizer = torch.optim.Adam([*velo.parameters(), *Ct.parameters()], lr=args.lr)
    lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=args.decay_rate)

    writer = SummaryWriter(log_dir=f"runs_GMM/gmm_linear_{args.sigma}")

    progress_bar = tqdm(range(args.ntrain), total=args.ntrain, position=0, leave=True)

    for k in progress_bar:
        total_loss = 0
        xt = torch.randn((args.batch_size, args.dim), device=device) * args.sigma
        times = torch.linspace(0,1,args.inner_steps)
       

        optimizer.zero_grad()
        for j in range(args.inner_steps - 1):
            t = times[j].view(1,1).to(device)
            time_s = t.repeat(args.batch_size,1)
            t1 = times[j+1].view(1,1).to(device)
            time = t1.repeat(args.batch_size,1)
            with torch.no_grad():
                traj = odeint(TorchWrapper(velo),
                                  xt,
                                  torch.tensor([t.item(), t1.item()]).type(torch.float32).to(device),
                                  atol=1e-4,
                                  rtol=1e-4,
                                  method='dopri5',
                )
                xt=traj[-1].detach()
     
            xt.requires_grad_(True)
            time.requires_grad_(True)
            Zt = Ct(time).reshape(args.batch_size,1)

            
            f_0 = (-1) * torch.sum(xt**2, dim=1).reshape(xt.shape[0], 1) *(1/(2*args.sigma**2))
            f_1 = gmm.log_prob(xt).reshape(xt.shape[0], 1)
            f =  (time) * f_1 + (1-time) * f_0

            dfdt =  f_1 - f_0
            dfdx = torch.autograd.grad(torch.sum(f), xt, create_graph=True)[0].reshape(args.batch_size, args.dim)
            vt = velo(torch.cat([xt, time], 1)).reshape(args.batch_size, 2)
            
            div = 0.0
            for i in range(xt.shape[1]):  
                d2vdx2_i, = torch.autograd.grad((vt[:, i]).sum(), xt, create_graph=True)
                div += d2vdx2_i[:, i]

            dot = (dfdx * vt).sum(1, keepdims=True)
            residual = dfdt.reshape(args.batch_size, 1) +Zt+ dot.reshape(args.batch_size, 1) + div.reshape(args.batch_size, 1)
            loss = (abs(residual)+(residual)**2).mean()
            loss/= (args.inner_steps -1)
            total_loss += loss.item()

            loss.backward()
        optimizer.step()

        if k % 50 == 0:
            lr_scheduler.step()


        writer.add_scalar('Loss/train', total_loss, k)
        
        if (k + 1) % 200 == 0:
            evaluate_and_plot(k, gmm, velo, device, writer, args)

        descr = f"Loss={total_loss:.3f}"
        progress_bar.set_description(descr)

    torch.save(Ct.state_dict(), f"nets_gmm_linear/Ct_{args.sigma}.pt")
    torch.save(velo.state_dict(), f"nets_gmm_linear/velo_{args.sigma}.pt")
    writer.close()

def evaluate_and_plot(k, gmm, velo, device, writer, args):
    wrapper_v = TorchWrapper(velo)
    wrapper = cnf_sample(wrapper_v)
    try:
        with torch.no_grad():
            
            
            start= torch.randn((args.eval_samples, args.dim), device=device)*args.sigma
            traj,logs = odeint(wrapper,
                    (start, torch.zeros(args.eval_samples, 1).type(torch.float32).to(device)),
                    torch.linspace(0, 1, 2).to(device),
                    atol=1e-4,
                    rtol=1e-4,
                    method='dopri5')

       
        logs=logs[-1].detach()
        es,nll= ess((-1) * torch.sum(start**2, dim=1)*(1/(2*args.sigma**2)),logs,gmm.log_prob(traj[-1].detach()))
        print(f'ESS: {es}')
        print(f'NLL: {nll.item()}')
        writer.add_scalar('Evaluation/ESS', es, k)
        writer.add_scalar('Evaluation/NLL', nll.item(), k)
        
        # Plotting
        fig, axs = plt.subplots()
        gmm.to("cpu")
        plot_contours(gmm.log_prob, bounds=(args.loc_scaling * -1.4, args.loc_scaling * 1.4), ax=axs, n_contour_levels=80, grid_width_n_points=200)
        gmm.to("cuda")

        plot_marginal_pair(traj[-1], ax=axs, bounds=(args.loc_scaling * -1.4, args.loc_scaling * 1.4))
        
        plt.savefig(f"imgsGMM_linear/Samples_{args.sigma}_{k}.png")
        plt.close()


        writer.add_figure('Evaluation/Samples', fig, k)

    except Exception as e:
        print(f"An error occurred during evaluation: {e}")

if __name__ == "__main__":
    args = parse_args()
    main(args)
