# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: skip-file

from torch.nn.functional import embedding
import torch.nn as nn
import torch
import numpy as np
import math

NONLINEARITIES = {
    "elu": nn.ELU(),
    "relu": nn.ReLU(),
    "lrelu": nn.LeakyReLU(negative_slope=0.2),
    "swish": nn.SiLU(),
    "tanh": nn.Tanh(),
    "softplus": nn.Softplus(),
    # "square": Lambda(lambda x: x**2),
    # "identity": Lambda(lambda x: x),
}

class IgnoreLinear(nn.Module):
    def __init__(self, dim_in, dim_out):
        super(IgnoreLinear, self).__init__()
        self._layer = nn.Linear(dim_in, dim_out)
        self.bn = nn.BatchNorm1d(dim_out)


    def forward(self, t, x):
        return self.bn(self._layer(x))


class BlendLinear(nn.Module):
    def __init__(self, dim_in, dim_out, layer_type=nn.Linear, **unused_kwargs):
        super(BlendLinear, self).__init__()
        self._layer0 = layer_type(dim_in, dim_out)
        self._layer1 = layer_type(dim_in, dim_out)
        self.bn = nn.BatchNorm1d(dim_out)

    def forward(self, t, x):
        y0 = self._layer0(x)
        y1 = self._layer1(x)
        out = y0 + (y1 - y0) * t[:, None]
        out = self.bn(out)
        return out


class ConcatLinear(nn.Module):
    def __init__(self, dim_in, dim_out):
        super(ConcatLinear, self).__init__()
        self._layer = nn.Linear(dim_in + 1, dim_out)
        self.bn = nn.BatchNorm1d(dim_out)

    def forward(self, t, x):
        tt = torch.ones_like(x[:, :1]) * t[:, None]
        ttx = torch.cat([tt, x], 1)

        # return self.bn(self._layer(ttx))
        return self._layer(ttx)


class ConcatLinear_v2(nn.Module):
    def __init__(self, dim_in, dim_out):
        super(ConcatLinear, self).__init__()
        self._layer = nn.Linear(dim_in, dim_out)
        self._hyper_bias = nn.Linear(1, dim_out, bias=False)

    def forward(self, t, x):
        return self._layer(x) + self._hyper_bias(t.view(-1, 1))


class SquashLinear(nn.Module):
    def __init__(self, dim_in, dim_out):
        super(SquashLinear, self).__init__()
        self._layer = nn.Linear(dim_in, dim_out)
        self._hyper = nn.Linear(1, dim_out)

    def forward(self, t, x):
        return self._layer(x) * torch.sigmoid(self._hyper(t.view(-1, 1)))


class ConcatSquashLinear(nn.Module):
    def __init__(self, dim_in, dim_out):
        super(ConcatSquashLinear, self).__init__()
        self._layer = nn.Linear(dim_in, dim_out)
        self._hyper_bias = nn.Linear(1, dim_out, bias=False)
        self._hyper_gate = nn.Linear(1, dim_out)

    def forward(self, t, x):
        return self._layer(x) * torch.sigmoid(self._hyper_gate(t.view(-1, 1))) \
            + self._hyper_bias(t.view(-1, 1))

def variance_scaling(scale, mode, distribution,
                     in_axis=1, out_axis=0,
                     dtype=torch.float32,
                     device='cpu'):
  """Ported from JAX. """

  def _compute_fans(shape, in_axis=1, out_axis=0):
    receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
    fan_in = shape[in_axis] * receptive_field_size
    fan_out = shape[out_axis] * receptive_field_size
    return fan_in, fan_out

  def init(shape, dtype=dtype, device=device):
    fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
    if mode == "fan_in":
      denominator = fan_in
    elif mode == "fan_out":
      denominator = fan_out
    elif mode == "fan_avg":
      denominator = (fan_in + fan_out) / 2
    else:
      raise ValueError(
        "invalid mode for variance scaling initializer: {}".format(mode))
    variance = scale / denominator
    if distribution == "normal":
      return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance)
    elif distribution == "uniform":
      return (torch.rand(*shape, dtype=dtype, device=device) * 2. - 1.) * np.sqrt(3 * variance)
    else:
      raise ValueError("invalid distribution for variance scaling initializer")

  return init

def default_init(scale=1.):
  """The same initialization used in DDPM."""
  scale = 1e-10 if scale == 0 else scale
  return variance_scaling(scale, 'fan_avg', 'uniform')


def get_timestep_embedding(timesteps, embedding_dim, max_positions=10000):
  assert len(timesteps.shape) == 1  # and timesteps.dtype == tf.int32
  half_dim = embedding_dim // 2
  # magic number 10000 is from transformers
  emb = math.log(max_positions) / (half_dim - 1)
  # emb = math.log(2.) / (half_dim - 1)
  emb = torch.exp(torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) * -emb)
  # emb = tf.range(num_embeddings, dtype=jnp.float32)[:, None] * emb[None, :]
  # emb = tf.cast(timesteps, dtype=jnp.float32)[:, None] * emb[None, :]
  emb = timesteps.float()[:, None] * emb[None, :]
  emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
  if embedding_dim % 2 == 1:  # zero pad
    emb = F.pad(emb, (0, 1), mode='constant')
  assert emb.shape == (timesteps.shape[0], embedding_dim)
  return emb

class GaussianFourierProjection(nn.Module):
  """Gaussian Fourier embeddings for noise levels."""

  def __init__(self, embedding_size=256, scale=1.0):
    super().__init__()
    self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)

  def forward(self, x):
    x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
    return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)

def get_sigmas(config):

  sigmas = np.exp(
    np.linspace(np.log(config.model.sigma_max), np.log(config.model.sigma_min), config.model.num_scales))

  return sigmas

class Time_Score(nn.Module):

  def __init__(self, config):
    super().__init__()
    base_layer = {
      "ignore": IgnoreLinear,
      "squash": SquashLinear,
      "concat": ConcatLinear,
      "concat_v2": ConcatLinear_v2,
      "concatsquash": ConcatSquashLinear,
      "blend": BlendLinear,
      "concatcoord": ConcatLinear,
    }
    self.config = config
    self.act = NONLINEARITIES[config.model.nonlinearity.lower()]
    self.register_buffer('sigmas', torch.tensor(get_sigmas(config)))
    self.hidden_dims = [256,512,1024,1024,512,256] 


    self.nf = nf = 24

    self.conditional = conditional = config.model.conditional  # noise-conditional
    self.embedding_type = embedding_type = config.model.embedding_type.lower()
    self.dataset = "default"#config.data.dataset

    modules = []
    embed_dim = nf
    dim = 24

    for item in self.hidden_dims:  #list(config.model.hidden_dims):
      modules += [
          base_layer["squash"](dim, item)
      ]
      dim += item
      modules.append(NONLINEARITIES["relu"])

    modules.append(nn.Linear(dim, 24))
    self.all_modules = nn.ModuleList(modules)


  def forward(self, x0, time_cond):
    # timestep/noise_level embedding; 
    modules = self.all_modules 
    x = x0
    m_idx = 0
    temb = x
    for _ in range(len(self.hidden_dims)):
      temb1 = modules[m_idx](t=time_cond, x=temb)
      temb = torch.cat([temb1, temb], dim=1)
      m_idx += 1
      temb = modules[m_idx](temb) # activation
      m_idx += 1

    result = modules[m_idx](temb)

    # if self.config.model.scale_by_sigma:
    #   used_sigmas = used_sigmas.reshape((x.shape[0], *([1] * len(x.shape[1:]))))
    #   h = h / used_sigmas

    return result


class conditional_Time_Score(nn.Module):

  def __init__(self, config):
    super().__init__()
    base_layer = {
      "ignore": IgnoreLinear,
      "squash": SquashLinear,
      "concat": ConcatLinear,
      "concat_v2": ConcatLinear_v2,
      "concatsquash": ConcatSquashLinear,
      "blend": BlendLinear,
      "concatcoord": ConcatLinear,
    }
    self.config = config
    self.act = NONLINEARITIES[config.model.nonlinearity.lower()]
    self.register_buffer('sigmas', torch.tensor(get_sigmas(config)))
    self.hidden_dims = [256,512,1024,2048,2048,1024,512,256] 


    self.nf = nf = 24

    self.conditional = conditional = config.model.conditional  # noise-conditional
    self.embedding_type = embedding_type = config.model.embedding_type.lower()
    self.dataset = "default"#config.data.dataset

    modules = []
    embed_dim = nf
    dim = 48

    for item in self.hidden_dims:  #list(config.model.hidden_dims):
      modules += [
          base_layer["concatsquash"](dim, item)
      ]
      dim += item
      modules.append(NONLINEARITIES["relu"])

    modules.append(nn.Linear(dim, 24))
    self.all_modules = nn.ModuleList(modules)


  def forward(self, h, x0, time_cond):
    # timestep/noise_level embedding; 
    modules = self.all_modules 
    x = torch.cat([h,x0], dim=1)
    m_idx = 0
    temb = x
    for _ in range(len(self.hidden_dims)):
      temb1 = modules[m_idx](t=time_cond, x=temb)
      temb = torch.cat([temb1, temb], dim=1)
      m_idx += 1
      temb = modules[m_idx](temb) # activation
      m_idx += 1

    result = modules[m_idx](temb)

    # if self.config.model.scale_by_sigma:
    #   used_sigmas = used_sigmas.reshape((x.shape[0], *([1] * len(x.shape[1:]))))
    #   h = h / used_sigmas

    return result
