from torch.distributions import MultivariateNormal, TransformedDistribution
import torch
from invert_linear import CondInvLinear, CondScale
from matplotlib import pyplot as pl
from mcmc_kernels import IsotropicRWM, IsotropicMALA, Proposal_Mixture, FullRWM,FullRWMFixed,  IsotropicUniform,IsotropicRWMFixed
from mixture_weight_networks import ReluMixtureWeights
import numpy as np
from mcmc_samplers import MCMCChain
from loss_functions import *
from arviz.stats.diagnostics import _ess as calc_ess

out_file_name='./Results/multi_scheme_l3_iidresample.txt'
num_repeats = 5
num_chains = 5
chain_len = 1000

d = 2
req_grad = False

if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device('cpu')
dist1 = FullRWMFixed(d, bias=True, device=device)
dist2 = FullRWMFixed(d, bias=True, device=device)
dist3 = FullRWMFixed(d, bias=True, device=device)
dist4 = FullRWMFixed(d, bias=True, device=device)
with torch.no_grad():
    dist1.transform.bias.data *= 0
    dist2.transform.bias.data *= 0
    dist3.transform.bias.data *= 0
    dist4.transform.bias.data *= 0
    dist1.transform.bias[0] += 4
    dist2.transform.bias[0] -= 4
    dist3.transform.bias[1] += 4
    dist4.transform.bias[1] -= 4
dist_list = [dist1, dist2, dist3, dist4]
target = Proposal_Mixture([dist1, dist2, dist3, dist4], device=device)

min_ess_res = np.zeros(num_repeats)
min_ess_sq_res = np.zeros(num_repeats)
rate_res = np.zeros(num_repeats)
msjd_res = np.zeros(num_repeats)
for repeat in range(0, num_repeats):
    chain = np.zeros([num_chains, chain_len, d])
    for i in range(0, num_chains):
        start = target.rsample([1])
        batch, acceptance_rate = target.sample([chain_len]), 1.0
        chain[i,:,:] = batch
        rate_res[repeat] += acceptance_rate/float(num_chains)
    min_ess_res[repeat] = min(my_ESS(chain[:, :, 0]), my_ESS(chain[:, :, 1]))
    min_ess_sq_res[repeat] = min(my_ESS(chain[:, :, 0]**2), my_ESS(chain[:, :, 1])**2)
    msjd_res[repeat] = np.mean((chain[:, 1:, 0] - chain[:, 0:-1, 0])**2 + (chain[:, 1:, 1] - chain[:, 0:-1, 1])**2)
print("Acceptance Rate: %f %f" % (np.mean(rate_res), np.std(rate_res)) )
print("Min ESS(X): %f %f" % (np.mean(min_ess_res), np.std(min_ess_res)) )
print("Min ESS(X^2): %f %f" % (np.mean(min_ess_sq_res), np.std(min_ess_sq_res)) )
print("MSJD: %f %f" % (np.mean(msjd_res), np.std(msjd_res)) )
with open(out_file_name, 'w') as outfile:
    outfile.write("Acceptance Rate: %f %f\n" % (np.mean(rate_res), np.std(rate_res)) )
    outfile.write("Min ESS(X): %f %f\n" % (np.mean(min_ess_res), np.std(min_ess_res)) )
    outfile.write("Min ESS(X^2): %f %f\n" % (np.mean(min_ess_sq_res), np.std(min_ess_sq_res)) )
    outfile.write("MSJD: %f %f\n" % (np.mean(msjd_res), np.std(msjd_res)) )