import numpy as np
import matplotlib.pyplot as plt

from models.energy_model import *
from models.diffusion_model import *

import torch
if torch.cuda.is_available():
  torch.set_default_device('cuda')
import torch.distributions as D

# Use a EBM formulation of likelihod vs a score formulation of likelihood
use_ebm = True; n_steps = 100; data_dim = 2

def plot_samples(x):
  y = x.detach().cpu().numpy()
  plt.scatter(y[:, 0], y[:, 1]); plt.xlim(-1., 1.); plt.ylim(-1., 1.)

def dist_show_2d(fn, xr, yr):
    nticks = 100
    x, y = np.meshgrid(np.linspace(xr[0], xr[1], nticks), np.linspace(yr[0], yr[1], nticks))
    coord = torch.from_numpy(np.stack([x, y], axis=-1).reshape((-1, 2))).float()
    with torch.no_grad():
      heatmap = fn(coord.to('cuda')).reshape((nticks, nticks))
    plt.imshow(heatmap.detach().cpu().numpy())

def forward_fn_product(net_one, net_two):
  dual_net = ProductEBMDiffusionModel(net_one, net_two)
  ddpm = PortableDiffusionModel(data_dim, n_steps, dual_net, var_type="beta_forward")

  def logp_unnorm(x, t):
    scale_e = ddpm.energy_scale(-2 - t)
    t = torch.ones((x.shape[0],), dtype=torch.int32) * t
    return -dual_net.neg_logp_unnorm(x, t) * scale_e

  def _logpx(x):
    return ddpm.logpx(x)["logpx"]

  if use_ebm:
    return ddpm.loss, ddpm.sample, _logpx, logp_unnorm, ddpm.p_gradient, ddpm.p_energy
  else:
    return ddpm.loss, ddpm.sample, _logpx, logp_unnorm, ddpm.p_gradient

'''
def forward_fn_mixture():
  net_one = ResnetDiffusionModel(n_steps=n_steps, n_layers=4, x_dim=data_dim, h_dim=128, emb_dim=32)

  if use_ebm:
    net_one = EBMDiffusionModel(net_one)

  net_two = ResnetDiffusionModel(n_steps=n_steps, n_layers=4, x_dim=data_dim, h_dim=128, emb_dim=32)

  if use_ebm:
    net_two = EBMDiffusionModel(net_two)

  dual_net = MixtureEBMDiffusionModel(net_one, net_two)
  ddpm = PortableDiffusionModel(data_dim, n_steps, dual_net, var_type="beta_forward")

  def logp_unnorm(x, t):
    scale_e = ddpm.energy_scale(-2 - t)
    t = jnp.ones((x.shape[0],), dtype=jnp.int32) * t
    return -dual_net.neg_logp_unnorm(x, t) * scale_e

  def _logpx(x):
    return ddpm.logpx(x)["logpx"]

  if use_ebm:
    return ddpm.loss, (ddpm.loss, ddpm.sample, _logpx, logp_unnorm, ddpm.p_gradient, ddpm.p_energy)
  else:
    return ddpm.loss, (ddpm.loss, ddpm.sample, _logpx, logp_unnorm, ddpm.p_gradient)

forward_mixture = hk.multi_transform(forward_fn_mixture)

def forward_fn_negation():
  net_one = ResnetDiffusionModel(n_steps=n_steps, n_layers=4, x_dim=data_dim, h_dim=128, emb_dim=32)

  if use_ebm:
    net_one = EBMDiffusionModel(net_one)

  net_two = ResnetDiffusionModel(n_steps=n_steps, n_layers=4, x_dim=data_dim, h_dim=128, emb_dim=32)

  if use_ebm:
    net_two = EBMDiffusionModel(net_two)

  dual_net = NegationEBMDiffusionModel(net_one, net_two)
  ddpm = PortableDiffusionModel(data_dim, n_steps, dual_net, var_type="beta_forward")

  def logp_unnorm(x, t):
    scale_e = ddpm.energy_scale(-2 - t)
    t = jnp.ones((x.shape[0],), dtype=jnp.int32) * t
    return -dual_net.neg_logp_unnorm(x, t) * scale_e

  def _logpx(x):
    return ddpm.logpx(x)["logpx"]

  if use_ebm:
    return ddpm.loss, (ddpm.loss, ddpm.sample, _logpx, logp_unnorm, ddpm.p_gradient, ddpm.p_energy)
  else:
    return ddpm.loss, (ddpm.loss, ddpm.sample, _logpx, logp_unnorm, ddpm.p_gradient)

forward_negation = hk.multi_transform(forward_fn_negation)
'''

def forward_fn(net_params):
  net = ResnetDiffusionModel(n_steps=n_steps, n_layers=net_params['n_layers'], x_dim=data_dim, h_dim=net_params['h_dim'], emb_dim=net_params['emb_dim'])
  if use_ebm:
    net = EBMDiffusionModel(net)
  ddpm = PortableDiffusionModel(data_dim, n_steps, net, var_type="beta_forward")

  def logp_unnorm(x, t):
    scale_e = ddpm.energy_scale(-2 - t)
    t = torch.ones((x.shape[0],), dtype=torch.int32) * t
    return -net.neg_logp_unnorm(x, t) * scale_e

  def epsilon(x, t):
    return net(x, t)

  def _logpx(x):
    return ddpm.logpx(x)["logpx"]

  return ddpm, ddpm.loss, ddpm.sample, _logpx, logp_unnorm, epsilon

def train_single_model(dataset_energy, dataset_sample, plot_epochs = 1000, name = 'gmm'):
  batch_size = 512; num_steps = 12000
  net_params = {"n_layers": 4, "h_dim": 128, "emb_dim": 32}

  ddpm, loss_fn, sample_fn, logpx_fn, logp_unnorm_fn, epsilon_fn = forward_fn(net_params)
  opt = torch.optim.AdamW(ddpm.parameters(), lr = 1e-3, weight_decay=1e-5)
  losses = []; test_logpx = []; itr = 0

  for itr in tqdm(range(1, num_steps + 1)):
    x = dataset_sample(batch_size)
    x = x.reshape(x.shape[0], -1).to('cuda')

    loss = loss_fn(x).mean(); loss.backward(); opt.step(); opt.zero_grad(); 
    losses.append(loss.item())

    if itr % plot_epochs == 0:
      f = plt.figure(figsize=(9, 4))
      x_samp = sample_fn(batch_size)
      f.add_subplot(1, 2, 1); plot_samples(x_samp)
      logpx = logpx_fn(x).mean()
      test_logpx.append(logpx.item())

      if use_ebm:
        f.add_subplot(1, 2, 2)
        dist_show_2d(lambda x: logp_unnorm_fn(x, 0), xr=[-.75, .75], yr=[-.75, .75])
        plt.savefig(f'results/{name}_samples.png')

      plt.close()
  
  return ddpm

def forward_score_fn(net_params):
  net = ResnetDiffusionModel(n_steps=n_steps, x_dim=net_params["data_dim"], h_dim=net_params['h_dim'],
                             n_layers=net_params['n_layers'], emb_dim=net_params['emb_dim'], emb_use=False)
  net = EBMDiffusionModel(net)

  def logp_unnorm(x, t=0):
    return -net.neg_logp_unnorm(x, torch.tensor([0]))

  def grad_logp_unnorm(x, t=0):
    return -net(x, torch.tensor([0]))

  def detach_grad_logp_unnorm(x, t=0):
    return -net.forward_no_grad(x, torch.tensor([0]))
  
  return net, logp_unnorm, grad_logp_unnorm, detach_grad_logp_unnorm

def fit_gradient_function(target_distribution, x, i):
  """ return the gradient w.r.t. input x """
  x.requires_grad_()
  neg_logp_unnorm = target_distribution(x).sum()
  grad = torch.autograd.grad([neg_logp_unnorm], [x], retain_graph = False, create_graph=False)[0]
  x.requires_grad = False
  return grad.detach()

def train_score_model(dataset_sample_mean, dataset_sample_std, num_steps = 10000, batch_size = 256, target_dist = None):
  n_modes = len(dataset_sample_mean)
  comp_dists = D.Independent(
    D.Normal(dataset_sample_mean, dataset_sample_std), 
    reinterpreted_batch_ndims = 1
  )
  mix = D.Categorical(torch.ones(len(dataset_sample_mean)))
  gmm = D.MixtureSameFamily(mix, comp_dists)

  net_params = {"n_layers": 6, "h_dim": 256, "emb_dim": 64, "data_dim": dataset_sample_mean.shape[-1]}
  net, logp_unnorm, grad_logp_unnorm, detach_grad_logp_unnorm = forward_score_fn(net_params)
  opt = torch.optim.AdamW(net.parameters(), lr = 1e-4, weight_decay=1e-5)

  for _ in tqdm(range(1, num_steps + 1)):
    noise_std = 1.0 * dataset_sample_std.unsqueeze(0)  
    noise = noise_std * torch.randn([batch_size, 1] + list(dataset_sample_std.shape[1:])) # [batch_size, n_modes, *dim]
    noise = noise.view(-1, noise.size(-1)).requires_grad_()
    m_logp_unnorm = gmm.log_prob(noise)
    m_grad_logp_unnorm = torch.autograd.grad([m_logp_unnorm.sum()], [noise])[0]
    m_logp_unnorm, m_grad_logp_unnorm = m_logp_unnorm.detach(), m_grad_logp_unnorm.detach()

    loss_grad_logp_unnorm = (grad_logp_unnorm(noise.squeeze()) - m_grad_logp_unnorm).square().sum() / (batch_size * n_modes)
    loss = loss_grad_logp_unnorm # + loss_logp_unnorm
    loss.backward(); opt.step(); opt.zero_grad()

  # return net, target_dist, lambda x, i: fit_gradient_function(target_dist, x, i)
  return net, lambda x, t: logp_unnorm(x).detach(), lambda x, t: detach_grad_logp_unnorm(x).detach()
  # return net, lambda x, t: gmm.log_prob(x).detach(), lambda x, t: fit_gradient_function(gmm.log_prob, x, t).detach()