import inspect
from typing import Any, List, Optional

import math
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from inspect import isfunction

# from celldiff.modules.diffusionmodules.util import checkpoint
from celldiff.modules.layers.performer import PerformerAttention


def exists(val):
    return val is not None


def uniq(arr):
    return {el: True for el in arr}.keys()


def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d


def max_neg_value(t):
    return -torch.finfo(t.dtype).max


def init_(tensor):
    dim = tensor.shape[-1]
    std = 1 / math.sqrt(dim)
    tensor.uniform_(-std, std)
    return tensor


def all_reduce_add(src: torch.Tensor, idx: torch.Tensor, dim: int):
    idx = idx.unsqueeze(-1).expand_as(src)
    return torch.zeros_like(src).scatter_add_(dim, idx, src).gather(dim, idx)


# feedforward
class GEGLU(nn.Module):
    def __init__(self, dim_in, dim_out):
        super().__init__()
        self.proj = nn.Linear(dim_in, dim_out * 2)

    def forward(self, x):
        x, gate = self.proj(x).chunk(2, dim=-1)
        return x * F.gelu(gate)


class FeedForward(nn.Module):
    def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
        super().__init__()
        inner_dim = int(dim * mult)
        dim_out = default(dim_out, dim)
        project_in = nn.Sequential(
            nn.Linear(dim, inner_dim),
            nn.GELU()
        ) if not glu else GEGLU(dim, inner_dim)

        self.net = nn.Sequential(
            project_in,
            nn.Dropout(dropout),
            nn.Linear(inner_dim, dim_out)
        )

    def forward(self, x):
        return self.net(x)


def zero_module(module):
    """
    Zero out the parameters of a module and return it.
    """
    for p in module.parameters():
        p.detach().zero_()
    return module


def Normalize(in_dim, num_groups=32, norm='group'):
    if norm == 'group':
        return nn.GroupNorm(num_groups=num_groups, num_channels=in_dim, eps=1e-6, affine=True)
    elif norm == 'layer':
        return nn.LayerNorm(in_dim)
    elif norm == 'batch':
        return nn.BatchNorm1d(in_dim)
    else:
        raise NotImplementedError(f"norm must in ['group', 'layer', 'batch'], but found {norm}")


class BatchedOperation:
    """Wrapper to expand batched dimension for input tensors.

    Args:
        batch_dim: Which dimension the batch goes.
        plain_num_dim: Number of dimensions for plain (i.e., no batch) inputs,
            which is used to determine whether the input the batched or not.
        ignored_args: Which arguments to ignored for automatic batch dimension
            expansion.
        squeeze_output_batch: If set to True, then try to squeeze out the batch
            dimension of the output tensor.

    """

    def __init__(
        self,
        batch_dim: int = 0,
        plain_num_dim: int = 2,
        ignored_args: Optional[List[str]] = None,
        squeeze_output_batch: bool = True,
    ):
        self.batch_dim = batch_dim
        self.plain_num_dim = plain_num_dim
        self.ignored_args = set(ignored_args or [])
        self.squeeze_output_batch = squeeze_output_batch
        self._is_batched = None

    def __call__(self, func):
        arg_names = inspect.getfullargspec(func).args

        def bounded_func(*args, **kwargs):
            new_args = []
            for arg_name, arg in zip(arg_names, args):
                if self.unsqueeze_batch_dim(arg_name, arg):
                    arg = arg.unsqueeze(self.batch_dim)
                new_args.append(arg)

            for arg_name, arg in kwargs.items():
                if self.unsqueeze_batch_dim(arg_name, arg):
                    kwargs[arg_name] = arg.unsqueeze(self.batch_dim)

            out = func(*new_args, **kwargs)

            if self.squeeze_output_batch:
                out = out.squeeze(self.batch_dim)

            return out

        return bounded_func

    def unsqueeze_batch_dim(self, arg_name: str, arg_val: Any) -> bool:
        return (
            isinstance(arg_val, torch.Tensor)
            and (arg_name not in self.ignored_args)
            and (not self.is_batched(arg_val))
        )

    def is_batched(self, val: torch.Tensor) -> bool:
        num_dim = len(val.shape)
        if num_dim == self.plain_num_dim:
            return False
        elif num_dim == self.plain_num_dim + 1:
            return True
        else:
            raise ValueError(
                f"Tensor should have either {self.plain_num_dim} or "
                f"{self.plain_num_dim + 1} number of dimension, got {num_dim}",
            )


class LinearAttention(PerformerAttention):
    def __init__(
        self,
        query_dim,
        context_dim=None,
        causal=False,
        heads=8,
        dim_head=64,
        local_heads=0,
        local_window_size=256,
        nb_features=None,
        generalized_attention=False,
        kernel_fn=nn.ReLU(),
        dropout=0.,
        no_projection=False,
        qkv_bias=False
    ):
        assert (
            (context_dim is None)
            or (context_dim == query_dim)
        ), f"Only allow context_dim = query_dim now, got {context_dim=}, {query_dim=}"
        super().__init__(
            query_dim,
            causal=causal,
            heads=heads,
            dim_head=dim_head,
            local_heads=local_heads,
            local_window_size=local_window_size,
            nb_features=nb_features,
            generalized_attention=generalized_attention,
            kernel_fn=kernel_fn,
            dropout=dropout,
            no_projection=no_projection,
            qkv_bias=qkv_bias,
        )


class CrossAttention(nn.Module):
    def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., qkv_bias=False):
        super().__init__()
        inner_dim = dim_head * heads
        context_dim = default(context_dim, query_dim)

        self.scale = dim_head ** -0.5
        self.heads = heads

        self.to_q = nn.Linear(query_dim, inner_dim, bias=qkv_bias)
        self.to_k = nn.Linear(context_dim, inner_dim, bias=qkv_bias)
        self.to_v = nn.Linear(context_dim, inner_dim, bias=qkv_bias)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, query_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, *, context=None, mask=None):
        h = self.heads
        q = self.to_q(x)
        context = default(context, x)
        k = self.to_k(context)
        v = self.to_v(context)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
        sim = einsum('b i d, b j d -> b i j', q, k) * self.scale

        if exists(mask):
            max_neg_value = -torch.finfo(sim.dtype).max
            if sim.shape[1:] == sim.shape[1:]:
                mask = repeat(mask, 'b ... -> (b h) ...', h=h)
            else:
                mask = rearrange(mask, 'b ... -> b (...)')
                mask = repeat(mask, 'b j -> (b h) () j', h=h)
            sim.masked_fill_(~mask, max_neg_value)

        # attention, what we cannot get enough of
        attn = sim.softmax(dim=-1)

        out = einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
        return self.to_out(out)


class TokenSpecificCrossAttention(CrossAttention):
    def forward(self, x, *, context=None, mask=None, expand_factors=None):
        assert context is not None and expand_factors is not None and mask is None
        h = self.heads
        q = self.to_q(x)
        k = self.to_k(context)
        v = self.to_v(context)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
        q_expanded = q.repeat_interleave(expand_factors, dim=1)  # expand to cond tokens

        sim_exp = ((q_expanded * k).sum(-1, keepdim=True) * self.scale).exp()
        idxs = torch.arange(q.shape[1], device=q.device).repeat_interleave(expand_factors, dim=0)
        denom = all_reduce_add(sim_exp, idxs, dim=1)
        attn = sim_exp / denom

        idx_expanded = idxs.unsqueeze(0).unsqueeze(-1).expand_as(q_expanded)
        out = torch.zeros_like(q).scatter_add_(1, idx_expanded, attn * v)

        out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
        return self.to_out(out)


class BasicTransformerBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        n_heads: int,
        d_head: int = 64,
        self_attn: bool = True,
        cross_attn: bool = False,
        ts_cross_attn: bool = False,
        final_act: Optional[nn.Module] = None,
        dropout: float = 0.,
        context_dim: Optional[int] = None,
        gated_ff: bool = True,
        checkpoint: bool = False,
        qkv_bias: bool = False,
        linear_attn: bool = False,
    ):
        super().__init__()
        assert self_attn or cross_attn, 'At least on attention layer'
        self.self_attn = self_attn
        self.cross_attn = cross_attn
        self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
        if ts_cross_attn:
            assert not (self_attn or linear_attn)
            attn_cls = TokenSpecificCrossAttention
        else:
            attn_cls = LinearAttention if linear_attn else CrossAttention
        if self.cross_attn:
            self.attn1 = attn_cls(
                query_dim=dim,
                context_dim=context_dim,
                heads=n_heads,
                dim_head=d_head,
                dropout=dropout,
                qkv_bias=qkv_bias,
            )  # is self-attn if context is none
        if self.self_attn:
            self.attn2 = attn_cls(
                query_dim=dim,
                heads=n_heads,
                dim_head=d_head,
                dropout=dropout,
                qkv_bias=qkv_bias,
            )  # is a self-attention
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
        self.norm3 = nn.LayerNorm(dim)
        self.act = final_act
        self.checkpoint = checkpoint
        assert not self.checkpoint, 'Checkpointing not available yet'

    # def forward(self, x, context=None):
    #     return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)

    @BatchedOperation(batch_dim=0, plain_num_dim=2)
    def forward(self, x, context=None, cross_mask=None, self_mask=None, **kwargs):
        if self.cross_attn:
            x = self.attn1(self.norm1(x), context=context, mask=cross_mask, **kwargs) + x
        if self.self_attn:
            x = self.attn2(self.norm2(x), mask=self_mask, **kwargs) + x
        x = self.ff(self.norm3(x)) + x
        if self.act is not None:
            x = self.act(x)
        return x


class Transformer(nn.Module):
    def __init__(self, in_dim, out_dim, n_heads, d_head,
                 depth=1, dropout=0., context_dim=None):
        super().__init__()
        self.in_dim = in_dim
        inner_dim = n_heads * d_head
        self.norm = Normalize(in_dim)

        self.proj_in = nn.Linear(in_dim, inner_dim)

        self.transformer_blocks = nn.ModuleList(
            [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
                for d in range(depth)]
        )

        self.proj_out = zero_module(nn.Linear(inner_dim, out_dim))

    def forward(self, x, context=None):
        # note: if no context is given, cross-attention defaults to self-attention
        x = x.unsqueeze(1)
        x_in = x
        x = self.norm(x)
        x = self.proj_in(x)
        for block in self.transformer_blocks:
            x = block(x, context=context)
        x = self.proj_out(x)
        return (x + x_in)[:, 0, :]
