from torch.distributions import MultivariateNormal, TransformedDistribution
import torch
from invert_linear import CondInvLinear, CondScale
from matplotlib import pyplot as pl
from mcmc_kernels import IsotropicRWM, IsotropicMALA, Proposal_Mixture, FullRWM,FullRWMFixed,  IsotropicUniform
from mixture_weight_networks import ReluMixtureWeights
import numpy as np
from mcmc_samplers import MCMCChain
from plot_utils import *
from nice_kernels import NICEResample, NICERelu, NICEDoubleRelu
from nice_transforms import NICE
from loss_functions import *

device = torch.device('cpu')

d = 2
dist1 = FullRWM(d, bias=True)
dist2 = FullRWM(d, bias=True)
dist3 = FullRWM(d, bias=True)
dist4 = FullRWM(d, bias=True)

with torch.no_grad():
    dist1.transform.bias.data *= 0
    dist2.transform.bias.data *= 0
    dist3.transform.bias.data *= 0
    dist4.transform.bias.data *= 0
    dist1.transform.bias[0] += 4
    dist2.transform.bias[0] -= 4
    dist3.transform.bias[1] += 4
    dist4.transform.bias[1] -= 4

dist_list = [dist1, dist2, dist3, dist4]
target = Proposal_Mixture([dist1, dist2, dist3, dist4])

flow = NICE(d, alternating_mask=True, num_layers=8, layer_depth=3,layer_width=6, device=device)

with torch.no_grad():
    flow.forward_network[0].weights += 2.0
model = NICEDoubleRelu(d, {'width':8*d, 'depth' : 4},{'width':8*d, 'depth' : 4}, transform=flow, device=device, rwm_bias=0.0)

stability_factor = 1e-5
comp_opt = torch.optim.Adam(model.parameters, lr = 0.003)

avg_rate = 0
num_avg = 0
num_starts = 8
num_samples = 100

for i in range(0, 20000):
    comp_opt.zero_grad()
    start = target.rsample([num_starts]).detach()
    model.condition(start)
    samples = model.rsample([num_samples])
    start = start.unsqueeze(1).repeat(1, num_samples, 1)
    start = start.view(num_starts*num_samples, d)
    loss = (-kl_loss(model, target,start, samples) - d*log_acc_rate_loss(model, target, start, samples)* 0.18125).mean()
    loss.backward()
    comp_opt.step()

    if i % 10 == 0:
        with torch.no_grad():
            a_rate = acc_rate_loss(model, target, start, samples)
        print(i, loss.mean(), a_rate.mean())
        if i > 10000:
            num_avg += 1
            with torch.no_grad():
                avg_rate += a_rate.mean()
            print("AVG_RATE: ", avg_rate/num_avg)

range_opts = {'x_range' : [-5.0, 5.0], 'y_range':[-5.0, 5.0], 'x_num': 100, 'y_num':100}
fig = pl.figure(figsize=(11, 4))
axes = fig.subplots(1, 4, sharey=True)
for index, ax in enumerate(axes):
    pl.sca(ax)
    starting_point = dist_list[index].transform.bias.data.clone()
    model.condition(starting_point)
    contour_plot(lambda x : torch.exp(target.log_prob(x)), range_opt=range_opts, show=False, label=False, outline=True, linestyles='dashed', levels = 4, legend='$Contours \ of \ Target \ Density, \ \pi(x,y)$')
    pcm = density_plot(lambda x : torch.exp(model.log_prob(x)), range_opt=range_opts, cmap='viridis', show=False, vmin=None, vmax=None, cbar=False)
    pl.scatter(starting_point.cpu().numpy()[0], starting_point.cpu().numpy()[1], c='red', label='$Starting Point, (x_{0}, y_{0})$')
    ax.axes.xaxis.set_visible(False)
    ax.axes.yaxis.set_visible(False)
pl.suptitle("$g(x',y' | x_{0}, y_{0})$")
handles, labels = ax.get_legend_handles_labels()

fig.legend(handles, labels, loc='lower center', bbox_to_anchor=(0.06, 0.25, 0.9, 0.05), ncol=2, fontsize=12, frameon=False)
pl.subplots_adjust(bottom=0.35, wspace=0.02)
pl.show()