from torch.distributions import MultivariateNormal, TransformedDistribution
import torch
from invert_linear import CondInvLinear, CondScale
from matplotlib import pyplot as pl
from mcmc_kernels import IsotropicRWM, IsotropicMALA, Proposal_Mixture, FullRWM,FullRWMFixed,  IsotropicUniform
from mixture_weight_networks import ReluMixtureWeights
import numpy as np
from mcmc_samplers import MCMCChain, MCMCStep, AuxMCMCChain
from plot_utils import *
from nice_kernels import NICEResample, NICERelu, NICEDoubleRelu
from nice_transforms import NICE
from logistic_regression_distributions  import HeartLR
from loss_functions import *
from arviz.stats.diagnostics import _ess as calc_ess

if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device('cpu')

chain_save_file='./Results/Samples/heart_samples.npy'
out_file_name='./Results/heart_abinitio.txt'
num_repeats = 5
num_chains = 5
chain_len= 20000

num_aux = 14
target = HeartLR(device=device)
d = target.predictors.shape[-1] + num_aux

min_ess_res = np.zeros(num_repeats)
min_ess_sq_res = np.zeros(num_repeats)
med_ess_res = np.zeros(num_repeats)
med_ess_sq_res = np.zeros(num_repeats)
max_ess_res = np.zeros(num_repeats)
max_ess_sq_res = np.zeros(num_repeats)
rate_res = np.zeros(num_repeats)
msjd_res = np.zeros(num_repeats)
for repeat in range(0, num_repeats):
    flow = NICE(d, alternating_mask=True, num_layers=5, layer_depth=4,layer_width=3*d, device=device)

    model = NICEDoubleRelu(d, {'width':3*d, 'depth' : 4},{'width':3*d, 'depth' : 4}, transform=flow, device=device, rwm_bias=0.0)

    comp_opt = torch.optim.Adam(model.parameters, lr = 0.0003)

    avg_rate = 0
    num_avg = 0
    num_starts = 8
    num_samples = 50
    sample_depth = 0
    num_resample = 0

    losses = np.zeros(2000)
    hmc_samples = np.float32(np.load(chain_save_file))
    hmc_samples = np.float32(np.concatenate((hmc_samples, np.random.normal(size=[len(hmc_samples), num_aux])), axis=1))
    np.random.shuffle(hmc_samples)
    start = torch.from_numpy((hmc_samples[0:num_starts])).to(device)
    start=start.unsqueeze(1).repeat(1, num_samples, 1)
    start[:,:, d-num_aux:].normal_()
    start = start.view(num_samples*num_starts, d)
    for i in range(0, 40000):
        comp_opt.zero_grad()
        model.condition(start)
        samples = model.rsample([1])
        start = start.unsqueeze(1).repeat(1, 1, 1)
        start = start.view(num_starts*num_samples, d)
        loss = (-kl_loss(model, target,start, samples) - d*log_acc_rate_loss(model, target, start, samples)* 0.18125).mean()
        loss.backward()
        comp_opt.step()

        if i % 10 == 0:
            with torch.no_grad():
                a_rate = acc_rate_loss(model, target, start, samples)
            print(repeat, i, loss.mean(), a_rate.mean())
            if i > 10000:
                num_avg += 1
                with torch.no_grad():
                    avg_rate += a_rate.mean()
                print("AVG_RATE: ", avg_rate/num_avg)

        start = start.view(num_starts,num_samples, d)[:,0, :].squeeze(1)
        for j in range(0, num_resample):
            new_start, acc_rate =MCMCStep(model, target, start[j])
            start[j] = new_start.detach()
        start[num_resample:] = torch.from_numpy((hmc_samples[i*num_starts:i*num_starts + num_starts])[num_resample:]).to(device)
        start=start.unsqueeze(1).repeat(1, num_samples, 1)
        start[:,:, d-num_aux:].normal_()
        start = start.view(num_samples*num_starts, d)

    chain = np.zeros([num_chains,chain_len, d])
    for i in range(0, num_chains):
        start = torch.from_numpy(hmc_samples[-i - 1]).to(device)
        batch, acceptance_rate = AuxMCMCChain(model, target, start,num_aux, chain_len)#MCMCChain(model, target, start, chain_len)#
        print(batch)
        chain[i,:,:] = batch
        rate_res[repeat] += acceptance_rate/float(num_repeats)

    print("Acceptance Rate: ", acceptance_rate)
    msjd_res[repeat] = float(d-num_aux) * np.mean((chain[:, 1:, 0:d-num_aux] - chain[:, 0:-1, 0:d-num_aux])**2)
    ess_sq = [ calc_ess(chain[:, :, i]**2) for i in range(0, d-num_aux)]
    ess = [ calc_ess(chain[:, :, i]) for i in range(0, d-num_aux)]
    min_ess_res[repeat] = np.min(ess)
    med_ess_res[repeat] = np.median(ess)
    max_ess_res[repeat] = np.max(ess)
    min_ess_sq_res[repeat] = np.min(ess_sq)
    med_ess_sq_res[repeat] = np.median(ess_sq)
    max_ess_sq_res[repeat] = np.max(ess_sq)
with open(out_file_name, 'w') as outfile:
    outfile.write("Acceptance Rate: %f %f\n" % (np.mean(rate_res), np.std(rate_res)) )
    outfile.write("Min ESS(X): %f %f\n" % (np.mean(min_ess_res), np.std(min_ess_res)) )
    outfile.write("Min ESS(X^2): %f %f\n" % (np.mean(min_ess_sq_res), np.std(min_ess_sq_res)) )
    outfile.write("Median ESS(X): %f %f\n" % (np.mean(med_ess_res), np.std(med_ess_res)) )
    outfile.write("Median ESS(X^2): %f %f\n" % (np.mean(med_ess_sq_res), np.std(med_ess_sq_res)) )
    outfile.write("Max ESS(X): %f %f\n" % (np.mean(max_ess_res), np.std(max_ess_res)) )
    outfile.write("Max ESS(X^2): %f %f\n" % (np.mean(max_ess_sq_res), np.std(max_ess_sq_res)) )
    outfile.write("MSJD: %f %f\n" % (np.mean(msjd_res), np.std(msjd_res)) )
