import torch
import matplotlib.pyplot as plt
import numpy as np
from torch import nn
from torchdiffeq import odeint
import torch.nn.functional as F
from models import *
import time
from tqdm import tqdm
import copy
import pandas as pd
from geomloss import SamplesLoss
from eval_utils import *
from fab.utils.plotting import plot_contours, plot_marginal_pair
from fab.target_distributions.gmm import GMM


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dim = 2
dimension=dim
n_mixes = 40
loc_scaling = 40.0  
log_var_scaling = 1.0 

gmm = GMM(dim=dim, n_mixes=n_mixes,
              loc_scaling=loc_scaling, log_var_scaling=log_var_scaling,
              use_gpu=True, true_expectation_estimation_n_samples=int(1e5))
eval_samples= 50000

gmm.to("cuda")
sigma = 1.0

psi = MLP2(dim=dim, out_dim=1, time_varying=True, w=256).to(device)
schedule=  MLP3(dim=1, out_dim=1, time_varying=False, w=256).to(device)


psi.load_state_dict(torch.load("nets_gmm_GF/psi.pt"))
schedule.load_state_dict(torch.load("nets_gmm_GF/schedule.pt"))
beta_min=0.1
beta_max=20
def beta(t):
     return 0.5*(beta_min + (beta_max-beta_min)*(t))





class torch_wrapper(torch.nn.Module):
   
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, t, x, *args, **kwargs):
        with torch.set_grad_enabled(True):
            x.requires_grad_(True)
            time= t.repeat(x.shape[0])[:, None]

         
            learned=time *self.model(torch.cat([x, time], 1)).reshape(x.shape[0],1) + ((time)*schedule(time).reshape(x.shape[0],1) +(1-time))*(gmm.log_prob(x)).reshape(x.shape[0],1)
            grad= torch.autograd.grad(torch.sum(learned), x, create_graph=True)[0]

        return -beta(time)*(grad.reshape(x.shape[0],dim) +x.reshape(x.shape[0],dim))  
energy_dist= SamplesLoss(loss="energy")
flow_class = cnf_sample(torch_wrapper(psi))  
samples_true = gmm.sample((eval_samples,))
print(samples_true.shape)
samples= torch.zeros_like(samples_true,device=device)
weights= torch.zeros((eval_samples,1),device=device)
start= torch.randn((eval_samples, dim), device=device)




e_values = []
es_values = []
nl_values = []



for p in range(10):
    samples= torch.zeros_like(samples_true,device=device)
    weights= torch.zeros((eval_samples,1),device=device)
    start= torch.randn((eval_samples, dim), device=device)*sigma
    for k in range(10):
        with torch.no_grad():
            z_t, logs = odeint(flow_class,
                (start[k*5000:(k+1)*5000,:], torch.zeros(5000, 1).type(torch.float32).to(device)),
                torch.linspace(1, 0.001, 2).to(device),
                atol=1e-4,
                rtol=1e-4,
                method='dopri5')
            
            logs = logs[-1].detach()
            traj = z_t[-1].detach()

            weights[k*5000:(k+1)*5000,:] = logs
            samples[k*5000:(k+1)*5000,:] = traj
    samples_true = gmm.sample((eval_samples,))  # Sampling from true GMM distribution
    
    # Compute energy distance
    e_value = energy_dist(samples.reshape(eval_samples, 2), samples_true.reshape(eval_samples, 2))
    e_values.append(e_value.cpu().numpy())

    logq = -1 * torch.sum(start**2, dim=1) * (1 / (2 * sigma**2))
    log_weights = gmm.log_prob(samples)
    log_weights_flow = logq - weights.squeeze()

    # Compute NLL
    nll_value = -torch.mean(log_weights)
    nl_values.append(nll_value.cpu().item())

    # Compute ESS as in stimper et al
    log_weights = log_weights - log_weights_flow
    log_weights = F.softmax(log_weights, dim=0)
    ess_value = 1 / torch.sum(log_weights ** 2) / log_weights.shape[0]
    es_values.append(ess_value.cpu().item())

e_values = np.array(e_values)
es_values = np.array(es_values)
nl_values = np.array(nl_values)

e_mean, e_std = np.mean(e_values), np.std(e_values)
es_mean, es_std = np.mean(es_values), np.std(es_values)
nl_mean, nl_std = np.mean(nl_values), np.std(nl_values)

print(f"Energy: Mean = {e_mean}, Std = {e_std}")
print(f"ESS: Mean = {es_mean}, Std = {es_std}")
print(f"NLL: Mean = {nl_mean}, Std = {nl_std}")


