import torch
from simulators.oup import oup
import sbibm
from networks.summary_nets import OUPSummary, GL
from utils.get_nn_models import *
from inference.snpe.snpe_c import SNPE_C as SNPE
from inference.base import *
from utils.torchutils import *
from utils.metrics import RMSE
import pickle
import os
import argparse
import utils.metrics as metrics
import random
from sbibm.metrics import c2st


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def set_seed(seed: int = 42) -> None:
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"Random seed set as {seed}")

def MMD_unweighted(x, y, lengthscale):
    """ Approximates the squared MMD between samples x_i ~ P and y_i ~ Q
    """

    m = x.shape[0]
    n = y.shape[0]

    z = torch.cat((x, y), dim=0)

    K = kernel_matrix(z, z, lengthscale)

    kxx = K[0:m, 0:m]
    kyy = K[m:(m + n), m:(m + n)]
    kxy = K[0:m, m:(m + n)]

    return (1 / m ** 2) * torch.sum(kxx) - (2 / (m * n)) * torch.sum(kxy) + (1 / n ** 2) * torch.sum(kyy)
    # return (1 / m ** 2) * torch.sum(kxx) - (2 / (m * n)) * torch.sum(kxy)

def median_heuristic(y):
    a = torch.cdist(y, y)**2
    return torch.sqrt(torch.median(a / 2))


def kernel_matrix(x, y, l):
    d = torch.cdist(x, y)**2

    kernel = torch.exp(-(1 / (2 * l ** 2)) * d)

    return kernel

def sample_posteriors(posterior, obs, num):
    return posterior.sample((num,), x=obs, show_progress_bars=False)
 
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--degree", type=float, default=0.1, help="degree of missingness")
    parser.add_argument("--type", type=str, default='mcar')
    args = parser.parse_args()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    #set_seed(42)
    task = sbibm.get_task('bernoulli_glm_raw')
    prior = task.prior_dist
    post_samples_final = task.get_reference_posterior_samples(num_observation=1)
    post_samples_final = post_samples_final[:1000]

    prior = torch.distributions.multivariate_normal.MultivariateNormal(task.prior_params['loc'].to(device), task.prior_params['precision_matrix'].to(device))

    sum_net = GL(input_size=1, hidden_dim=4).to(device)
    neural_posterior = posterior_nn(
            model="maf",
            embedding_net=sum_net,
            hidden_features=20,
            num_transforms=3)

    inference = SNPE(prior=prior, density_estimator=neural_posterior,types ='glm',degree=args.degree,missing=args.type, device='cuda')
    theta = torch.tensor(np.load("missing_data/glm_theta_1000.npy")).to(device)
    x = torch.tensor(np.load("missing_data/glm_x_1000.npy")).to(device)

    density_estimator,missing_model = inference.append_simulations(theta, x.unsqueeze(1)).train(
            distance='none', x_obs=None, beta=0)
    torch.save(density_estimator, "test/density_estimator_glm.pkl")
    torch.save(sum_net, "test/sum_net_glm.pkl")
    posterior = inference.build_posterior(density_estimator)
    with open("test/posterior_glm.pkl", "wb") as handle:
            pickle.dump(posterior, handle)

    n_samples = 1000
    theta_gt = torch.tensor(np.load(f"missing_data/glm_theta_obs.npy"))
    obs_sample = torch.tensor(np.load(f"missing_data/glm_obs_zero_"+ str(int(args.degree*100))+".npy")).to(device)
    mask_sample = torch.tensor(np.load(f"missing_data/glm_obs_mask_"+ str(int(args.degree*100))+".npy")).to(device)

    max_val = x.max()
    min_val = x.min()

    lengthscale = median_heuristic(post_samples_final.cpu())

    x_obs_norm = (obs_sample  - min_val)/(max_val - min_val)
    pred_mean,std,pred_dist= missing_model(x_obs_norm.squeeze(1),mask_sample.squeeze(1))

    x_unnorm = pred_mean*(max_val - min_val) + min_val
    x_new = mask_sample*obs_sample + (1-mask_sample)*x_unnorm

    n_sim = 100
    rmse_zero_npe = np.zeros(n_sim)
    mmd_zero_npe = np.zeros(n_sim)
    for i in range(0, n_sim):
            post_samples = sample_posteriors(posterior, x_new.unsqueeze(1), n_samples)
            rmse_zero_npe[i] = torch.sqrt(((post_samples.mean(dim=0).detach().cpu()-theta_gt.cpu())**2).mean()).item()
            mmd_zero_npe[i] = MMD_unweighted(post_samples.detach().cpu(), post_samples_final.cpu(), lengthscale)
    
    print(f" RMSE mean={np.mean(rmse_zero_npe)}, MMD mean={np.mean(mmd_zero_npe)}")