from typing import Callable, Dict

import numpy as np
import torch
from ablator import Enum


class MaskType(Enum):

    MIX = "mix"
    GLOBAL = "global"
    LOCAL = "local"
    FULL = "full"
    RANDOM = "random"


def full_mask(sz: int) -> torch.Tensor:
    # larger alpha means larger mask
    mask = torch.zeros((sz, sz), dtype=torch.bool)
    return mask


def random_mask(sz: int, alpha: float):
    rand_mask = torch.rand((sz, sz)) > (1 - alpha)
    return rand_mask


def global_mask(sz: int) -> torch.Tensor:
    mask = ~full_mask(sz)
    mask[:, 0] = False
    mask[0, :] = False
    return mask


def local_attention(sz: int, stride_len=None):
    if stride_len is None:
        stride_len = int(np.ceil(sz * 0.1) + 1)
    neg_mask = ~torch.triu(torch.ones(sz, sz, dtype=torch.bool), diagonal=stride_len)
    pos_mask = torch.triu(torch.ones(sz, sz, dtype=torch.bool), diagonal=-stride_len)
    return ~(neg_mask & pos_mask)


def causal_mask(sz: int) -> torch.Tensor:
    """Generates an upper-triangular matrix of -inf, with zeros on diag."""
    return torch.triu(torch.ones(sz, sz, dtype=torch.bool), diagonal=1)


def mix_attention(sz):
    mask = full_mask(sz)
    mask |= ~global_mask(sz)
    return ~mask

def full_attention(sz):
    mask = full_mask(sz)
    return ~mask

mask_fn_map: Dict[str, Callable] = {
    MaskType.GLOBAL: global_mask,
    MaskType.MIX: mix_attention,
    MaskType.FULL: full_attention,
}


def make_mask(mask_type: str, size: int):
    return mask_fn_map[mask_type](size)
