from typing import List, Optional
import torch

__all__ = [
    "pad_sequence", "make_mask_by_length",
    "make_self_attn_mask_from_mask", "make_cross_attn_mask_from_mask",
    "apply_local_mask_to_attn_mask", "make_autoregressive_mask_from_shape",
]


def pad_sequence(sequences: List[torch.Tensor],
                 output_batch_first: bool = True,
                 padding_value=0,
                 pad_to_multiple: int = 1) -> torch.Tensor:
    """
    Pad list of sequence to maximum length.
    Unlike torch.nn.utils.pad_sequence, input is always B x (T, ...).
    :param sequences:               List of (seq_length, ...)
    :param output_batch_first:
    :param padding_value:
    :param pad_to_multiple:
    :return:
            padded sequence:        (batch_size, max_seq_length, ...)
    """
    batch_size = len(sequences)
    max_len = max([s.shape[0] for s in sequences])
    trailing_dims = sequences[0].shape[1:]

    if (pad_to_multiple > 1) and (max_len % pad_to_multiple != 0):
        max_len += (pad_to_multiple - max_len % pad_to_multiple)
    out_dims = (batch_size, max_len) + trailing_dims

    out_tensor = sequences[0].new_full(out_dims, padding_value)
    for i, tensor in enumerate(sequences):
        length = tensor.shape[0]
        out_tensor[i, :length, ...] = tensor

    if output_batch_first:
        return out_tensor  # (B, T, ...)
    else:
        return out_tensor.transpose(0, 1).contiguous()  # (T, B, ...)


@torch.no_grad()
def make_mask_by_length(lengths: torch.Tensor,
                        max_length: Optional[int] = None,
                        right_align: bool = False) -> torch.Tensor:
    """
    Make boolean T/F mask which indicates padding or not.
    :param lengths:         (batch_size,)   long
    :param max_length:
    :param right_align:
    :return:
            mask:           (batch_size, max_length)    bool, T: valid, F: pad
    """
    if max_length is None:
        max_length = lengths.max().item()

    if lengths.max() > max_length:
        raise ValueError(f"make_mask_by_length length overflow, "
                         f"maximum of lengths: {lengths.max().item()}, but given max_length: {max_length}.")

    batch_size = lengths.shape[0]

    seq_range = torch.arange(0, max_length, dtype=torch.long, device=lengths.device)  # (s,)
    seq_range = seq_range.unsqueeze(0).expand(batch_size, max_length)  # (b, s)
    seq_length = lengths.unsqueeze(1).expand(batch_size, max_length)  # (b, s)
    mask = torch.less(seq_range, seq_length)
    if right_align:
        mask = torch.fliplr(mask)
    return mask.contiguous()


@torch.no_grad()
def make_self_attn_mask_from_mask(mask: torch.Tensor) -> torch.Tensor:
    """
    Make self-attentive mask (outer-product of self)
    :param mask:            (batch_size, max_length)                bool
    :return:
            mask:           (batch_size, max_length, max_length)    bool
    """
    if mask.ndim != 2:
        raise ValueError(f"make_self_attn_mask_from_mask is for 2D tensor (B, T), got {mask.shape}.")

    b, s = mask.shape
    attn_mask = torch.logical_and(mask.view(b, s, 1), mask.view(b, 1, s))  # (b, s, s)
    return attn_mask.contiguous()


@torch.no_grad()
def make_cross_attn_mask_from_mask(mask_self: torch.Tensor,
                                   mask_cross: torch.Tensor) -> torch.Tensor:
    """
    Make cross-attentive mask (outer-product of self and cross)
    :param mask_self:       (batch_size, max_self_length)       bool
    :param mask_cross:      (batch_size, max_cross_length)      bool
    :return:
            mask:           (batch_size, max_self_length, max_cross_length) bool
    """
    if mask_self.ndim != 2:
        raise ValueError(f"make_cross_attn_mask_from_mask is for 2D tensor (B, T), got {mask_self.shape} for self.")
    if mask_cross.ndim != 2:
        raise ValueError(f"make_cross_attn_mask_from_mask is for 2D tensor (B, T), got {mask_cross.shape} for cross.")

    b, s_a = mask_self.shape
    _, s_c = mask_cross.shape
    assert mask_cross.shape[0] == b

    attn_mask = torch.logical_and(mask_self.view(b, s_a, 1), mask_cross.view(b, 1, s_c))  # (b, sa, sc)
    return attn_mask.contiguous()


@torch.no_grad()
def apply_local_mask_to_attn_mask(attn_mask: torch.Tensor,
                                  left_context: int = -1,
                                  right_context: int = -1) -> torch.Tensor:
    if (left_context < 0) and (right_context < 0):  # shortcut
        return attn_mask

    s_q, s_k = attn_mask.shape[-2:]
    if s_q > s_k:
        raise ValueError(f"make_local_mask_from_attn_msk shape invalid, got {attn_mask.shape}.")

    local_mask = torch.ones(s_q, s_k, dtype=torch.bool, device=attn_mask.device)
    if left_context >= 0:  # restrict left context. If 0, do not see any left context.
        local_mask = torch.logical_and(
            local_mask,
            torch.triu(torch.ones_like(local_mask), diagonal=s_k - s_q - left_context)
        )
    if right_context >= 0:  # restrict right context. If 0, do not see any right context.
        local_mask = torch.logical_and(
            local_mask,
            torch.tril(torch.ones_like(local_mask), diagonal=s_k - s_q + right_context)
        )

    attn_mask = torch.logical_and(attn_mask, local_mask)
    return attn_mask


@torch.no_grad()
def make_autoregressive_mask_from_shape(q_length, k_length, same_length: bool = False) -> torch.Tensor:
    if q_length > k_length:
        raise ValueError(f"make_autoregressive_mask_from_shape shape invalid, got ({q_length}, {k_length}).")
    attn_mask = torch.ones(q_length, k_length, dtype=torch.bool)
    if not same_length:
        return apply_local_mask_to_attn_mask(attn_mask, left_context=-1, right_context=0)
    else:
        return apply_local_mask_to_attn_mask(attn_mask, left_context=k_length - q_length, right_context=0)


if __name__ == '__main__':
    # print(make_autoregressive_mask_from_shape(3, 7, same_length=False))
    # print(make_autoregressive_mask_from_shape(5, 5, same_length=False))
    # print(make_autoregressive_mask_from_shape(3, 7, same_length=True))
    # print(make_autoregressive_mask_from_shape(5, 5, same_length=True))

    a = torch.ones(5, 5, dtype=torch.bool)
    # print(apply_local_mask_to_attn_mask(a, -1, -1))
    # print(apply_local_mask_to_attn_mask(a, -1, 0))
    # print(apply_local_mask_to_attn_mask(a, -1, 1))
    # print(apply_local_mask_to_attn_mask(a, 2, 1))
