from functools import wraps

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch import Tensor

from src.modules.act import GEGLU, ReGLU


def exists(val):
    return val is not None


def default(val, d):
    return val if exists(val) else d


def check_act_fn(act):
    """Some new activation functions like GEGLU and ReGLU are gates, where the half of the inputs
    is used to do the gating. This function checks if the activation function is a gate or not, and
    adopts the dimensions of the linear layer to be consistent.

    :param act: Activation function to be checked.
    :type act: _type_
    :return: Multiplication factor for the linear layer.
    :rtype: int
    """
    acts = (GEGLU, ReGLU)
    if hasattr(act, "func"):
        return 2 if issubclass(act.func, acts) else 1
    else:
        return 2 if issubclass(act, acts) else 1


def cache_fn(f):
    cache = None

    @wraps(f)
    def cached_fn(*args, _cache=True, **kwargs):
        if not _cache:
            return f(*args, **kwargs)
        nonlocal cache
        if cache is not None:
            return cache
        cache = f(*args, **kwargs)
        return cache

    return cached_fn


def dropout_seq(seq, mask, dropout):
    b, n, *_, device = *seq.shape, seq.device
    logits = torch.randn(b, n, device=device)

    if exists(mask):
        logits = logits.masked_fill(~mask, -torch.finfo(logits.dtype).max)

    keep_prob = 1.0 - dropout
    num_keep = max(1, int(keep_prob * n))
    keep_indices = logits.topk(num_keep, dim=1).indices

    batch_indices = torch.arange(b, device=device)
    batch_indices = rearrange(batch_indices, "b -> b 1")

    seq = seq[batch_indices, keep_indices]

    if exists(mask):
        seq_counts = mask.sum(dim=-1)
        seq_keep_counts = torch.ceil(seq_counts * keep_prob).int()
        keep_mask = torch.arange(num_keep, device=device) < rearrange(seq_keep_counts, "b -> b 1")

        mask = mask[batch_indices, keep_indices] & keep_mask

    return seq, mask


class PreNorm(nn.Module):
    def __init__(self, dim, fn, context_dim=None):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)
        self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None

    def forward(self, x, **kwargs):
        x = self.norm(x)

        if exists(self.norm_context):
            context = kwargs["context"]
            normed_context = self.norm_context(context)
            kwargs.update(context=normed_context)

        return self.fn(x, **kwargs)


class FeedForward(nn.Module):
    def __init__(
        self,
        dim: int,
        depth: int = 1,
        act: nn.Module = GEGLU,
        input_dim: int = None,
        output_dim: int = None,
    ):
        super().__init__()
        input_dim = default(input_dim, dim)
        output_dim = default(output_dim, dim)
        act_mult = check_act_fn(act)
        layers = [nn.Sequential(nn.Linear(input_dim, dim * act_mult), act())]

        layers = layers + [
            nn.Sequential(nn.Linear(dim, dim * act_mult), act()) for _ in range(1, depth)
        ]
        layers.append(nn.Linear(dim, output_dim))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)


class Attention(nn.Module):
    def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64):
        super().__init__()
        inner_dim = dim_head * heads
        context_dim = default(context_dim, query_dim)
        self.scale = dim_head**-0.5
        self.heads = heads

        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
        self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False)
        self.to_out = nn.Linear(inner_dim, query_dim)

    def forward(self, x, context=None, mask=None):
        h = self.heads

        q = self.to_q(x)
        context = default(context, x)
        k, v = self.to_kv(context).chunk(2, dim=-1)

        q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))

        out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, scale=self.scale)
        out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
        return self.to_out(out)
