from torch.distributions import MultivariateNormal, TransformedDistribution
import torch
from invert_linear import CondInvLinear, CondScale
from matplotlib import pyplot as pl
from mcmc_kernels import IsotropicRWM, IsotropicMALA, Proposal_Mixture, FullRWM,FullRWMFixed,  IsotropicUniform
from mixture_weight_networks import ReluMixtureWeights
import numpy as np
from mcmc_samplers import MCMCChain
from plot_utils import *
from loss_functions import *

d = 2
req_grad = False
ground_truth_mean = torch.zeros(d, requires_grad=req_grad)
ground_truth_covariance = torch.eye(d, requires_grad=req_grad)

dist1 = FullRWM(d, bias=True)
dist2 = FullRWM(d, bias=True)
dist3 = FullRWM(d, bias=True)
dist4 = FullRWM(d, bias=True)
with torch.no_grad():
    dist1.transform.bias.data *= 0
    dist2.transform.bias.data *= 0
    dist3.transform.bias.data *= 0
    dist4.transform.bias.data *= 0
    dist1.transform.bias[0] += 4
    dist2.transform.bias[0] -= 4
    dist3.transform.bias[1] += 4
    dist4.transform.bias[1] -= 4

dist_list = [dist1, dist2, dist3, dist4]
target = Proposal_Mixture([dist1, dist2, dist3, dist4])

num_comp = 4
prop_list = [FullRWMFixed(d, bias=True) for comp in range(0, num_comp)]

model_weights = ReluMixtureWeights(num_comp, d, 8*d, 4)
model = Proposal_Mixture(prop_list,model_weights)
with torch.no_grad():
    for index, prop in enumerate(prop_list):  
        prop.transform.weight *= 0
        prop.transform.weight += dist_list[index % len(dist_list)].transform.weight.data
        prop.transform.bias *= 0
        prop.transform.bias += dist_list[index % len(dist_list)].transform.bias.data
    
plotx, ploty = torch.meshgrid(torch.Tensor(range(-50,51)).float()/5.0, torch.Tensor(range(-50,51)).float()/5.0)
min_x = float(plotx.min())
max_x = float(plotx.max())
min_y =float(ploty.min())
max_y = float(ploty.max())
plot_grid = torch.cat((plotx.unsqueeze(-1), ploty.unsqueeze(-1)), -1).float()

target_prob = torch.exp(target.log_prob(plot_grid.view(len(plotx) * len(ploty), 2)))
pl.imshow((target_prob).view(len(plotx), len(ploty)).detach().numpy(), extent=[min_x, max_x, min_y, max_y])
pl.colorbar()
pl.xlabel('$x$')
pl.ylabel('$y$')
pl.title('$\pi(x, y)$')
pl.show()

def component_prob(x, weight_network, index):
    with torch.no_grad():
        weight_network.condition(x)
        weight_results_start = weight_network()
        print(weight_results_start.shape)
    return torch.nn.functional.softmax(weight_results_start, dim=1)[:, index]
    
range_opts = {'x_range' : [-5.0, 5.0], 'y_range':[-5.0, 5.0], 'x_num': 100, 'y_num':100}
fig = pl.figure(figsize=(11, 4))
axes = fig.subplots(1, 4, sharey=True)
for index, ax in enumerate(axes):
    pl.sca(ax)
    contour_plot(lambda x : torch.exp(target.log_prob(x)), range_opt=range_opts, show=False, label=False, outline=True, linestyles='dashed', levels = 4, legend='$Contours \ of \ Target \ Density, \ \pi(x,y)$')
    pcm = density_plot(lambda x : component_prob(x, model_weights, index), range_opt=range_opts, cmap='viridis', show=False, vmin=0.0, vmax=1.0, cbar=False)
    contour_plot(lambda x : torch.exp(prop_list[index].log_prob(x)), range_opt=range_opts, color_opt='red', label=False, outline=False, show=False, levels = 4, legend="$Contours \ of \ Component \ Proposal \ Densities, \ g_{i}(x',y')$")
    ax.axes.xaxis.set_visible(False)
    ax.axes.yaxis.set_visible(False)
    pl.title("$g_{%i}(x',y')$" % (index + 1))

handles, labels = ax.get_legend_handles_labels()
fig.colorbar(pcm, orientation='horizontal', ax=axes, aspect=60,  pad=0.01, anchor=(0.0, 0.5)).set_label(label='$P(g_{i} | x,y)$',size=12)
fig.legend(handles, labels, loc='lower center', bbox_to_anchor=(0.06, 0.17, 0.9, 0.05), ncol=2, fontsize=12, frameon=False)
pl.subplots_adjust(bottom=0.3, wspace=0.02)
pl.show()

stability_factor = 1e-5
weight_opt = torch.optim.Adam(model.mixture_parameters, lr = 0.003)
comp_opt = torch.optim.Adam(model.component_parameters, lr = 0.003)

avg_rate = 0
num_avg = 0
for i in range(0, 1000):
    weight_opt.zero_grad()
    comp_opt.zero_grad()
    start = target.rsample([1]).detach()

    model.condition(start)
    samples, weights = model.uniform_rsample([50])
    loss = (-(acc_rate_loss(model, target, start, samples) * msjd_loss(start, samples))) * torch.exp(weights.detach())
    full_loss = (weights * loss.detach() + loss).mean()
    full_loss.backward(retain_graph=True)
    a_rate = acc_rate_loss(model, target, start, samples).detach() * torch.exp(weights.detach())
        
    comp_opt.step()
    if i > 0:
        weight_opt.step()

    if i % 10 == 0:
        model.condition(start)
        print(i, loss.mean(), torch.diagonal(torch.matmul(model.components[0].transform.weight_mat(), model.components[0].transform.weight_mat().transpose(0,1))).sqrt().mean(), a_rate.sum()/(torch.exp(weights.detach()).sum()))
    
range_opts = {'x_range' : [-5.0, 5.0], 'y_range':[-5.0, 5.0], 'x_num': 100, 'y_num':100}
fig = pl.figure(figsize=(11, 4))
axes = fig.subplots(1, 4, sharey=True)
for index, ax in enumerate(axes):
    pl.sca(ax)
    contour_plot(lambda x : torch.exp(target.log_prob(x)), range_opt=range_opts, show=False, label=False, outline=True, linestyles='dashed', levels = 4, legend='$Contours \ of \ Target \ Density, \ \pi(x,y)$')
    pcm = density_plot(lambda x : component_prob(x, model_weights, index), range_opt=range_opts, cmap='viridis', show=False, vmin=0.0, vmax=1.0, cbar=False)
    contour_plot(lambda x : torch.exp(prop_list[index].log_prob(x)), range_opt=range_opts, color_opt='red', label=False, outline=False, show=False, levels = 4, legend="$Contours \ of \ Component \ Proposal \ Densities, \ g_{i}(x',y')$")
    ax.axes.xaxis.set_visible(False)
    ax.axes.yaxis.set_visible(False)
    pl.title("$g_{%i}(x',y')$" % (index + 1))

handles, labels = ax.get_legend_handles_labels()
fig.colorbar(pcm, orientation='horizontal', ax=axes, aspect=60,  pad=0.01, anchor=(0.0, 0.5)).set_label(label='$P(g_{i} | x,y)$',size=12)
fig.legend(handles, labels, loc='lower center', bbox_to_anchor=(0.06, 0.17, 0.9, 0.05), ncol=2, fontsize=12, frameon=False)
pl.subplots_adjust(bottom=0.3, wspace=0.02)
pl.show()

