import argparse
import torch
import matplotlib.pyplot as plt
import numpy as np
from torch import nn
import torch.distributions as D
from torchdiffeq import odeint
import ot
import torch.nn.functional as F
from models import *
from tqdm import tqdm
from fab.utils.plotting import plot_contours, plot_marginal_pair
from fab.target_distributions.many_well import ManyWellEnergy
from torch.utils.tensorboard import SummaryWriter
from eval_utils import *
import os 

def parse_args():
    parser = argparse.ArgumentParser(description="GMM Mate Training Script")
    parser.add_argument("--dim", type=int, default=8, help="Dimension of the problem")
    parser.add_argument("--batch_size", type=int, default=256, help="Batch size for training")
    parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate")
    parser.add_argument("--decay_rate", type=float, default=0.98, help="Learning rate decay rate")
    parser.add_argument("--ntrain", type=int, default=50000, help="Number of training iterations")
    parser.add_argument("--inner_steps", type=int, default=51, help="Number of inner steps")
    parser.add_argument("--sigma", type=float, default=1.0, help="Sigma for initial distribution")
    parser.add_argument("--end", type=float, default=1.0, help="End time for ODE")
    parser.add_argument("--eval_samples", type=int, default=1000, help="Evaluation Samples")

    return parser.parse_args()


class TorchWrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, t, x):
        with torch.set_grad_enabled(True):
            x.requires_grad_(True)
            time = t.repeat(x.shape[0])[:, None]
            v = self.model(torch.cat([x, time], 1))
        return v

def main(args):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    os.makedirs("nets_MW_linear", exist_ok=True)

    torch.manual_seed(0)
    target = ManyWellEnergy(args.dim, a=-0.5, b=-6, use_gpu=True,normalised=False)
    target.to(device)

    velo = MLP4(dim=args.dim, out_dim=args.dim, time_varying=True, w=512).to(device)
    Ct = MLP4(dim=1, out_dim=1, time_varying=False, w=512).to(device)

    optimizer = torch.optim.Adam([*velo.parameters(),*Ct.parameters()], lr=args.lr)
    lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=args.decay_rate)

    writer = SummaryWriter(log_dir=f"runs_MW/linear")

    progress_bar = tqdm(range(args.ntrain), total=args.ntrain, position=0, leave=True)

    for k in progress_bar:
        total_loss = 0
        xt = torch.randn((args.batch_size, args.dim), device=device) * args.sigma
        times = torch.linspace(0,1,args.inner_steps)
       

        optimizer.zero_grad()
        for j in range(args.inner_steps - 1):
            t = times[j].view(1,1).to(device)
            time_s = t.repeat(args.batch_size,1)
            t1 = times[j+1].view(1,1).to(device)
            time = t1.repeat(args.batch_size,1)
            with torch.no_grad():
                traj = odeint(TorchWrapper(velo),
                                  xt,
                                  torch.tensor([t.item(), t1.item()]).type(torch.float32).to(device),
                                  atol=1e-4,
                                  rtol=1e-4,
                                  method='dopri5',
                )
                xt=traj[-1].detach()
        
            xt.requires_grad_(True)
            time.requires_grad_(True)
            Zt = Ct(time).reshape(args.batch_size,1)

            f_0 = (-1) * torch.sum(xt**2, dim=1).reshape(xt.shape[0], 1) *(1/(2*args.sigma**2))
            f_1 = target.log_prob(xt).reshape(xt.shape[0], 1)
            f = (time) * f_1 + (1-time) * f_0

            dfdt =f_1 - f_0
            dfdx = torch.autograd.grad(torch.sum(f), xt, create_graph=True)[0].reshape(args.batch_size, args.dim)
            vt = velo(torch.cat([xt, time], 1)).reshape(args.batch_size, args.dim)
            
            div = 0.0
            for i in range(xt.shape[1]):  
                d2vdx2_i, = torch.autograd.grad((vt[:, i]).sum(), xt, create_graph=True)
                div += d2vdx2_i[:, i]

            dot = (dfdx * vt).sum(1, keepdims=True)
            residual = dfdt.reshape(args.batch_size, 1)+Zt +dot.reshape(args.batch_size, 1) + div.reshape(args.batch_size, 1)
            loss = (abs(residual)+(residual)**2).mean()
            loss/= (args.inner_steps -1)
            total_loss += loss.item()

            loss.backward()
        torch.nn.utils.clip_grad_norm_(Ct.parameters(), 100.0)
        torch.nn.utils.clip_grad_norm_(velo.parameters(), 100.0)
        optimizer.step()

        if k % 50 == 0:
            lr_scheduler.step()

        writer.add_scalar('Loss/train', total_loss, k)
        
        if (k + 1) % 500 == 0:
            evaluate_and_plot(k, target, velo, device, writer, args)
        descr = f"Loss={total_loss:.3f}"
        progress_bar.set_description(descr)

    torch.save(velo.state_dict(), f"nets_MW_linear/velo.pt")
    torch.save(Ct.state_dict(), f"nets_MW_linear/CT.pt")

    writer.close()

def evaluate_and_plot(k, target, velo, device, writer, args):
    wrapper_v = TorchWrapper(velo)
    wrapper = cnf_sample(wrapper_v)
    try:
        with torch.no_grad():
                       
            start= torch.randn((args.eval_samples, args.dim), device=device)*args.sigma
            traj,logs = odeint(wrapper,
                    (start, torch.zeros(args.eval_samples, 1).type(torch.float32).to(device)),
                    torch.linspace(0, 1, 2).to(device),
                    atol=1e-4,
                    rtol=1e-4,
                    method='dopri5')


       

        logs=logs[-1].detach()
        es,nll= ess((-1) * torch.sum(start**2, dim=1)*(1/(2*args.sigma**2)),logs,target.log_prob(traj[-1].detach()))
        print(f'ESS: {es}')
        print(f'NLL: {nll.item()}')
        writer.add_scalar('Evaluation/ESS', es, k)
        writer.add_scalar('Evaluation/NLL', nll.item(), k)

    except Exception as e:
        print(f"An error occurred during evaluation: {e}")

if __name__ == "__main__":
    args = parse_args()
    main(args)
