from inspect import isfunction
import math
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
import numpy as np

from ldm.modules.diffusionmodules.util import checkpoint, MPCGroupNorm, MPCGeLU, MPCSoftmax, MPCInvsqrt
from ldm.modules.diffusionmodules.basic_util import he_erf

from torch import Tensor, Size
from typing import Union, List, Optional, Tuple
from torch.nn import init
import numbers

from typing import Any

def exists(val):
    return val is not None


def uniq(arr):
    return{el: True for el in arr}.keys()


def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d


def max_neg_value(t):
    return -torch.finfo(t.dtype).max


def init_(tensor):
    dim = tensor.shape[-1]
    std = 1 / math.sqrt(dim)
    tensor.uniform_(-std, std)
    return tensor


# feedforward
class GEGLU(nn.Module):
    def __init__(self, dim_in, dim_out):
        super().__init__()
        self.proj = nn.Linear(dim_in, dim_out * 2)
        self.mpc_mode = True
        self.gelu = MPCGeLU()

    def forward(self, x):
        x = self.proj(x).transpose(0, -1)
        n_ch = x.shape[0] // 2
        x, gate = x[:n_ch].transpose(0, -1), x[n_ch:].transpose(0, -1)
        
        if self.mpc_mode:
            gelu = self.gelu(gate)
            return x * gelu
        else:
            return x * F.gelu(gate)


class MPCFeedForward(nn.Module):
    def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
        super().__init__()
        inner_dim = int(dim * mult)
        dim_out = default(dim_out, dim)
        project_in = nn.Sequential(
            nn.Linear(dim, inner_dim),
            MPCGeLU()
        ) if not glu else GEGLU(dim, inner_dim)

        self.net = nn.Sequential(
            project_in,
            nn.Identity(),
            nn.Linear(inner_dim, dim_out)
        )
        self.mpc_mode = True

    def forward(self, x):
        return self.net(x)


def zero_module(module):
    """
    Zero out the parameters of a module and return it.
    """
    for p in module.parameters():
        p.detach().zero_()
    return module


def Normalize(in_channels):
    return MPCGroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)


class MPCLinearAttention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)
        self.softmax = MPCSoftmax(dim=-1)

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x)
        # q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
        qkv_shape = qkv.shape
        qkv = qkv.reshape(qkv_shape[0], 3, self.heads, qkv_shape[1]//(3*self.heads), qkv_shape[2]*qkv_shape[3]) # b qkv heads c (h w)
        qkv = qkv.transpose(1, 2) # qkv b heads c (h w)
        q, k, v = qkv[0], qkv[1], qkv[2]
        k = self.softmax(k)

        context = k.matmul(v.transpose(2, 3))

        out = context.transpose(2, 3).matmul(q)

        out = out.reshape(out.shape[0], out.shape[1] * out.shape[2], h, w)
        return self.to_out(out)


class MPCCrossAttention(nn.Module):
    def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        self.mpc_mode = True
        inner_dim = dim_head * heads
        context_dim = default(context_dim, query_dim)

        self.scale = dim_head ** -0.5
        self.heads = heads

        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, query_dim),
            nn.Identity(),
        )

        self.softmax = MPCSoftmax(dim=-1)
        
    def forward(self, x, context=None, mask=None):
        h = self.heads

        q = self.to_q(x)
        context = default(context, x)
        k = self.to_k(context)
        v = self.to_v(context)

        if self.mpc_mode:
            q = q.reshape(q.shape[0], q.shape[1], h, q.shape[2]//h).transpose(1, 2)
            q = q.reshape(q.shape[0]*q.shape[1], q.shape[2], q.shape[3])
            
            k = k.reshape(k.shape[0], k.shape[1], h, k.shape[2]//h).transpose(1, 2)
            k = k.reshape(k.shape[0]*k.shape[1], k.shape[2], k.shape[3])
            
            v = v.reshape(v.shape[0], v.shape[1], h, v.shape[2]//h).transpose(1, 2)
            v = v.reshape(v.shape[0]*v.shape[1], v.shape[2], v.shape[3])
            
            k = k.transpose(1, 2)

            sim = torch.matmul(q, k) * self.scale
            
        else:
            q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
            sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
            
        if exists(mask):
            # mask = rearrange(mask, 'b ... -> b (...)')
            mask = mask.view(mask.shape[0], -1)
            max_neg_value = -torch.finfo(sim.dtype).max
            mask = mask.view(mask.shape[0], 1, mask.shape[1])
            mask = mask.expand(mask.shape[0], h, mask.shape[2])
            mask = mask.view(mask.shape[0]*mask.shape[1], 1, mask.shape[2])
            sim.masked_fill_(~mask, max_neg_value)

        if self.mpc_mode:
            attn = self.softmax(sim)
        else:
            attn = sim.softmax(dim=-1)
        
        if self.mpc_mode:

            out = torch.matmul(attn, v)
            
            out = out.reshape(out.shape[0]//h, h, out.shape[1], out.shape[2]).transpose(1, 2)
            out = out.reshape(out.shape[0], out.shape[1], out.shape[2] * out.shape[3])
        else:
            out = einsum('b i j, b j d -> b i d', attn, v)
            out = rearrange(out, '(b h) n d -> b n (h d)', h=h)

        return self.to_out(out)


class MPCBasicTransformerBlock(nn.Module):
    def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
        super().__init__()
        self.mpc_mode = True
        self.attn1 = MPCCrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout)  # is a self-attention
        self.ff = MPCFeedForward(dim, dropout=dropout, glu=gated_ff)
        self.attn2 = MPCCrossAttention(query_dim=dim, context_dim=context_dim,
                                    heads=n_heads, dim_head=d_head, dropout=dropout)  # is self-attn if context is none
        self.norm1 = MPCLayerNorm(dim)
        self.norm2 = MPCLayerNorm(dim)
        self.norm3 = MPCLayerNorm(dim)
        self.checkpoint = checkpoint

    def forward(self, x, context=None):
        return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)

    def _forward(self, x, context=None):
        x = self.attn1(self.norm1(x)) + x
        x = self.attn2(self.norm2(x), context=context) + x
        x = self.ff(self.norm3(x)) + x
        return x


class MPCSpatialTransformer(nn.Module):
    """
    Transformer block for image-like data.
    First, project the input (aka embedding)
    and reshape to b, t, d.
    Then apply standard transformer action.
    Finally, reshape to image
    """
    def __init__(self, in_channels, n_heads, d_head,
                 depth=1, dropout=0., context_dim=None):
        super().__init__()
        self.mpc_mode = True
        self.in_channels = in_channels
        inner_dim = n_heads * d_head
        self.norm = Normalize(in_channels)

        self.proj_in = nn.Conv2d(in_channels,
                                 inner_dim,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)

        self.transformer_blocks = nn.ModuleList(
            [MPCBasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
                for d in range(depth)]
        )

        self.proj_out = nn.Conv2d(inner_dim,
                                    in_channels,
                                    kernel_size=1,
                                    stride=1,
                                    padding=0)

    def forward(self, x, context=None):
        # note: if no context is given, cross-attention defaults to self-attention
        b, c, h, w = x.shape
        x_in = x
        x = self.norm(x)
        x = self.proj_in(x)
        x = x.reshape(x.shape[0], x.shape[1], x.shape[2]*x.shape[3])
        x = x.transpose(1, 2)
        for block in self.transformer_blocks:
            x = block(x, context=context)
        x = x.transpose(1, 2)
        x = x.reshape(x.shape[0], x.shape[1], h, w)
        x = self.proj_out(x)
        return x + x_in
    
_shape_t = Union[int, List[int], Size]
class MPCLayerNorm(nn.LayerNorm):
    __constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
    normalized_shape: Tuple[int, ...]
    eps: float
    elementwise_affine: bool

    def __init__(self, normalized_shape: _shape_t, eps: float = 1e-5, elementwise_affine: bool = True,
                 bias: bool = True, device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__(normalized_shape)
        if isinstance(normalized_shape, numbers.Integral):
            # mypy error: incompatible types in assignment
            normalized_shape = (normalized_shape,)  # type: ignore[assignment]
        self.normalized_shape = tuple(normalized_shape)  # type: ignore[arg-type]
        self.eps = eps
        self.elementwise_affine = elementwise_affine
        self.invsqrt = MPCInvsqrt(iters=8)
        self.mpc_mode = True

    def forward(self, x):
        if self.mpc_mode:
            mean = x.mean(dim=-1, keepdim=True)
            var = (x - mean).square().mean(dim=-1, keepdim=True)
            x = (x - mean) * self.invsqrt(var + self.eps)

            weight = self.weight.reshape(1,1,-1).to(torch.float32)
            bias = self.bias.reshape(1,1,-1).to(torch.float32)
            
            x = x * weight + bias
            
            return x
    
        else:
            return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        