from torch.distributions import MultivariateNormal, TransformedDistribution
import torch
from invert_linear import CondInvLinear, CondScale
from matplotlib import pyplot as pl
from mcmc_kernels import IsotropicRWM, IsotropicMALA, Proposal_Mixture, FullRWM,FullRWMFixed,  IsotropicUniform
from mixture_weight_networks import ReluMixtureWeights
import numpy as np
from mcmc_samplers import MCMCChain
from plot_utils import *
from nice_kernels import NICEResample, NICERelu, NICEDoubleRelu
from nice_transforms import NICE
from loss_functions import *
from arviz.stats.diagnostics import _ess as calc_ess

if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device('cpu')

out_file_name='./Results/multi_scheme_l8_abinitio.txt'
num_repeats = 5
num_chains = 5
chain_len = 1000

d = 2
req_grad = False

dist1 = FullRWM(d, bias=True, device=device)
dist2 = FullRWM(d, bias=True, device=device)
dist3 = FullRWM(d, bias=True, device=device)
dist4 = FullRWM(d, bias=True, device=device)
with torch.no_grad():
    dist1.transform.bias.data *= 0
    dist2.transform.bias.data *= 0
    dist3.transform.bias.data *= 0
    dist4.transform.bias.data *= 0
    dist1.transform.bias[0] += 4
    dist2.transform.bias[0] -= 4
    dist3.transform.bias[1] += 4
    dist4.transform.bias[1] -= 4

dist_list = [dist1, dist2, dist3, dist4]
target = Proposal_Mixture([dist1, dist2, dist3, dist4], device=device)

min_ess_res = np.zeros(num_repeats)
min_ess_sq_res = np.zeros(num_repeats)
rate_res = np.zeros(num_repeats)
msjd_res = np.zeros(num_repeats)

repeat = 0
while repeat < num_repeats:
    flow = NICE(d, alternating_mask=True, num_layers=8, layer_depth=3,layer_width=6, device=device)

    with torch.no_grad():
        flow.forward_network[0].weights += 2.0
    model = NICEDoubleRelu(d, {'width':8*d, 'depth' : 4},{'width':8*d, 'depth' : 4}, transform=flow, device=device, rwm_bias=0.0)
    comp_opt = torch.optim.Adam(model.parameters, lr = 0.0003)
    avg_rate = 0
    num_avg = 0
    num_starts = 8
    num_samples = 100

    for i in range(0, 20000):
        comp_opt.zero_grad()
        start = target.rsample([num_starts]).detach()
        model.condition(start)
        samples = model.rsample([num_samples])
        start = start.unsqueeze(1).repeat(1, num_samples, 1)
        start = start.view(num_starts*num_samples, d)
        loss = (-kl_loss(model, target,start, samples) - d*log_acc_rate_loss(model, target, start, samples)* 0.18125).mean()
        loss.backward()
        comp_opt.step()

        if i % 10 == 0:
            with torch.no_grad():
                a_rate = acc_rate_loss(model, target, start, samples)
            print(i, loss.mean(), a_rate.mean())
            if i > 10000:
                num_avg += 1
                with torch.no_grad():
                    avg_rate += a_rate.mean()
                print("AVG_RATE: ", avg_rate/num_avg)

    chain = np.zeros([num_chains,chain_len, d])
    for i in range(0, num_chains):
        start = target.rsample([1])
        batch, acceptance_rate = MCMCChain(model, target, start, chain_len)
        chain[i,:,:] = batch
        rate_res[repeat] += acceptance_rate/float(num_chains)
    min_ess_res[repeat] = min(calc_ess(chain[:, :, 0]), calc_ess(chain[:, :, 1]))
    min_ess_sq_res[repeat] = min(calc_ess(chain[:, :, 0] ** 2), calc_ess(chain[:, :, 1]) ** 2)
    msjd_res[repeat] = np.mean((chain[:, 1:, 0] - chain[:, 0:-1, 0])**2 + (chain[:, 1:, 1] - chain[:, 0:-1, 1])**2)
    repeat += 1
    if rate_res[repeat - 1] <= 0.55:
        repeat -= 1
        rate_res[repeat] *= 0

print("Acceptance Rate: %f %f" % (np.mean(rate_res), np.std(rate_res)) )
print("Min ESS(X): %f %f" % (np.mean(min_ess_res), np.std(min_ess_res)) )
print("Min ESS(X^2): %f %f" % (np.mean(min_ess_sq_res), np.std(min_ess_sq_res)) )
print("MSJD: %f %f" % (np.mean(msjd_res), np.std(msjd_res)) )
with open(out_file_name, 'w') as outfile:
    outfile.write("Acceptance Rate: %f %f\n" % (np.mean(rate_res), np.std(rate_res)) )
    outfile.write("Min ESS(X): %f %f\n" % (np.mean(min_ess_res), np.std(min_ess_res)) )
    outfile.write("Min ESS(X^2): %f %f\n" % (np.mean(min_ess_sq_res), np.std(min_ess_sq_res)) )
    outfile.write("MSJD: %f %f\n" % (np.mean(msjd_res), np.std(msjd_res)) )  

