from typing import Optional
import torch
from torch import Tensor
import pdb

from examples.simultaneous_translation.utils.functions import (
    exclusive_cumprod,
    prob_check,
    moving_sum,
)


def expected_alignment_from_p_choose(
    p_choose: Tensor,
    padding_mask: Optional[Tensor] = None,
    eps: float = 1e-6
):
    """
    Calculating expected alignment for from stepwise probability

    Reference:
    Online and Linear-Time Attention by Enforcing Monotonic Alignments
    https://arxiv.org/pdf/1704.00784.pdf

    q_ij = (1 − p_{ij−1})q_{ij−1} + a+{i−1j}
    a_ij = p_ij q_ij

    Parallel solution:
    ai = p_i * cumprod(1 − pi) * cumsum(a_i / cumprod(1 − pi))

    ============================================================
    Expected input size
    p_choose: bsz, tgt_len, src_len
    """
    prob_check(p_choose)

    # p_choose: bsz, tgt_len, src_len
    bsz, tgt_len, src_len = p_choose.size()
    dtype = p_choose.dtype

    p_choose = p_choose.float()

    if padding_mask is not None:
        p_choose = p_choose.masked_fill(padding_mask.unsqueeze(1), 0.0)

    if p_choose.is_cuda:
        p_choose = p_choose.contiguous()
        from alignment_train_cuda_binding import alignment_train_cuda as alignment_train
    else:
        from alignment_train_cpu_binding import alignment_train_cpu as alignment_train

    alpha = p_choose.new_zeros([bsz, tgt_len, src_len])
    alignment_train(p_choose, alpha, eps)

    # Mix precision to prevent overflow for fp16
    alpha = alpha.type(dtype)

    prob_check(alpha)

    return alpha


def expected_soft_attention(
    alpha: Tensor,
    soft_energy: Tensor,
    padding_mask: Optional[Tensor] = None,
    chunk_size: Optional[int] = None,
    eps: float = 1e-10
):
    """
    Function to compute expected soft attention for
    monotonic infinite lookback attention from
    expected alignment and soft energy.

    Reference:
    Monotonic Chunkwise Attention
    https://arxiv.org/abs/1712.05382

    Monotonic Infinite Lookback Attention for Simultaneous Machine Translation
    https://arxiv.org/abs/1906.05218

    alpha: bsz, tgt_len, src_len
    soft_energy: bsz, tgt_len, src_len
    padding_mask: bsz, src_len
    left_padding: bool
    """
    if padding_mask is not None:
        #no need to pad src, we have already done this in model.forward()
        #alpha = alpha.masked_fill(padding_mask.unsqueeze(1), 0.0)
        #check alpha[-1].sum(dim=-1) == 1.0 for first tgt_lengths[-1] elements 
        soft_energy = soft_energy.masked_fill(
            padding_mask.unsqueeze(1), -float("inf")
        )

    prob_check(alpha)

    dtype = alpha.dtype

    alpha = alpha.float()
    soft_energy = soft_energy.float()

    soft_energy = soft_energy - soft_energy.max(dim=2, keepdim=True)[0]
    exp_soft_energy = torch.exp(soft_energy) + eps

    if chunk_size is not None:
        # Chunkwise
        beta = (
            exp_soft_energy
            * moving_sum(
                alpha / (eps + moving_sum(exp_soft_energy, chunk_size, 1)),
                1, chunk_size
            )
        )
    else:
        # Infinite lookback
        # Notice that infinite lookback is a special case of chunkwise
        # where chunksize = inf
        inner_items = alpha / (eps + torch.cumsum(exp_soft_energy, dim=2))

        beta = (
            exp_soft_energy
            * torch.cumsum(inner_items.flip(dims=[2]), dim=2)
            .flip(dims=[2])
        )

    if padding_mask is not None:
        beta = beta.masked_fill(
            padding_mask.unsqueeze(1).to(torch.bool), 0.0)

    # Mix precision to prevent overflow for fp16
    beta = beta.type(dtype)

    beta = beta.clamp(0, 1)

    prob_check(beta)

    return beta


def mass_preservation(
    alpha: Tensor,
    padding_mask: Optional[Tensor] = None,
    left_padding: bool = False
):
    """
    Function to compute the mass perservation for alpha.
    This means that the residual weights of alpha will be assigned
    to the last token.

    Reference:
    Monotonic Infinite Lookback Attention for Simultaneous Machine Translation
    https://arxiv.org/abs/1906.05218

    alpha: bsz, tgt_len, src_len
    padding_mask: bsz, src_len
    left_padding: bool
    """

    prob_check(alpha)

    if padding_mask is not None:
        if not left_padding:
            assert not padding_mask[:, 0].any(), (
                "Find padding on the beginning of the sequence."
            )
        alpha = alpha.masked_fill(padding_mask.unsqueeze(1), 0.0)

    if left_padding or padding_mask is None:
        residuals = 1 - alpha[:, :, :-1].sum(dim=-1).clamp(0, 1)
        alpha[:, :, -1] = residuals
    else:
        # right padding
        _, tgt_len, src_len = alpha.size()
        residuals = 1 - alpha.sum(dim=-1, keepdim=True).clamp(0, 1)
        src_lens = src_len - padding_mask.sum(dim=1, keepdim=True)
        src_lens = src_lens.expand(-1, tgt_len).contiguous()
        # add back the last value
        residuals += alpha.gather(2, src_lens.unsqueeze(2) - 1)
        alpha = alpha.scatter(2, src_lens.unsqueeze(2) - 1, residuals)

        prob_check(alpha)

    return alpha
