from torch.distributions import MultivariateNormal, TransformedDistribution, Independent, Laplace
import torch
from invert_linear import CondInvLinear, CondScale
import numpy as np
from mcmc_kernels import IsotropicRWM, IsotropicMALAGaussian, IsotropicUniform, Proposal_Mixture
from mcmc_samplers import MCMCChain
from loss_functions import *

if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = torch.device('cpu')

out_file_name = './Results/ver_rwm_laplace_10000_abinitio.txt'
num_repeats = 5
num_chains = 5
chain_len = 5000

d = 10000
req_grad = False
ground_truth_mean = torch.zeros(d, requires_grad=req_grad).to(device)
ground_truth_scale = torch.ones(d, requires_grad=req_grad).to(device)
target = Independent(Laplace(loc=ground_truth_mean, scale=ground_truth_scale), 1)

msj_start = np.zeros([num_repeats, num_chains, chain_len])
rate_start = np.zeros([num_repeats, num_chains, chain_len])
msj_end = np.zeros([num_repeats, num_chains, chain_len])
rate_end = np.zeros([num_repeats, num_chains, chain_len])
for repeat in range(0, num_repeats):
    model = IsotropicRWM(d, bias=False, device=device)
    with torch.no_grad():
        model.transform.weight.data *= 2.39
    for j in range(0, num_chains):
        for i in range(0, chain_len):
            start = target.rsample([1])
            start = torch.clamp(start, min=-1e6, max=1e6)
            chain, acceptance_rate = MCMCChain(model, target, start[0], 2)
            msj_start[repeat, j, i] = ((chain[0] - chain[1])**2.0).mean()
            rate_start[repeat, j, i] = acceptance_rate
    print(np.mean(rate_start[repeat, :, :]))
    with torch.no_grad():
        model.transform.weight.data *= 1.1

    stability_factor = 1e-5
    opt = torch.optim.Adam(model.parameters, lr = 0.0003)
    avg_rate = 0
    num_avg = 0

    for i in range(0, 20000):
        opt.zero_grad()
        start = target.rsample([1])
        start = torch.clamp(start, min=-1e6, max=1e6)

        model.condition(start)
        samples = model.rsample([50])
        log_acc_rate = log_acc_rate_loss(model, target, start, samples)
        loss = (-kl_loss(model, target,start, samples) - d*log_acc_rate* 0.18125).mean()
        loss.backward(retain_graph=True)
        opt.step()

        if i > 10000:
                a_rate = torch.exp(log_acc_rate.detach().cpu())
                num_avg += 1
                with torch.no_grad():
                    avg_rate += a_rate.mean()

        if i % 10 == 0:
                a_rate = torch.exp(log_acc_rate.cpu().detach())

                mean_rate = a_rate.mean()
                print(i, mean_rate)
            
                if i > 10000:
                    print("AVG_RATE: ", avg_rate/num_avg)

    for j in range(0, num_chains):
        for i in range(0, chain_len):
            start = target.rsample([1])
            start = torch.clamp(start, min=-1e6, max=1e6)
            chain, acceptance_rate = MCMCChain(model, target, start[0], 2)
            msj_end[repeat, j, i] = ((chain[0] - chain[1])**2.0).mean()
            rate_end[repeat, j, i] = acceptance_rate
print("MSJD Analytic:", np.mean(msj_start), np.std(np.mean(np.mean(msj_start, axis=2), axis=1)))
print("MSJD Optimized:", np.mean(msj_end), np.std(np.mean(np.mean(msj_end, axis=2), axis=1)))
print("Acceptance Rate Analytic:", np.mean(rate_start), np.std(np.mean(np.mean(rate_start, axis=2), axis=1)))
print("Acceptance Rate Optimized:", np.mean(rate_end), np.std(np.mean(np.mean(rate_end, axis=2), axis=1)))
with open(out_file_name, 'w') as outfile:
    outfile.write("MSJD Analytic: %f %f" % (np.mean(msj_start), np.std(np.mean(np.mean(msj_start, axis=2), axis=1))))
    outfile.write("MSJD Optimized: %f %f" % (np.mean(msj_end), np.std(np.mean(np.mean(msj_end, axis=2), axis=1))))
    outfile.write("Acceptance Rate Analytic: %f %f" % (np.mean(rate_start), np.std(np.mean(np.mean(rate_start, axis=2), axis=1))))
    outfile.write("Acceptance Rate Optimized: %f %f" % (np.mean(rate_end), np.std(np.mean(np.mean(rate_end, axis=2), axis=1))))
       
            
    

