import torch
import matplotlib.pyplot as plt
import numpy as np
from torch import nn
import torch.distributions as D
from torchdiffeq import odeint
import ot
import torch.nn.functional as F
from models import *
import time
import copy
import pandas as pd
from geomloss import SamplesLoss
from eval_utils import *
from fab.utils.plotting import plot_contours, plot_marginal_pair
from fab.target_distributions.many_well import ManyWellEnergy


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dim = 8
dimension=dim
target = ManyWellEnergy(dim, a=-0.5, b=-6, use_gpu=True,normalised=True)
target.to(device)
eval_samples= 50000

sigma=1.0

psi = MLP4(dim=dim, out_dim=1, time_varying=True, w=512).to(device)
schedule = MLP3(dim=1, out_dim=1, time_varying=False, w=128).to(device)


psi.load_state_dict(torch.load("nets_MW_GF/psi.pt"))
schedule.load_state_dict(torch.load("nets_MW_GF/schedule.pt"))

beta_min=0.1
beta_max=20
def beta(t):
     return 0.5*(beta_min + (beta_max-beta_min)*(t))






class torch_wrapper(torch.nn.Module):
   
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, t, x, *args, **kwargs):
        with torch.set_grad_enabled(True):
            x.requires_grad_(True)
            time= t.repeat(x.shape[0])[:, None]

         
            learned=time *self.model(torch.cat([x, time], 1)).reshape(x.shape[0],1) + ((time)*schedule(time).reshape(x.shape[0],1) +(1-time))*(target.log_prob(x)).reshape(x.shape[0],1)
            grad= torch.autograd.grad(torch.sum(learned), x, create_graph=True)[0]
         
            out = -beta(time)*(grad.reshape(x.shape[0],dim) +x.reshape(x.shape[0],dim))
            
        return out
energy_dist= SamplesLoss(loss="energy")
flow_class = cnf_sample(torch_wrapper(psi))  
samples_true = target.sample((eval_samples,))
e_values = []
es_values = []
nl_values = []



for p in range(10):
    samples= torch.zeros_like(samples_true,device=device)
    weights= torch.zeros((eval_samples,1),device=device)
    start= torch.zeros((eval_samples, dim), device=device)*sigma
    k= 0
    while k <50:
        start[k*1000:(k+1)*1000,:] = torch.randn((1000,dim),device=device)
        try:
             with torch.no_grad():
                 z_t, logs = odeint(flow_class,
                     (start[k*1000:(k+1)*1000,:], torch.zeros(1000, 1).type(torch.float32).to(device)),
                     torch.linspace(1, 0.001, 100).to(device),
                     atol=1e-4,
                     rtol=1e-4,
                     method='euler')
                 
                 logs = logs[-1].detach()
                 traj = z_t[-1].detach()
                 if torch.isnan(logs).any() or torch.isnan(traj).any():
                      raise ValueError("NaN values detected in logs or traj")
                 weights[k*1000:(k+1)*1000,:] = logs
                 samples[k*1000:(k+1)*1000,:] = traj
                 k+=1
                 print(k) #low error prob
                 
        except Exception as e:
             print(f"An error occurred during evaluation: {e}")
     
    samples_true = target.sample((eval_samples,)) 
    
    # Compute energy distance
    e_value = energy_dist(samples.reshape(eval_samples, dim), samples_true.reshape(eval_samples, dim))
    e_values.append(e_value.cpu().numpy())
     
    logq = -1 * torch.sum(start**2, dim=1) * (1 / (2 * sigma**2))
    log_weights = target.log_prob(samples)
    log_weights_flow = logq - weights.squeeze()

    # Compute NLL
    nll_value = -torch.mean(log_weights)
    nl_values.append(nll_value.cpu().item())

    # Compute ESS
    log_weights = log_weights - log_weights_flow
    log_weights = F.softmax(log_weights, dim=0)
    ess_value = 1 / torch.sum(log_weights ** 2) / log_weights.shape[0]
    es_values.append(ess_value.cpu().item())
    print(e_value,nll_value,ess_value)
e_values = np.array(e_values)
es_values = np.array(es_values)
nl_values = np.array(nl_values)

e_mean, e_std = np.mean(e_values), np.std(e_values)
es_mean, es_std = np.mean(es_values), np.std(es_values)
nl_mean, nl_std = np.mean(nl_values), np.std(nl_values)

print(f"Energy: Mean = {e_mean}, Std = {e_std}")
print(f"ESS: Mean = {es_mean}, Std = {es_std}")
print(f"NLL: Mean = {nl_mean}, Std = {nl_std}")
