import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from operator import itemgetter

# helper functions

def variance_scaling(scale, mode, distribution,
                     in_axis=1, out_axis=0,
                     dtype=torch.float32,
                     device='cpu'):
  """Ported from JAX. """

  def _compute_fans(shape, in_axis=1, out_axis=0):
    receptive_field_size = np.prod(shape) / shape[in_axis] / shape[out_axis]
    fan_in = shape[in_axis] * receptive_field_size
    fan_out = shape[out_axis] * receptive_field_size
    return fan_in, fan_out

  def init(shape, dtype=dtype, device=device):
    fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
    if mode == "fan_in":
      denominator = fan_in
    elif mode == "fan_out":
      denominator = fan_out
    elif mode == "fan_avg":
      denominator = (fan_in + fan_out) / 2
    else:
      raise ValueError(
        "invalid mode for variance scaling initializer: {}".format(mode))
    variance = scale / denominator
    if distribution == "normal":
      return torch.randn(*shape, dtype=dtype, device=device) * np.sqrt(variance)
    elif distribution == "uniform":
      return (torch.rand(*shape, dtype=dtype, device=device) * 2. - 1.) * np.sqrt(3 * variance)
    else:
      raise ValueError("invalid distribution for variance scaling initializer")

  return init

def default_init(scale=1.):
  """The same initialization used in DDPM."""
  scale = 1e-10 if scale == 0 else scale
  return variance_scaling(scale, 'fan_avg', 'uniform')

class GaussianFourierProjection(nn.Module):
    """Gaussian Fourier embeddings for noise levels."""

    def __init__(self, embedding_size=256, scale=1.0):
        super().__init__()
        self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)

    def forward(self, x):
        x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
        return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)

class Linear(nn.Linear):
    def __init__(self, dim_in, dim_out, init_scale=0.1):
        super().__init__(dim_in, dim_out)
        self.weight.data = default_init(scale=init_scale)(self.weight.shape)
        nn.init.zeros_(self.bias)
        self.dim_in, self.dim_out = dim_in, dim_out

class ChannelLinear(Linear):
    def __init__(self, dim_in, dim_out, init_scale=0.1):
        super().__init__(dim_in, dim_out, init_scale)

    def forward(self, x):
        assert len(x.shape) == 4
        xs = [int(k) for k in x.shape]
        assert xs[1] == self.dim_in, 'xs[1]: {}, self.dim_in: {}'.format(xs[1], self.dim_in)
        x = x.reshape(xs[0], xs[1], 1, xs[2], xs[3])
        y = torch.einsum('ij,njkml->nikml', self.weight, x) + self.bias.reshape(-1, self.dim_out, 1, 1, 1)

        return y[:,:,0,:,:]

def exists(val):
    return val is not None

def map_el_ind(arr, ind):
    return list(map(itemgetter(ind), arr))

def sort_and_return_indices(arr):
    indices = [ind for ind in range(len(arr))]
    arr = zip(arr, indices)
    arr = sorted(arr)
    return map_el_ind(arr, 0), map_el_ind(arr, 1)

# calculates the permutation to bring the input tensor to something attend-able
# also calculates the inverse permutation to bring the tensor back to its original shape

# num_dimensions = number of dimensions that are not batch_dim or channel_dim
# embed_dim = channel_dim
# for example, if data shape = (n, c, h, w)
# num_dimensions = 2, channel_dim = 1
def calculate_permutations(num_dimensions, channel_dim):
    total_dimensions = num_dimensions + 2
    channel_dim %= total_dimensions
    axial_dims = [ind for ind in range(1, total_dimensions) if ind != channel_dim]
    axial_dims.reverse()

    permutations = []

    for axial_dim in axial_dims:
        last_two_dims = [axial_dim, channel_dim]
        dims_rest = set(range(0, total_dimensions)) - set(last_two_dims)
        permutation = [*dims_rest, *last_two_dims]
        permutations.append(permutation)

    return permutations

# helper classes

class Sequential(nn.Module):
    def __init__(self, blocks):
        super().__init__()
        self.blocks = blocks

    def forward(self, x, **kwargs):
        for f in self.blocks:
            x = f(x, **kwargs)
        return x

class PermuteToFrom(nn.Module):
    def __init__(self, permutation, fn):
        super().__init__()
        self.fn = fn
        _, inv_permutation = sort_and_return_indices(permutation)
        self.permutation = permutation
        self.inv_permutation = inv_permutation

    def forward(self, x, **kwargs):
        axial = x.permute(*self.permutation).contiguous()

        shape = axial.shape
        *_, t, d = shape

        # merge all but axial dimension
        axial = axial.reshape(-1, t, d)

        # attention
        axial = self.fn(axial, **kwargs)

        # restore to original shape and permutation
        axial = axial.reshape(*shape)
        axial = axial.permute(*self.inv_permutation).contiguous()
        return axial

# axial pos emb

class AxialPositionalEmbedding(nn.Module):
    def __init__(self, dim, shape, channel_dim=1, learnable=False, const=10000):
        super().__init__()
        self.const = const
        parameters = []
        total_dimensions = len(shape) + 2
        ax_dim_indices = [i for i in range(1, total_dimensions) if i != channel_dim]

        self.num_axials = len(shape)

        for i, (axial_dim, axial_dim_index) in enumerate(zip(shape, ax_dim_indices)):
            shape = [1] * total_dimensions
            shape[channel_dim] = dim
            shape[axial_dim_index] = axial_dim
#             parameter = nn.Parameter(torch.randn(*shape))
            parameter = self.get_encoding_matrix(dim, axial_dim).reshape(shape)
            if learnable:
                parameter = nn.Parameter(parameter)
                setattr(self, f'param_{i}', parameter)
            else:
                self.register_buffer(f'param_{i}', parameter)

    def get_encoding_matrix(self, dim, axial_dim):
        n = axial_dim # sequence length
        d = dim # number of channels
        # build a (d, n) shape matrix
        i = torch.arange((n + 1) // 2).reshape(1, (n + 1) // 2).tile(d, 1)
        k = torch.arange(d).reshape(d, 1).tile(1, (n + 1) // 2)
        tmp = k / (self.const ** (2 * i / d))
        sin_component = torch.sin(tmp).reshape(d, (n + 1) // 2, 1)
        cos_component = torch.cos(tmp).reshape(d, (n + 1) // 2, 1)
        mat = torch.cat([sin_component, cos_component], axis=2).reshape(d, n)
        return mat

    def forward(self, x):
        for i in range(self.num_axials):
            x = x + getattr(self, f'param_{i}')
        return x

# attention

class Attentionpp(nn.Module):
    """Channel-wise self-attention block. Modified from DDPM."""
    def __init__(self, nf, heads, causal=True, skip_rescale=True, init_scale=0.):
        super().__init__()
        self.heads = heads
        self.causal = causal
        self.nf_heads = (nf // heads)
        nf_tot = self.nf_heads * heads
        assert nf_tot == nf

        self.LN = nn.LayerNorm(nf)
        self.NIN_0 = Linear(nf, nf)
        self.NIN_1 = Linear(nf, nf)
        self.NIN_2 = Linear(nf, nf)
        self.NIN_3 = Linear(nf, nf, init_scale=init_scale)
        self.skip_rescale = skip_rescale

    def maybe_apply_causal_mask(self, x):
        if self.causal:
            x = torch.tril(x, diagonal=0)
        return x

    def causal_softmax(self, x, dim=-1, eps=1e-7):
        x = self.maybe_apply_causal_mask(x)
        x = x.softmax(dim=dim)
        x = self.maybe_apply_causal_mask(x)

        if self.causal:
            # renormalize
            x = x / (x.sum(dim=dim).unsqueeze(dim) + eps)

        return x

    def forward(self, x):
        h = self.LN(x)
        q = self.NIN_0(h)
        k = self.NIN_1(h)
        v = self.NIN_2(h)

        b, t, d, m, e = *q.shape, self.heads, self.nf_heads

        merge_heads = lambda x: x.reshape(b, t, m, e).transpose(1, 2).reshape(b * m, t, e)
        q, k, v = map(merge_heads, (q, k, v))

        dots = torch.einsum('bie,bje->bij', q, k) * (e ** -0.5)
        dots = self.causal_softmax(dots)
        h = torch.einsum('bij,bje->bie', dots, v)

        h = h.reshape(b, m, t, e).transpose(1, 2).reshape(b, t, d)
        h = self.NIN_3(h)
        if not self.skip_rescale:
            return x + h
        else:
            return (x + h) / np.sqrt(2.)

