import numpy as np
import matplotlib.pyplot as plt
import torch.distributions as D

from codes.training import *
from codes.datasets import *
from codes.samplers import *

import matplotlib as mpl
mpl.rcParams['axes.linewidth'] = 3.0

if not __name__ == '__main__':
  raise ValueError

# dataset 1
dataset_gmm_energy, dataset_gmm_sample = toy_gmm(std=.02)
try:
  ddpm_gmm = torch.load('models/ddpm_gmm.pt')
except:
  ddpm_gmm = train_single_model(dataset_gmm_energy, dataset_gmm_sample, name = 'gmm')
  torch.save(ddpm_gmm, 'models/ddpm_gmm.pt')

# dataset 2
scale = 0.2
dataset_bar_energy, dataset_bar_sample = toy_bar(scale=scale)
try:
  ddpm_bar = torch.load('models/ddpm_bar_%.1f.pt'%scale)
except:
  ddpm_bar = train_single_model(dataset_bar_energy, dataset_bar_sample, name = 'bar')
  torch.save(ddpm_bar, 'models/ddpm_bar_%.1f.pt'%scale)

bs = 512
plt.figure(figsize = (5, 5))
samples = dataset_gmm_sample(bs)
plt.scatter(samples[:, 0].cpu(), samples[:, 1].cpu(), alpha=.5)
xr = [-1., 1.]; yr = [-1., 1.]
plt.xlim(*xr); plt.ylim(*yr); plt.xticks([]); plt.yticks([])
plt.title(r'$p_1(x)$', fontsize=18, fontweight='bold', pad=5); plt.tight_layout()
plt.savefig('models/gmm.png')

plt.figure(figsize = (5, 5))
samples = dataset_bar_sample(bs)
plt.scatter(samples[:, 0].cpu(), samples[:, 1].cpu(), alpha=.5)
xr = [-1., 1.]; yr = [-1., 1.]
plt.xlim(*xr); plt.ylim(*yr); plt.xticks([]); plt.yticks([])
plt.title(r'$p_2(x)$', fontsize=18, fontweight='bold', pad=5); plt.tight_layout()
plt.savefig('models/bar.png')

# training hyperparameters
torch.manual_seed(1)
nticks = 200; xr = [-1., 1.]; yr = [-1., 1.]
target_dist = None; dim = 2; init_std = .01; init_mu = 0.0
num_leapfrog = 8; samples_per_step = 4; batch_size = 512; elastic_variance = 0.0
lambda_L = 50.; lambda_G = 1.; lambda_V = 0.5; elastic_force = lambda_L / num_leapfrog
ula_step_size = .002; uha_step_size = .005; damping = .5; mass_diag_sqrt = 1.0

# *** initial dist ***
initial_dist = D.Independent(
  D.Normal(torch.tensor([-1., 1.]) + init_mu, torch.ones(dim) * init_std), 
  reinterpreted_batch_ndims = 1
)
m_sample_steps = [50, 100]
num_leapfrog = 8; samples_per_step = 8; batch_size = 512
x_init = initial_dist.sample(sample_shape=(batch_size,))

# product
if use_ebm:
  _, dual_product_sample_fn, dual_product_nll, dual_product_logp_unorm_fn, dual_product_gradient_fn, dual_product_energy_fn = forward_fn_product(ddpm_gmm.net, ddpm_bar.net)
else:
  _, dual_product_sample_fn, dual_product_nll, dual_product_logp_unorm_fn, dual_product_gradient_fn = forward_fn_product(ddpm_gmm.net, ddpm_bar.net)

for sample_steps in m_sample_steps:
  betas = torch.linspace(0., 1., sample_steps)
  torch.manual_seed(1)

  def gradient_function(x, t):
    t = sample_steps - torch.ones((x.shape[0],), dtype=torch.int32) * t - 1
    return -1 * dual_product_gradient_fn(x, t)

  def energy_function(x, t):
    with torch.no_grad():
      t = sample_steps - torch.ones((x.shape[0],), dtype=torch.int32) * t - 1
      return -1 * dual_product_energy_fn(x, t)

  f = plt.figure(figsize=(16, 11))

  f.add_subplot(2, 3, 1); plt.title('Ground Truth', fontsize=18, fontweight='bold', pad=5)
  x, y = np.meshgrid(np.linspace(xr[0], xr[1], nticks), np.linspace(yr[0], yr[1], nticks))
  coord = torch.from_numpy(np.stack([x, y], axis=-1).reshape((-1, 2))).float().cuda()
  heatmap_gmm = dataset_gmm_energy(coord)
  heatmap_gmm = heatmap_gmm / heatmap_gmm.sum()
  heatmap_bar = dataset_bar_energy(coord)
  heatmap_bar = heatmap_bar / heatmap_bar.sum()
  heatmap = heatmap_gmm * heatmap_bar
  heatmap = heatmap / heatmap.sum()
  if target_dist is not None:
    heatmap = target_dist(coord, 0).exp()
  # h = plt.imshow(heatmap.cpu().reshape((nticks, nticks)).numpy(), origin='lower') 
  # plt.axis('off'); plt.colorbar(h)
  plt.contourf(x, y, heatmap.cpu().reshape((nticks, nticks)).numpy(), levels=40, cmap="Blues"); plt.xticks([]); plt.yticks([])
  # plt.imshow(heatmap.cpu().reshape((nticks, nticks)).numpy(), cmap = 'Blues', origin='lower', interpolation='bilinear')
  plt.xticks([]); plt.yticks([])

  # Samples from MCMC
  ula_step_sizes = torch.ones((sample_steps,)) * ula_step_size
  uha_step_sizes = torch.ones((sample_steps,)) * uha_step_size

  f.add_subplot(2, 3, 2); plt.title('U-LMC', fontsize=18, fontweight='bold', pad=5)
  # rng_seq = hk.PRNGSequence(1)
  sampler_ula = AnnealedULASampler(sample_steps, samples_per_step, ula_step_sizes, initial_dist, target_distribution=target_dist, gradient_function=gradient_function)
  # sampler_uha = AnnealedUHASampler(sample_steps, samples_per_step, uha_step_sizes, damping, mass_diag_sqrt, num_leapfrog, initial_dist, target_distribution=target_dist, gradient_function=gradient_function)
  x_samp, logw, accept = sampler_ula.sample(batch_size)
  plt.scatter(x_samp[:, 0].cpu(), x_samp[:, 1].cpu(), color='green', alpha=.5)
  plt.xlim(*xr); plt.ylim(*yr); plt.xticks([]); plt.yticks([])

  f.add_subplot(2, 3, 3); plt.title('LMC', fontsize=18, fontweight='bold', pad=5)
  sampler_mala = AnnealedMALASampler(sample_steps, samples_per_step, ula_step_sizes, initial_dist, target_distribution=target_dist, gradient_function=gradient_function, energy_function=energy_function)
  x_samp, logw, accept = sampler_mala.sample(batch_size)
  plt.scatter(x_samp[:, 0].cpu(), x_samp[:, 1].cpu(), color='blue', alpha=.5)
  plt.xlim(*xr); plt.ylim(*yr); plt.xticks([]); plt.yticks([])
  
  f.add_subplot(2, 3, 5); plt.title('U-HMC', fontsize=18, fontweight='bold', pad=5)
  sampler_uhmc = AnnealedUHMCSampler(sample_steps, samples_per_step, uha_step_sizes, damping, mass_diag_sqrt, num_leapfrog, initial_dist, target_distribution=target_dist, gradient_function=gradient_function)
  x_samp, logw, accept = sampler_uhmc.sample(batch_size)
  plt.scatter(x_samp[:, 0].cpu(), x_samp[:, 1].cpu(), color='green', alpha=.5)
  plt.xlim(*xr); plt.ylim(*yr); plt.xticks([]); plt.yticks([])

  f.add_subplot(2, 3, 6); plt.title('HMC', fontsize=18, fontweight='bold', pad=5)
  sampler_mahmc = AnnealedMAHMCSampler(sample_steps, samples_per_step, uha_step_sizes, damping, mass_diag_sqrt, num_leapfrog, initial_dist, target_distribution=target_dist, gradient_function=gradient_function, energy_function=energy_function)
  x_samp, logw, accept = sampler_mahmc.sample(batch_size)
  plt.scatter(x_samp[:, 0].cpu(), x_samp[:, 1].cpu(), color='blue', alpha=.5)
  plt.xlim(*xr); plt.ylim(*yr); plt.xticks([]); plt.yticks([])

  f.add_subplot(2, 3, 4); plt.title('FTH', fontsize=18, fontweight='bold', pad=5)
  sampler_fthl = AnnealedLHMCSampler(sample_steps, samples_per_step, uha_step_sizes, 
                                     lambda_L, lambda_G, lambda_V, 
                                     elastic_variance, elastic_force, False, 
                                     damping, mass_diag_sqrt, num_leapfrog, 
                                     initial_dist, target_distribution=target_dist, gradient_function=gradient_function, energy_function=energy_function)

  x_samp, logw, accept = sampler_fthl.sample(batch_size)
  plt.scatter(x_samp[:, 0].cpu(), x_samp[:, 1].cpu(), color='red', alpha=.5)
  plt.xlim(*xr); plt.ylim(*yr); plt.xticks([]); plt.yticks([])
  
  # f.suptitle('Gradient Model Results (# Steps = %d)'%(sample_steps * samples_per_step), fontweight='bold', fontsize=24); 
  f.tight_layout()
  plt.savefig('results/results_compose/compose_results(%d-%d-%d).png'%(sample_steps, samples_per_step, num_leapfrog))