"""
Adapted from https://github.com/cfifty/rotation_trick/blob/main/src/modules/blocks.py
which originates in taming transformers: https://github.com/CompVis/taming-transformers
"""

import torch
import torch.nn as nn



def nonlinearity(x):
  # swish
  return x * torch.sigmoid(x)


def Normalize(in_channels):
  return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)

class DepthwiseSeparableConv(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True):
    super().__init__()
    self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride,
                              padding=padding, bias=bias, groups=in_channels)
    self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1,
                              padding=0, bias=bias, groups=1)
    
  def forward(self, x):
    x = self.depthwise(x)
    x = self.pointwise(x)
    return x

class Upsample(nn.Module):
  def __init__(self, in_channels, with_conv, depthwise=False):
    super().__init__()
    self.with_conv = with_conv
    if self.with_conv:
      self.conv = (DepthwiseSeparableConv
                  if depthwise else 
                  torch.nn.Conv2d)(in_channels,
                                  in_channels,
                                  kernel_size=3,
                                  stride=1,
                                  padding=1)

  def forward(self, x):
    x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
    if self.with_conv:
      x = self.conv(x)
    return x


class Downsample(nn.Module):
  def __init__(self, in_channels, with_conv, depthwise=False):
    super().__init__()
    self.with_conv = with_conv
    if self.with_conv:
      # no asymmetric padding in torch conv, must do it ourselves
      self.conv = (DepthwiseSeparableConv
                  if depthwise else 
                  torch.nn.Conv2d)(in_channels,
                                  in_channels,
                                  kernel_size=3,
                                  stride=2,
                                  padding=0)

  def forward(self, x):
    if self.with_conv:
      pad = (0, 1, 0, 1)
      x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
      x = self.conv(x)
    else:
      x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
    return x


class ResnetBlock(nn.Module):
  def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
               dropout, temb_channels=512, depthwise=False):
    super().__init__()
    self.in_channels = in_channels
    out_channels = in_channels if out_channels is None else out_channels
    self.out_channels = out_channels
    self.use_conv_shortcut = conv_shortcut

    self.norm1 = Normalize(in_channels)
    self.conv1 = (DepthwiseSeparableConv
                  if depthwise else 
                  torch.nn.Conv2d)(in_channels,
                                 out_channels,
                                 kernel_size=3,
                                 stride=1,
                                 padding=1)
    if temb_channels > 0:
      self.temb_proj = torch.nn.Linear(temb_channels,
                                       out_channels)
    self.norm2 = Normalize(out_channels)
    self.dropout = torch.nn.Dropout(dropout)
    self.conv2 = (DepthwiseSeparableConv
                  if depthwise else 
                  torch.nn.Conv2d)(out_channels,
                                 out_channels,
                                 kernel_size=3,
                                 stride=1,
                                 padding=1)
    if self.in_channels != self.out_channels:
      if self.use_conv_shortcut:
        self.conv_shortcut = (DepthwiseSeparableConv
                              if depthwise else 
                              torch.nn.Conv2d)(in_channels,
                                             out_channels,
                                             kernel_size=3,
                                             stride=1,
                                             padding=1)
      else:
        self.nin_shortcut = torch.nn.Conv2d(in_channels,
                                            out_channels,
                                            kernel_size=1,
                                            stride=1,
                                            padding=0)

  def forward(self, x, temb):
    h = x
    h = self.norm1(h)
    h = nonlinearity(h)
    h = self.conv1(h)

    if temb is not None:
      h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]

    h = self.norm2(h)
    h = nonlinearity(h)
    h = self.dropout(h)
    h = self.conv2(h)

    if self.in_channels != self.out_channels:
      if self.use_conv_shortcut:
        x = self.conv_shortcut(x)
      else:
        x = self.nin_shortcut(x)

    return x + h


class AttnBlock(nn.Module):
  def __init__(self, in_channels):
    super().__init__()
    self.in_channels = in_channels

    self.norm = Normalize(in_channels)
    self.q = torch.nn.Conv2d(in_channels,
                             in_channels,
                             kernel_size=1,
                             stride=1,
                             padding=0)
    self.k = torch.nn.Conv2d(in_channels,
                             in_channels,
                             kernel_size=1,
                             stride=1,
                             padding=0)
    self.v = torch.nn.Conv2d(in_channels,
                             in_channels,
                             kernel_size=1,
                             stride=1,
                             padding=0)
    self.proj_out = torch.nn.Conv2d(in_channels,
                                    in_channels,
                                    kernel_size=1,
                                    stride=1,
                                    padding=0)

  def forward(self, x):
    h_ = x
    h_ = self.norm(h_)
    q = self.q(h_)
    k = self.k(h_)
    v = self.v(h_)

    # compute attention
    b, c, h, w = q.shape
    q = q.reshape(b, c, h * w)
    q = q.permute(0, 2, 1)  # b,hw,c
    k = k.reshape(b, c, h * w)  # b,c,hw
    w_ = torch.bmm(q, k)  # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
    w_ = w_ * (int(c) ** (-0.5))
    w_ = torch.nn.functional.softmax(w_, dim=2)

    # attend to values
    v = v.reshape(b, c, h * w)
    w_ = w_.permute(0, 2, 1)  # b,hw,hw (first hw of k, second of q)
    h_ = torch.bmm(v, w_)  # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
    h_ = h_.reshape(b, c, h, w)

    h_ = self.proj_out(h_)

    return x + h_