from torch.distributions import MultivariateNormal, TransformedDistribution
import torch
from invert_linear import CondInvLinear, CondScale
from matplotlib import pyplot as pl
from mcmc_kernels import IsotropicRWM, IsotropicMALA, Proposal_Mixture, FullRWM,FullRWMFixed,  IsotropicUniform,IsotropicRWMFixed
from mixture_weight_networks import ReluMixtureWeights
import numpy as np
from mcmc_samplers import MCMCChain
from loss_functions import *
from arviz.stats.diagnostics import _ess as calc_ess

out_file_name='./Results/position_dependent_msjd.txt'
num_repeats = 5
num_chains = 5
chain_len = 1000

d = 2
req_grad = False

if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device('cpu')
dist1 = FullRWMFixed(d, bias=True, device=device)
dist2 = FullRWMFixed(d, bias=True, device=device)
dist3 = FullRWMFixed(d, bias=True, device=device)
dist4 = FullRWMFixed(d, bias=True, device=device)
with torch.no_grad():
    dist1.transform.bias.data *= 0
    dist2.transform.bias.data *= 0
    dist3.transform.bias.data *= 0
    dist4.transform.bias.data *= 0
    dist1.transform.bias[0] += 4
    dist2.transform.bias[0] -= 4
    dist3.transform.bias[1] += 4
    dist4.transform.bias[1] -= 4
dist_list = [dist1, dist2, dist3, dist4]
target = Proposal_Mixture([dist1, dist2, dist3, dist4], device=device)

min_ess_res = np.zeros(num_repeats)
min_ess_sq_res = np.zeros(num_repeats)
rate_res = np.zeros(num_repeats)
msjd_res = np.zeros(num_repeats)
for repeat in range(0, num_repeats):
    num_comp = 4
    prop_list = [FullRWMFixed(d, bias=True, device=device) for comp in range(0, num_comp)]

    model_weights = ReluMixtureWeights(num_comp, d, 8*d, 4, device=device)
    model = Proposal_Mixture(prop_list,model_weights, device=device)
    with torch.no_grad():
        for index, prop in enumerate(prop_list):  
            prop.transform.weight *= 0
            prop.transform.weight += dist_list[index % len(dist_list)].transform.weight.data
            prop.transform.bias *= 0
            prop.transform.bias += dist_list[index % len(dist_list)].transform.bias.data

    stability_factor = 1e-5
    weight_opt = torch.optim.Adam(model.mixture_parameters, lr = 0.0003)
    comp_opt = torch.optim.Adam(model.component_parameters, lr = 0.0003)

    avg_rate = 0
    num_avg = 0
    for i in range(0, 40000):
        weight_opt.zero_grad()
        comp_opt.zero_grad()
        start = target.rsample([1]).detach()
        model.condition(start)
        samples, weights = model.uniform_rsample([50])
        loss = (-(acc_rate_loss(model, target, start, samples)*msjd_loss(start, samples)))*torch.exp(weights.detach())
        full_loss = (weights * loss.detach() + loss).mean()
        full_loss.backward(retain_graph=True)
        a_rate = acc_rate_loss(model, target, start, samples).cpu().detach()*torch.exp(weights.cpu().detach())
        comp_opt.step()
        weight_opt.step()

        if i % 10 == 0:
            model.condition(start)
            print(i, torch.diagonal(torch.matmul(model.components[0].transform.weight_mat(), model.components[0].transform.weight_mat().transpose(0,1))).sqrt().mean(), a_rate.sum()/(torch.exp(weights.cpu().detach()).sum()))
            if i > 10000:
                num_avg += 1
                with torch.no_grad():
                    avg_rate += a_rate.sum()/(torch.exp(weights.cpu().detach()).sum())
                print("AVG_RATE: ", avg_rate/num_avg)

    chain = np.zeros([num_chains, chain_len, d])
    for i in range(0, num_chains):
        start = target.rsample([1])
        batch, acceptance_rate = target.sample([chain_len]), MCMCChain(model, target, start, batch_len)
        chain[i,:,:] = batch
        rate_res[repeat] += acceptance_rate/float(num_chains)
    min_ess_res[repeat] = min(my_ESS(chain[:, :, 0]), my_ESS(chain[:, :, 1]))
    min_ess_sq_res[repeat] = min(my_ESS(chain[:, :, 0]**2), my_ESS(chain[:, :, 1])**2)
    msjd_res[repeat] = np.mean((chain[:, 1:, 0] - chain[:, 0:-1, 0])**2 + (chain[:, 1:, 1] - chain[:, 0:-1, 1])**2)
print("Acceptance Rate: %f %f" % (np.mean(rate_res), np.std(rate_res)) )
print("Min ESS(X): %f %f" % (np.mean(min_ess_res), np.std(min_ess_res)) )
print("Min ESS(X^2): %f %f" % (np.mean(min_ess_sq_res), np.std(min_ess_sq_res)) )
print("MSJD: %f %f" % (np.mean(msjd_res), np.std(msjd_res)) )
with open(out_file_name, 'w') as outfile:
    outfile.write("Acceptance Rate: %f %f\n" % (np.mean(rate_res), np.std(rate_res)) )
    outfile.write("Min ESS(X): %f %f\n" % (np.mean(min_ess_res), np.std(min_ess_res)) )
    outfile.write("Min ESS(X^2): %f %f\n" % (np.mean(min_ess_sq_res), np.std(min_ess_sq_res)) )
    outfile.write("MSJD: %f %f\n" % (np.mean(msjd_res), np.std(msjd_res)) )   

    

