# coding=utf-8
"""
Codes adapted from
https://github.com/NVlabs/denoising-diffusion-gan/blob/main/score_sde/models/discriminator.py
"""


import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from . import up_or_down_sampling
from . import dense_layer

dense = dense_layer.dense
conv2d = dense_layer.conv2d


class PositionalTimeEmbedding(nn.Module):
  """Positional embeddings for time t, specific for discriminator"""
  def __init__(self, T, act, embedding_dim, out_dim, max_positions=10000):
    super().__init__()
    half_dim = embedding_dim // 2
    emb = math.log(max_positions) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim) * -emb)
    pos = torch.arange(T).float()
    emb = pos[:, None] * emb[None, :]
    assert list(emb.shape) == [T, half_dim]
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
    if embedding_dim % 2 == 1:  # zero pad
      emb = F.pad(emb, (0, 1), mode='constant')
    assert list(emb.shape) == [T, embedding_dim]

    self.timembedding = nn.Sequential(
      nn.Embedding.from_pretrained(emb),
      dense(embedding_dim, out_dim),
      act,
      dense(out_dim, out_dim),
    )

  def forward(self, t):
    assert len(t.shape) == 1
    emb = self.timembedding(t)
    return emb


class DownConvBlock(nn.Module):
    def __init__(self, act, in_ch, out_ch, temb_dim, kernel_size=3, padding=1,
                 downsample=False, fir_kernel=(1, 3, 3, 1)):
        super().__init__()
        self.fir_kernel = fir_kernel
        self.downsample = downsample

        self.conv1 = conv2d(in_ch, out_ch, kernel_size, padding=padding)
        self.conv2 = conv2d(out_ch, out_ch, kernel_size, padding=padding, init_scale=0.)
        self.dense_t1 = dense(temb_dim, out_ch)
        self.act = act
        self.skip = conv2d(in_ch, out_ch, 1, padding=0, bias=False)

    def forward(self, input, temb):
        out = self.act(input)
        out = self.conv1(out)
        out += self.dense_t1(temb)[..., None, None]
        out = self.act(out)
        if self.downsample:
            out = up_or_down_sampling.downsample_2d(out, self.fir_kernel, factor=2)
            input = up_or_down_sampling.downsample_2d(input, self.fir_kernel, factor=2)
        out = self.conv2(out)
        skip = self.skip(input)
        out = (out + skip) / np.sqrt(2)
        return out


class SmallEBM(nn.Module):
  """
  A time-dependent conditional energy-based model for small images (CIFAR10, StackMNIST).
  E(x_{t-1} | x_t, t): X x X --> R
  """
  def __init__(self, image_channels, ndf, reduced_T, temb_dim, act=nn.LeakyReLU(0.2)):
    super().__init__()
    self.act = act
    self.time_embedding = PositionalTimeEmbedding(T=reduced_T, act=act, embedding_dim=temb_dim, out_dim=temb_dim)

    # Encoding layers where the resolution decreases
    self.start_conv = conv2d(image_channels, ndf * 2, 1, padding=0)
    self.conv1 = DownConvBlock(act=act, in_ch=ndf * 2, out_ch=ndf * 2, temb_dim=temb_dim)
    self.conv2 = DownConvBlock(act=act, in_ch=ndf * 2, out_ch=ndf * 4, temb_dim=temb_dim, downsample=True)
    self.conv3 = DownConvBlock(act=act, in_ch=ndf * 4, out_ch=ndf * 8, temb_dim=temb_dim, downsample=True)
    self.conv4 = DownConvBlock(act=act, in_ch=ndf * 8, out_ch=ndf * 8, temb_dim=temb_dim, downsample=True)
    self.final_conv = conv2d(ndf * 8 + 1, ndf * 8, 3, padding=1, init_scale=0.)
    self.end_linear = dense(ndf * 8, 1)
    self.stddev_group = 4
    self.stddev_feat = 1

  def forward(self, x, t, x_t):
    temb = self.act(self.time_embedding(t))

    input_x = torch.cat((x, x_t), dim=1)
    h = self.start_conv(input_x)
    h = self.conv1(h, temb)
    h = self.conv2(h, temb)
    h = self.conv3(h, temb)
    out = self.conv4(h, temb)

    batch, channel, height, width = out.shape
    group = min(batch, self.stddev_group)
    stddev = out.view(group, -1, self.stddev_feat, channel // self.stddev_feat, height, width)
    stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
    stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
    stddev = stddev.repeat(group, 1, height, width)

    out = torch.cat([out, stddev], 1)
    out = self.final_conv(out)
    out = self.act(out)
    out = out.view(out.shape[0], out.shape[1], -1).sum(2)
    out = self.end_linear(out)
    
    return out


class MiddleEBM(nn.Module):
  """
  A time-dependent conditional energy-based model for middle images (Celeba64, Celeba128, Imagenet64).
  E(x_{t-1} | x_t, t): X x X --> R
  """
  def __init__(self, image_channels, ndf, reduced_T, temb_dim, img_size, act=nn.LeakyReLU(0.2)):
    super().__init__()
    self.act = act
    self.time_embedding = PositionalTimeEmbedding(T=reduced_T, act=act, embedding_dim=temb_dim, out_dim=temb_dim)

    # Encoding layers where the resolution decreases
    self.start_conv = conv2d(image_channels, ndf * 2, 1, padding=0)
    if img_size == 64:
      self.conv1 = DownConvBlock(act=act, in_ch=ndf * 2, out_ch=ndf * 2, temb_dim=temb_dim)
      self.conv2 = DownConvBlock(act=act, in_ch=ndf * 2, out_ch=ndf * 4, temb_dim=temb_dim, downsample=True)
      self.conv3 = DownConvBlock(act=act, in_ch=ndf * 4, out_ch=ndf * 8, temb_dim=temb_dim, downsample=True)
      self.conv4 = DownConvBlock(act=act, in_ch=ndf * 8, out_ch=ndf * 8, temb_dim=temb_dim, downsample=True)
      self.conv5 = DownConvBlock(act=act, in_ch=ndf * 8, out_ch=ndf * 8, temb_dim=temb_dim, downsample=True)
    elif img_size == 128:
      self.conv1 = DownConvBlock(act=act, in_ch=ndf * 2, out_ch=ndf * 4, temb_dim=temb_dim, downsample=True)
      self.conv2 = DownConvBlock(act=act, in_ch=ndf * 4, out_ch=ndf * 8, temb_dim=temb_dim, downsample=True)
      self.conv3 = DownConvBlock(act=act, in_ch=ndf * 8, out_ch=ndf * 8, temb_dim=temb_dim, downsample=True)
      self.conv4 = DownConvBlock(act=act, in_ch=ndf * 8, out_ch=ndf * 8, temb_dim=temb_dim, downsample=True)
      self.conv5 = DownConvBlock(act=act, in_ch=ndf * 8, out_ch=ndf * 8, temb_dim=temb_dim, downsample=True)
    else:
      raise NotImplementedError
    self.final_conv = conv2d(ndf * 8 + 1, ndf * 8, 3, padding=1, init_scale=0.)
    self.end_linear = dense(ndf * 8, 1)
    self.stddev_group = 4
    self.stddev_feat = 1

  def forward(self, x, t, x_t):
    temb = self.act(self.time_embedding(t))

    input_x = torch.cat((x, x_t), dim=1)
    h = self.start_conv(input_x)
    h = self.conv1(h, temb)
    h = self.conv2(h, temb)
    h = self.conv3(h, temb)
    h = self.conv4(h, temb)
    out = self.conv5(h, temb)

    batch, channel, height, width = out.shape
    group = min(batch, self.stddev_group)
    stddev = out.view(group, -1, self.stddev_feat, channel // self.stddev_feat, height, width)
    stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
    stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
    stddev = stddev.repeat(group, 1, height, width)

    out = torch.cat([out, stddev], 1)
    out = self.final_conv(out)
    out = self.act(out)
    out = out.view(out.shape[0], out.shape[1], -1).sum(2)
    out = self.end_linear(out)
    
    return out


class LargeEBM(nn.Module):
  """
  A time-dependent conditional energy-based model for large images (CelebA, LSUN).
  E(x_{t-1} | x_t, t): X x X --> R
  """
  def __init__(self, image_channels, ndf, reduced_T, temb_dim, act=nn.LeakyReLU(0.2)):
    super().__init__()
    self.act = act
    self.time_embedding = PositionalTimeEmbedding(T=reduced_T, act=act, embedding_dim=temb_dim, out_dim=temb_dim)

    # Encoding layers where the resolution decreases
    self.start_conv = conv2d(image_channels, ndf * 2, 1, padding=0)
    self.conv1 = DownConvBlock(act=act, in_ch=ndf * 2, out_ch=ndf * 4, temb_dim=temb_dim, downsample=True)
    self.conv2 = DownConvBlock(act=act, in_ch=ndf * 4, out_ch=ndf * 8, temb_dim=temb_dim, downsample=True)
    self.conv3 = DownConvBlock(act=act, in_ch=ndf * 8, out_ch=ndf * 8, temb_dim=temb_dim, downsample=True)
    self.conv4 = DownConvBlock(act=act, in_ch=ndf * 8, out_ch=ndf * 8, temb_dim=temb_dim, downsample=True)
    self.conv5 = DownConvBlock(act=act, in_ch=ndf * 8, out_ch=ndf * 8, temb_dim=temb_dim, downsample=True)
    self.conv6 = DownConvBlock(act=act, in_ch=ndf * 8, out_ch=ndf * 8, temb_dim=temb_dim, downsample=True)
    self.final_conv = conv2d(ndf * 8 + 1, ndf * 8, 3, padding=1)
    self.end_linear = dense(ndf * 8, 1)
    self.stddev_group = 4
    self.stddev_feat = 1

  def forward(self, x, t, x_t):
    temb = self.act(self.time_embedding(t))
    
    input_x = torch.cat((x, x_t), dim=1)
    h = self.start_conv(input_x)
    h = self.conv1(h, temb)
    h = self.conv2(h, temb)
    h = self.conv3(h, temb)
    h = self.conv4(h, temb)
    h = self.conv5(h, temb)
    out = self.conv6(h, temb)
    
    batch, channel, height, width = out.shape
    group = min(batch, self.stddev_group)
    stddev = out.view(group, -1, self.stddev_feat, channel // self.stddev_feat, height, width)
    stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
    stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
    stddev = stddev.repeat(group, 1, height, width)

    out = torch.cat([out, stddev], 1)
    out = self.final_conv(out)
    out = self.act(out)
    out = out.view(out.shape[0], out.shape[1], -1).sum(2)
    out = self.end_linear(out)
    
    return out

