import torch
import matplotlib.pyplot as plt
import numpy as np
from torch import nn
from torchdiffeq import odeint
import torch.nn.functional as F
from models import *
import time
from tqdm import tqdm
import copy
import pandas as pd
from geomloss import SamplesLoss
from eval_utils import *
from fab.utils.plotting import plot_contours, plot_marginal_pair
from fab.target_distributions.gmm import GMM


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dim = 2
dimension=dim
n_mixes = 40
loc_scaling = 40.0  # scale of the problem (changes how far apart the modes of each Guassian component will be)
log_var_scaling = 1.0 # variance of each Gaussian
torch.manual_seed(0)  # seed of 0 for GMM problem
gmm = GMM(dim=dim, n_mixes=n_mixes,
              loc_scaling=loc_scaling, log_var_scaling=log_var_scaling,
              use_gpu=True, true_expectation_estimation_n_samples=int(1e5))
eval_samples= 50000

gmm.to("cuda")

sigma = 20.0 
velo = MLP2(dim=dim, out_dim=dim, time_varying=True, w=256).to(device)
#eval for linear and learned
velo.load_state_dict(torch.load(f"nets_gmm_linear/velo_{sigma}.pt")) 



class TorchWrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, t, x):
        with torch.set_grad_enabled(True):
            x.requires_grad_(True)
            time = t.repeat(x.shape[0])[:, None]
            v = self.model(torch.cat([x, time], 1))
        return v
energy_dist= SamplesLoss(loss="energy")
samples_true = gmm.sample((eval_samples,))#.cpu().data.numpy()

flow_class = cnf_sample(TorchWrapper(velo)) 
samples= torch.zeros_like(samples_true,device=device)
weights= torch.zeros((eval_samples,1),device=device)


e_values = []
es_values = []
nl_values = []



for p in range(10):
    samples= torch.zeros_like(samples_true,device=device)
    weights= torch.zeros((eval_samples,1),device=device)
    start= torch.randn((eval_samples, dim), device=device)*sigma
    for k in range(10):
        with torch.no_grad():
            z_t, logs = odeint(flow_class,
                (start[k*5000:(k+1)*5000,:], torch.zeros(5000, 1).type(torch.float32).to(device)),
                torch.linspace(0, 1, 2).to(device),
                atol=1e-4,
                rtol=1e-4,
                method='dopri5')
            
            logs = logs[-1].detach()
            traj = z_t[-1].detach()

            weights[k*5000:(k+1)*5000,:] = logs
            samples[k*5000:(k+1)*5000,:] = traj
    samples_true = gmm.sample((eval_samples,))  # Sampling from true GMM distribution
    
    # Compute energy distance
    e_value = energy_dist(samples.reshape(eval_samples, 2), samples_true.reshape(eval_samples, 2))
    e_values.append(e_value.cpu().numpy())

    logq = -1 * torch.sum(start**2, dim=1) * (1 / (2 * sigma**2))
    log_weights = gmm.log_prob(samples)
    log_weights_flow = logq - weights.squeeze()

    # Compute NLL
    nll_value = -torch.mean(log_weights)
    nl_values.append(nll_value.cpu().item())

    # Compute ESS as in stimper et al
    log_weights = log_weights - log_weights_flow
    log_weights = F.softmax(log_weights, dim=0)
    ess_value = 1 / torch.sum(log_weights ** 2) / log_weights.shape[0]
    es_values.append(ess_value.cpu().item())

e_values = np.array(e_values)
es_values = np.array(es_values)
nl_values = np.array(nl_values)

e_mean, e_std = np.mean(e_values), np.std(e_values)
es_mean, es_std = np.mean(es_values), np.std(es_values)
nl_mean, nl_std = np.mean(nl_values), np.std(nl_values)

print(f"Energy: Mean = {e_mean}, Std = {e_std}")
print(f"ESS: Mean = {es_mean}, Std = {es_std}")
print(f"NLL: Mean = {nl_mean}, Std = {nl_std}")
