
import torch.nn as nn
import functools
import torch
import numpy as np
import torch.nn.functional as F
import math
import string
from functools import partial

def _einsum(a, b, c, x, y):
  einsum_str = '{},{}->{}'.format(''.join(a), ''.join(b), ''.join(c))
  return torch.einsum(einsum_str, x, y)

def contract_inner(x, y):
  """tensordot(x, y, 1)."""
  x_chars = list(string.ascii_lowercase[:len(x.shape)])
  y_chars = list(string.ascii_lowercase[len(x.shape):len(y.shape) + len(x.shape)])
  y_chars[0] = x_chars[-1]  # first axis of y and last of x get summed
  out_chars = x_chars[:-1] + y_chars[1:]
  return _einsum(x_chars, y_chars, out_chars, x, y)

class NIN(nn.Module):
  def __init__(self, in_dim, num_units, init_scale=0.1):
    super().__init__()
    self.W = nn.Parameter(default_init(scale=init_scale)((in_dim, num_units)), requires_grad=True)
    self.b = nn.Parameter(torch.zeros(num_units), requires_grad=True)

  def forward(self, x):
    x = x.permute(0, 2, 1)
    y = contract_inner(x, self.W) + self.b
    return y.permute(0, 2, 1)

class AttnBlockpp(nn.Module):
  """Channel-wise self-attention block. Modified from DDPM."""

  def __init__(self, channels, skip_rescale=False, init_scale=0.):
    super().__init__()
    self.GroupNorm_0 = nn.GroupNorm(num_groups=min(channels // 4, 24), num_channels=channels,
                                  eps=1e-6)
    self.NIN_0 = NIN(channels, channels)
    self.NIN_1 = NIN(channels, channels)
    self.NIN_2 = NIN(channels, channels)
    self.NIN_3 = NIN(channels, channels, init_scale=init_scale)
    self.skip_rescale = skip_rescale

  def forward(self, x):
    B, C, L = x.shape
    h = self.GroupNorm_0(x)
    q = self.NIN_0(h)
    k = self.NIN_1(h)
    v = self.NIN_2(h)

    w = torch.einsum('bcl,bcm->blm', q, k) * (int(C) ** (-0.5))
    w = torch.reshape(w, (B, L*L))
    w = F.softmax(w, dim=-1)
    w = torch.reshape(w, (B, L, L))
    h = torch.einsum('blm,bcm->bcl', w, v)
    h = self.NIN_3(h)
    return (x + h) / np.sqrt(2.)

def naive_upsample(x, factor=2):
  _N, C, L = x.shape
  x = torch.reshape(x, (-1, C, L, 1))
  x = x.repeat(1, 1, 1, factor)
  return torch.reshape(x, (-1, C, L * factor))


def naive_downsample(x, factor=2):
  _N, C, L = x.shape
  x = torch.reshape(x, (-1, C, L // factor, factor))
  return torch.mean(x, dim=3)


class ResnetBlockBigGANpp(nn.Module):
  def __init__(self, act, in_ch, out_ch=None, temb_dim=None, up=False, down=False,
               dropout=0.1, fir_kernel=(1, 3, 3, 1),
               skip_rescale=True, init_scale=0.):
    super().__init__()

    out_ch = out_ch if out_ch else in_ch
    self.GroupNorm_0 = nn.GroupNorm(num_groups=min(in_ch // 4, 24), num_channels=in_ch, eps=1e-6)
    self.up = up
    self.down = down
    self.fir_kernel = fir_kernel

    self.Conv_0 = ddpm_conv3x3(in_ch, out_ch)
    if temb_dim is not None:
      self.Dense_0 = nn.Linear(temb_dim, out_ch)
      self.Dense_0.weight.data = default_init()(self.Dense_0.weight.shape)
      nn.init.zeros_(self.Dense_0.bias)

    self.GroupNorm_1 = nn.GroupNorm(num_groups=min(out_ch // 4, 24), num_channels=out_ch, eps=1e-6)
    self.Dropout_0 = nn.Dropout(dropout)
    self.Conv_1 = ddpm_conv3x3(out_ch, out_ch, init_scale=init_scale)
    if in_ch != out_ch or up or down:
      self.Conv_2 = ddpm_conv1x1(in_ch, out_ch)

    self.skip_rescale = skip_rescale
    self.act = act
    self.in_ch = in_ch
    self.out_ch = out_ch

  def forward(self, x, temb=None):
    h = self.act(self.GroupNorm_0(x))

    if self.up:
      h = naive_upsample(h, factor=2)
      x = naive_upsample(x, factor=2)
    elif self.down:
      h = naive_downsample(h, factor=2)
      x = naive_downsample(x, factor=2)

    h = self.Conv_0(h)
    # Add bias to each feature map conditioned on the time embedding
    if temb is not None:
      h += self.Dense_0(self.act(temb))[:, :, None]
    h = self.act(self.GroupNorm_1(h))
    h = self.Dropout_0(h)
    h = self.Conv_1(h)

    if self.in_ch != self.out_ch or self.up or self.down:
      x = self.Conv_2(x)
    
    return (x + h) / np.sqrt(2.)

def ddpm_conv3x3(in_planes, out_planes, stride=1, bias=True, dilation=1, init_scale=1., padding=1):
  """3x3 convolution with DDPM initialization."""
  conv = nn.Conv1d(in_planes, out_planes, kernel_size=3, stride=stride, padding=padding,
                   dilation=dilation, bias=bias)
  conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
  nn.init.zeros_(conv.bias)
  return conv

def ddpm_conv1x1(in_planes, out_planes, stride=1, bias=True, init_scale=1., padding=0):
  """1x1 convolution with DDPM initialization."""
  conv = nn.Conv1d(in_planes, out_planes, kernel_size=1, stride=stride, padding=padding, bias=bias)
  conv.weight.data = default_init(init_scale)(conv.weight.data.shape)
  nn.init.zeros_(conv.bias)
  return conv

def get_act():
  return nn.SiLU()

def default_init(scale=1.):
  scale = 1e-10 if scale == 0 else scale
  return variance_scaling(scale)

def variance_scaling(scale=1., in_axis=1, out_axis=0, dtype=torch.float32, device='cpu'):

  def _compute_fans(shape, in_axis=1, out_axis=0):
    receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
    fan_in = shape[in_axis] * receptive_field_size
    fan_out = shape[out_axis] * receptive_field_size
    return fan_in, fan_out

  def init(shape, dtype=dtype, device=device):
    fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
    denominator = (fan_in + fan_out) / 2
    variance = scale / denominator
    return (torch.rand(*shape, dtype=dtype, device=device) * 2. - 1.) * np.sqrt(3 * variance)
  return init

def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
  assert len(timesteps.shape) == 1  # and timesteps.dtype == tf.int32
  half_dim = embedding_dim // 2
  emb = math.log(max_positions) / (half_dim - 1)
  emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
  emb = timesteps.float()[:, None] * emb[None, :]
  emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
  if embedding_dim % 2 == 1:  # zero pad
    emb = F.pad(emb, (0, 1), mode='constant')
  assert emb.shape == (timesteps.shape[0], embedding_dim)
  return emb

def get_sigmas(config):
  """Get sigmas --- the set of noise levels for SMLD from config files.
  Args:
    config: A ConfigDict object parsed from the config file
  Returns:
    sigmas: a jax numpy arrary of noise levels
  """
  sigmas = np.exp(
    np.linspace(np.log(config.model.sigma_max), np.log(config.model.sigma_min), config.model.num_scales))

  return sigmas

class NCSNpp(nn.Module):

  def __init__(self, arg):
    super().__init__()
    self.act = act = get_act()
    # self.register_buffer('sigmas', torch.tensor(get_sigmas(arg.config)))

    self.nf = nf = arg.nf
    ch_mult = (1, 2, 2, 2)
    self.num_res_blocks = num_res_blocks = 4
    self.attn_resolutions = attn_resolutions = (int(arg.hidden_dim/2),)
    dropout = 0.1
    self.num_resolutions = num_resolutions = len(ch_mult)
    self.all_resolutions = all_resolutions = [arg.hidden_dim // (2 ** i) for i in range(num_resolutions)]

    fir_kernel = [1, 3, 3, 1]
    self.skip_rescale = skip_rescale = True
    init_scale = 0.

    modules = []

    embed_dim = nf

    modules.append(nn.Linear(embed_dim, nf * 4))
    modules[-1].weight.data = default_init()(modules[-1].weight.shape)
    nn.init.zeros_(modules[-1].bias)
    modules.append(nn.Linear(nf * 4, nf * 4))
    modules[-1].weight.data = default_init()(modules[-1].weight.shape)
    nn.init.zeros_(modules[-1].bias)

    AttnBlock = functools.partial(AttnBlockpp,
                                  init_scale=init_scale,
                                  skip_rescale=skip_rescale)

    ResnetBlock = functools.partial(ResnetBlockBigGANpp,
                                    act=act,
                                    dropout=dropout,
                                    fir_kernel=fir_kernel,
                                    init_scale=init_scale,
                                    skip_rescale=skip_rescale,
                                    temb_dim=nf * 4)

    # Downsampling block

    channels = 2

    modules.append(ddpm_conv3x3(channels, nf))
    hs_c = [nf]

    in_ch = nf
    for i_level in range(num_resolutions):
      # Residual blocks for this resolution
      for i_block in range(num_res_blocks):
        out_ch = nf * ch_mult[i_level]
        modules.append(ResnetBlock(in_ch=in_ch, out_ch=out_ch))
        in_ch = out_ch

        if all_resolutions[i_level] in attn_resolutions:
          modules.append(AttnBlock(channels=in_ch))
        hs_c.append(in_ch)

      if i_level != num_resolutions - 1:
        modules.append(ResnetBlock(down=True, in_ch=in_ch))

        hs_c.append(in_ch)

    in_ch = hs_c[-1]
    modules.append(ResnetBlock(in_ch=in_ch))
    modules.append(AttnBlock(channels=in_ch))
    modules.append(ResnetBlock(in_ch=in_ch))

    # Upsampling block
    for i_level in reversed(range(num_resolutions)):
      for i_block in range(num_res_blocks + 1):
        out_ch = nf * ch_mult[i_level]
        modules.append(ResnetBlock(in_ch=in_ch + hs_c.pop(),
                                   out_ch=out_ch))
        in_ch = out_ch

      if all_resolutions[i_level] in attn_resolutions:
        modules.append(AttnBlock(channels=in_ch))

      if i_level != 0:
        modules.append(ResnetBlock(in_ch=in_ch, up=True))


    modules.append(nn.GroupNorm(num_groups=min(in_ch // 4, 24),
                                num_channels=in_ch, eps=1e-6))
    modules.append(ddpm_conv3x3(in_ch, int(channels/2), init_scale=init_scale))

    self.all_modules = nn.ModuleList(modules)

  def forward(self, condition, x, time_cond):
    # timestep/noise_level embedding; only for continuous training
    modules = self.all_modules
    m_idx = 0
    # Sinusoidal positional embeddings.
    timesteps = time_cond
    # used_sigmas = self.sigmas[time_cond.long()]
    temb = get_timestep_embedding(timesteps, self.nf)
  
    temb = modules[m_idx](temb)
    m_idx += 1
    temb = modules[m_idx](self.act(temb))
    m_idx += 1

    x = x.unsqueeze(dim=1)
    condition = condition.unsqueeze(dim=1)
    x = torch.cat([condition, x], dim=1)

    # Downsampling block
    input_pyramid = None

    hs = [modules[m_idx](x)]

    m_idx += 1
    for i_level in range(self.num_resolutions):
      # Residual blocks for this resolution
      for i_block in range(self.num_res_blocks):
        h = modules[m_idx](hs[-1], temb)
        m_idx += 1
        if h.shape[-1] in self.attn_resolutions:
          h = modules[m_idx](h)
          m_idx += 1

        hs.append(h)

      if i_level != self.num_resolutions - 1:
        h = modules[m_idx](hs[-1], temb)
        m_idx += 1

        hs.append(h)

    h = hs[-1]

    h = modules[m_idx](h, temb)
    m_idx += 1
    h = modules[m_idx](h)
    m_idx += 1
    h = modules[m_idx](h, temb)
    m_idx += 1

    pyramid = None

    # Upsampling block
    for i_level in reversed(range(self.num_resolutions)):
      for i_block in range(self.num_res_blocks + 1):
        h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb)
        m_idx += 1

      if h.shape[-1] in self.attn_resolutions:
        h = modules[m_idx](h)
        m_idx += 1

      if i_level != 0:
        h = modules[m_idx](h, temb)
        m_idx += 1

    h = self.act(modules[m_idx](h))
    m_idx += 1
    h = modules[m_idx](h)
    m_idx += 1

    h = h.squeeze(dim=1)

    return h
