import torch
import torch.nn as nn
import torch.nn.functional as F

class ResnetDiffusionBlock(nn.Module):
  """ Resnet score block.
      Adds embedding for each scale after each linear layer.
  """
  def __init__(self, n_steps, n_layers, x_dim, h_dim, emb_dim, widen=2, emb_use=True, name=None):
    super().__init__()
    self._n_layers = n_layers
    self._n_steps  = n_steps
    self._emb_dim  = emb_dim
    self._emb_use  = emb_use
    self._x_dim    = x_dim
    self._h_dim    = h_dim
    self._widen    = widen


    self.layer_norm = nn.LayerNorm(normalized_shape=h_dim)
    self.layer_hmb  = nn.Linear(self._h_dim, self._h_dim * self._widen)
    self.layer_emb  = nn.Linear(self._emb_dim, self._h_dim * self._widen)
    self.layer_int  = nn.Linear(self._h_dim * self._widen, self._h_dim * self._widen)
    self.layer_out  = nn.Linear(self._h_dim * self._widen, self._h_dim)
    
  def forward(self, x, emb):
    h = self.layer_norm(x)
    h = F.silu(h)
    h = self.layer_hmb(h) + self.layer_emb(emb)
    h = F.silu(h)
    h = self.layer_int(h)
    h = F.silu(h)
    h = self.layer_out(h)
    return h
  
# Define a simple MLP Diffusion Model
class ResnetDiffusionModel(nn.Module):
  """ Resnet score model.
      Adds embedding for each scale after each linear layer.
  """
  def __init__(self, n_steps, n_layers, x_dim, h_dim, emb_dim, widen=2, emb_use=True, name=None):
    super().__init__()
    self._n_layers = n_layers
    self._n_steps  = n_steps
    self._emb_dim  = emb_dim
    self._emb_use  = emb_use
    self._x_dim    = x_dim
    self._h_dim    = h_dim
    self._widen    = widen

    self.encode = nn.Linear(x_dim, h_dim)
    self.decode = nn.Linear(h_dim, x_dim)
    self.embedding = nn.Embedding(n_steps, emb_dim)
    self.resnet_blocks = nn.ModuleList([ResnetDiffusionBlock(n_steps, n_layers, x_dim, h_dim, emb_dim, widen, emb_use, name) for _ in range(n_layers)])

  def __call__(self, x, t):
    x = torch.atleast_2d(x)
    t = torch.atleast_1d(t)
    assert x.shape[-1] == self._x_dim and t.dim() == 1

    if self._emb_use:
      emb = self.embedding(t)
    else:
      emb = torch.zeros(list(t.shape) + [self.embedding.embedding_dim])
    
    x = self.encode(x.float())
    # get layers and embeddings
    for block in self.resnet_blocks:
      x = block(x, emb) + x
    x = self.decode(x)
    return x

# Define a energy diffusion model (wrapper around a normal diffusion model)
class EBMDiffusionModel(nn.Module):
  """ EBM parameterization on top of score model.  """
  def __init__(self, net, name=None):
    super().__init__()
    self.net = net

  def neg_logp_unnorm(self, x, t):
    score = self.net(x, t)
    return ((score - x) ** 2).sum(-1)

  def __call__(self, x, t):
    x.requires_grad_()
    neg_logp_unnorm = self.neg_logp_unnorm(x, t).sum()
    return torch.autograd.grad([neg_logp_unnorm], [x], retain_graph = True, create_graph=True)[0]

  def forward_no_grad(self, x, t):
    x.requires_grad_()
    neg_logp_unnorm = self.neg_logp_unnorm(x, t).sum()
    grad = torch.autograd.grad([neg_logp_unnorm], [x], retain_graph = False, create_graph=False)[0]
    x.requires_grad = False
    return grad.detach()
  
# Define how to multiply two different EBM distributions together
class ProductEBMDiffusionModel(nn.Module):
  """ EBM where we compose two distributions together.
      Add the energy value together
  """
  def __init__(self, net1, net2, name=None):
    super().__init__()
    self.net1 = net1
    self.net2 = net2

  def neg_logp_unnorm(self, x, t):
    unorm_1 = self.net1.neg_logp_unnorm(x, t)
    unorm_2 = self.net2.neg_logp_unnorm(x, t)
    return unorm_1 + unorm_2

  def __call__(self, x, t):
    score = (self.net1(x, t) + self.net2(x, t))
    return score

  def forward_no_grad(self, x, t):
    score = (self.net1.forward_no_grad(x, t) + self.net2.forward_no_grad(x, t))
    return score

'''
# Define how to add two different EBM distributions
class MixtureEBMDiffusionModel(hk.Module):
  """ EBM where we compose two distributions together.
      Take the logsumexp of the energies
  """
  def __init__(self, net1, net2, name=None):
    super().__init__(name=name)
    self.net1 = net1
    self.net2 = net2

  def neg_logp_unnorm(self, x, t):
    unorm_1 = self.net1.neg_logp_unnorm(x, t)
    unorm_2 = self.net2.neg_logp_unnorm(x, t)
    concat_energy = jnp.stack([unorm_1, unorm_2], axis=-1)
    energy = -jax.scipy.special.logsumexp(-concat_energy*3.5, -1)
    return energy

  def __call__(self, x, t):
    neg_logp_unnorm = lambda _x: self.neg_logp_unnorm(_x, t).sum()
    return hk.grad(neg_logp_unnorm)(x)

# Define how to minus two different EBM distributions
class NegationEBMDiffusionModel(hk.Module):
  """ 
    EBM where we compose two distributions together.
    Negate one distribution
  """

  def __init__(self, net1, net2, name=None):
    super().__init__(name=name)
    self.net1 = net1
    self.net2 = net2

  def neg_logp_unnorm(self, x, t):
    unorm_1 = self.net1.neg_logp_unnorm(x, t)
    unorm_2 = self.net2.neg_logp_unnorm(x, t)
    return 1.3 * unorm_1 - 0.3 * unorm_2

  def __call__(self, x, t):
    neg_logp_unnorm = lambda _x: self.neg_logp_unnorm(_x, t).sum()
    return hk.grad(neg_logp_unnorm)(x)
'''