import argparse
import torch
import matplotlib.pyplot as plt
import numpy as np
from torch import nn
import torch.distributions as D
from torchdiffeq import odeint
import ot
import torch.nn.functional as F
from models import *
from tqdm import tqdm
from fab.utils.plotting import plot_contours, plot_marginal_pair
from fab.target_distributions.many_well import ManyWellEnergy
from torch.utils.tensorboard import SummaryWriter
from eval_utils import *
import os 
def parse_args():
    parser = argparse.ArgumentParser(description="MW GF Training Script")
    parser.add_argument("--dim", type=int, default=8, help="Dimension of the problem")
    parser.add_argument("--batch_size", type=int, default=4096, help="Batch size for training")
    parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate")
    parser.add_argument("--decay_rate", type=float, default=0.99, help="Learning rate decay rate")
    parser.add_argument("--ntrain", type=int, default=400000, help="Number of training iterations")
    parser.add_argument("--inner_steps", type=int, default=2, help="Number of inner steps")
    parser.add_argument("--sigma", type=float, default=1.0, help="Sigma for initial distribution")
    parser.add_argument("--beta_min", type=float, default=0.1, help="Minimum beta value")
    parser.add_argument("--beta_max", type=float, default=20.0, help="Maximum beta value")
    parser.add_argument("--eval_samples", type=int, default=1000, help="Evaluation Samples")
    return parser.parse_args()

class TorchWrapper(nn.Module):
    def __init__(self, psi, schedule, target, beta):
        super().__init__()
        self.psi = psi
        self.schedule = schedule
        self.target = target
        self.beta = beta

    def forward(self, t, x):
        with torch.set_grad_enabled(True):
            x.requires_grad_(True)
            time = t.repeat(x.shape[0])[:, None]
            
            learned = time * self.psi(torch.cat([x, time], 1)).reshape(x.shape[0], 1) + \
                      ((time) * self.schedule(time).reshape(x.shape[0], 1) + (1 - time)) * (self.target.log_prob(x)).reshape(x.shape[0], 1)
            grad = torch.autograd.grad(torch.sum(learned), x, create_graph=True)[0]

        return -self.beta(time) * (grad.reshape(x.shape[0], x.shape[1]) + x.reshape(x.shape[0], x.shape[1]))

def beta(t, beta_min, beta_max):
    return 0.5 * (beta_min + (beta_max - beta_min) * (t))

def beta_int(t, beta_min, beta_max):
    return beta_min * t + (beta_max - beta_min) / 2 * t**2

def main(args):
    device = "cuda"
    os.makedirs("nets_MW_GF", exist_ok=True)
    torch.manual_seed(0)
    target = ManyWellEnergy(args.dim, a=-0.5, b=-6, use_gpu=True,normalised=False)
    target.to(device)

    psi = MLP4(dim=args.dim, out_dim=1, time_varying=True, w=512).to(device)
    schedule = MLP3(dim=1, out_dim=1, time_varying=False, w=128).to(device)
    Ct = MLP4(dim=1, out_dim=1, time_varying=False, w=256).to(device)


    optimizer = torch.optim.Adam([*psi.parameters(), *schedule.parameters(),*Ct.parameters()], lr=args.lr)
    lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=args.decay_rate)

    writer = SummaryWriter(log_dir=f"runs_MW/GF")

    progress_bar = tqdm(range(args.ntrain), total=args.ntrain, position=0, leave=True)

    for k in progress_bar:
        total_loss = 0
        x0 = torch.randn((args.batch_size, args.dim), device=device) * args.sigma
        x1 = torch.rand((args.batch_size, args.dim), device=device) * (6) - (3) * torch.ones((args.batch_size, args.dim), device=device)
        for j in range(args.inner_steps - 1):
            optimizer.zero_grad()

            time = torch.rand((args.batch_size, 1), device=device) * (0.999) + 0.001
            xt = torch.sqrt((1 - torch.exp(-beta_int(time, args.beta_min, args.beta_max)))) * x0 + \
                 torch.exp(0.5 * (-beta_int(time, args.beta_min, args.beta_max))) * x1

            xt.requires_grad_(True)
            time.requires_grad_(True)
            
            f = (time) * psi(torch.cat([xt, time], 1)).reshape(args.batch_size, 1) + \
                ((time) * schedule(time).reshape(xt.shape[0], 1) + (1 - time)) * (target.log_prob(xt)).reshape(xt.shape[0], 1)

            dfdt = torch.autograd.grad(torch.sum(f), time, create_graph=True)[0].reshape(args.batch_size, 1)
            dfdx = torch.autograd.grad(torch.sum(f), xt, create_graph=True)[0].reshape(args.batch_size, args.dim)
            Zt = Ct(time).reshape(args.batch_size,1)
   
            lap_f = 0.0
            for i in range(xt.shape[1]):  
                d2fdx2_i, = torch.autograd.grad((dfdx[:, i]).sum(), xt, create_graph=True)
                lap_f += d2fdx2_i[:, i]
       
            dot = (dfdx * -1 * (beta(time, args.beta_min, args.beta_max) * (xt + dfdx))).sum(1, keepdims=True)
            residual= (dfdt +Zt+dot.reshape(args.batch_size, 1) - beta(time, args.beta_min, args.beta_max) * (lap_f.reshape(args.batch_size, 1) + args.dim))
            loss = (abs(residual)+residual**2).mean()
            total_loss += loss.item()

            loss.backward()
            torch.nn.utils.clip_grad_norm_(psi.parameters(), 100.0)
            torch.nn.utils.clip_grad_norm_(schedule.parameters(), 100.0)
            optimizer.step()

        if k % 1000 == 0:
            lr_scheduler.step()

        writer.add_scalar('Loss/train', total_loss, k)
        
        if (k + 1) % 10000 == 0 or k ==0:
            evaluate_and_plot(k, target, psi, schedule, device, writer, args)
        
                

        descr = f"Loss={total_loss:.3f}"
        progress_bar.set_description(descr)
    
    torch.save(psi.state_dict(), f"nets_MW_GF/psi.pt")
    torch.save(schedule.state_dict(), f"nets_MW_GF/schedule.pt")
    writer.close()

def evaluate_and_plot(k, target, psi, schedule, device, writer, args):
    wrapper_v = TorchWrapper(psi, schedule, target, lambda t: beta(t, args.beta_min, args.beta_max))
    wrapper = cnf_sample(wrapper_v)
    
    try:
        with torch.no_grad():
            traj_forward = odeint(wrapper_v,
                                  target.sample((args.eval_samples,)).float(),
                                  torch.linspace(0.001, 0.999, 2).to(device),
                                  atol=1e-4,
                                  rtol=1e-4,
                                  method='dopri5')
            start= torch.randn((args.eval_samples, args.dim), device=device)*args.sigma
            traj_backward,logs = odeint(wrapper,
                    (start, torch.zeros(args.eval_samples, 1).type(torch.float32).to(device)),
                    torch.linspace(0.999, 0.001, 2).to(device),
                    atol=1e-4,
                    rtol=1e-4,
                    method='dopri5')

        logs=logs[-1].detach()
        es,nll= ess((-1) * torch.sum(start**2, dim=1)*(1/(2*args.sigma**2)),logs,target.log_prob(traj_backward[-1].detach()))
        print(f'ESS: {es}')
        print(f'NLL: {nll.item()}')
        writer.add_scalar('Evaluation/ESS', es, k)
        writer.add_scalar('Evaluation/NLL', nll.item(), k)
        
        
        
    except Exception as e:
        print(f"An error occurred during evaluation: {e}")

if __name__ == "__main__":
    args = parse_args()
    main(args)
