from torch.distributions import MultivariateNormal, TransformedDistribution
import torch
from invert_linear import CondInvLinear, CondScale
from matplotlib import pyplot as pl
from mcmc_kernels import IsotropicRWM, IsotropicMALA, Proposal_Mixture, FullRWM,FullRWMFixed,  IsotropicUniform
from mixture_weight_networks import ReluMixtureWeights
import numpy as np
from mcmc_samplers import GibbsStep, GibbsChain
from plot_utils import *
from nice_kernels import NICEResample, NICERelu, NICEDoubleRelu
from nice_transforms import NICE
from logistic_regression_distributions  import HeartLR
import hamiltorch
from arviz.stats.diagnostics import _ess as calc_ess

if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device('cpu')

chain_save_file='./Results/Samples/heart_samples.npy'
out_file_name='./Results/heart_hmc.txt'
num_repeats = 5
num_chains = 5
chain_len= 20000

target = HeartLR(device=device)
d = target.predictors.shape[-1]

hmc_samples = np.float32(np.load(chain_save_file)[1000:])
np.random.shuffle(hmc_samples)

min_ess_res = np.zeros(num_repeats)
min_ess_sq_res = np.zeros(num_repeats)
med_ess_res = np.zeros(num_repeats)
med_ess_sq_res = np.zeros(num_repeats)
max_ess_res = np.zeros(num_repeats)
max_ess_sq_res = np.zeros(num_repeats)
rate_res = np.zeros(num_repeats)
msjd_res = np.zeros(num_repeats)
for repeat in range(0,num_repeats):
    hmc_samples = np.float32(np.load(chain_save_file))
    np.random.shuffle(hmc_samples)
    chain = np.zeros([num_chains,chain_len, d])
    for i in range(0, num_chains):
        start = torch.from_numpy(hmc_samples[-i - 1]).to(device)
        batch = hamiltorch.sample(target.log_prob, start, num_samples=(chain_len + 1000),  burn=1000,   sampler=hamiltorch.Sampler.HMC_NUTS, desired_accept_rate=0.65, store_on_GPU = True)
        batch=torch.stack(batch).cpu().numpy()
        chain[i,:,:] = batch
    msjd_res[repeat] = float(d) * np.mean((chain[:, 1:, :] - chain[:, 0:-1, :])**2)
    ess_sq = [ calc_ess(chain[:, :, i]**2) for i in range(0, d)]
    ess = [ calc_ess(chain[:, :, i]) for i in range(0, d)]
    min_ess_res[repeat] = np.min(ess)
    med_ess_res[repeat] = np.median(ess)
    max_ess_res[repeat] = np.max(ess)
    min_ess_sq_res[repeat] = np.min(ess_sq)
    med_ess_sq_res[repeat] = np.median(ess_sq)
    max_ess_sq_res[repeat] = np.max(ess_sq)

with open(out_file_name, 'w') as outfile:
    outfile.write("Acceptance Rate: %f %f\n" % (np.mean(rate_res), np.std(rate_res)) )
    outfile.write("Min ESS(X): %f %f\n" % (np.mean(min_ess_res), np.std(min_ess_res)) )
    outfile.write("Min ESS(X^2): %f %f\n" % (np.mean(min_ess_sq_res), np.std(min_ess_sq_res)) )
    outfile.write("Median ESS(X): %f %f\n" % (np.mean(med_ess_res), np.std(med_ess_res)) )
    outfile.write("Median ESS(X^2): %f %f\n" % (np.mean(med_ess_sq_res), np.std(med_ess_sq_res)) )
    outfile.write("Max ESS(X): %f %f\n" % (np.mean(max_ess_res), np.std(max_ess_res)) )
    outfile.write("Max ESS(X^2): %f %f\n" % (np.mean(max_ess_sq_res), np.std(max_ess_sq_res)) )
    outfile.write("MSJD: %f %f\n" % (np.mean(msjd_res), np.std(msjd_res)) )

