import torch
import matplotlib.pyplot as plt
import numpy as np
from torch import nn
import torch.distributions as D
from torchdiffeq import odeint
import ot
import torch.nn.functional as F
from models import *
import time
import copy
import pandas as pd
from geomloss import SamplesLoss
from eval_utils import *
from fab.utils.plotting import plot_contours, plot_marginal_pair
from fab.target_distributions.many_well import ManyWellEnergy



device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dim = 8
dimension=dim
target = ManyWellEnergy(dim, a=-0.5, b=-6, use_gpu=True,normalised=True)
target.to(device)
eval_samples= 50000



velo = MLP4(dim=dim, out_dim=dim, time_varying=True, w=512).to(device)


sigma =1.0
velo.load_state_dict(torch.load(f"nets_MW_linear/velo.pt"))
beta_min=0.1
beta_max=20
def beta(t):
     return 0.5*(beta_min + (beta_max-beta_min)*(t))



class TorchWrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, t, x):
        with torch.set_grad_enabled(True):
            x.requires_grad_(True)
            time = t.repeat(x.shape[0])[:, None]
            v = self.model(torch.cat([x, time], 1))
        return v  
energy_dist= SamplesLoss(loss="energy")
samples_true = target.sample((eval_samples,))

flow_class = cnf_sample(TorchWrapper(velo)) 
samples= torch.zeros_like(samples_true,device=device)
weights= torch.zeros((eval_samples,1),device=device)


e_values = []
es_values = []
nl_values = []



for p in range(10):
    samples= torch.zeros_like(samples_true,device=device)
    weights= torch.zeros((eval_samples,1),device=device)
    start= torch.randn((eval_samples, dim), device=device)*sigma
    for k in range(10):
        with torch.no_grad():
            z_t, logs = odeint(flow_class,
                (start[k*5000:(k+1)*5000,:], torch.zeros(5000, 1).type(torch.float32).to(device)),
                torch.linspace(0, 1, 2).to(device),
                atol=1e-4,
                rtol=1e-4,
                method='dopri5')
            
            logs = logs[-1].detach()
            traj = z_t[-1].detach()

            weights[k*5000:(k+1)*5000,:] = logs
            samples[k*5000:(k+1)*5000,:] = traj
    samples_true = target.sample((eval_samples,))
    
    # Compute energy distance
    e_value = energy_dist(samples.reshape(eval_samples, dim), samples_true.reshape(eval_samples, dim))
    e_values.append(e_value.cpu().numpy())

    logq = -1 * torch.sum(start**2, dim=1) * (1 / (2 * sigma**2))
    log_weights = target.log_prob(samples)
    log_weights_flow = logq - weights.squeeze()

    # Compute NLL
    nll_value = -torch.mean(log_weights)
    nl_values.append(nll_value.cpu().item())

    # Compute ESS
    log_weights = log_weights - log_weights_flow
    log_weights = F.softmax(log_weights, dim=0)
    ess_value = 1 / torch.sum(log_weights ** 2) / log_weights.shape[0]
    es_values.append(ess_value.cpu().item())

e_values = np.array(e_values)
es_values = np.array(es_values)
nl_values = np.array(nl_values)

e_mean, e_std = np.mean(e_values), np.std(e_values)
es_mean, es_std = np.mean(es_values), np.std(es_values)
nl_mean, nl_std = np.mean(nl_values), np.std(nl_values)

print(f"Energy: Mean = {e_mean}, Std = {e_std}")
print(f"ESS: Mean = {es_mean}, Std = {es_std}")
print(f"NLL: Mean = {nl_mean}, Std = {nl_std}")
