from typing import Optional

import torch
from torch import nn


def _mean_pooling(x: torch.Tensor,
                  mask: Optional[torch.Tensor] = None) -> torch.Tensor:
    if mask is not None:
        mask = mask.view((*mask.shape, 1))
        lengths = mask.sum(dim=1)
        x = (x * mask).sum(dim=1) / lengths
    else:
        x = x.mean(dim=1)
    return x


def _max_pooling(x: torch.Tensor,
                 mask: Optional[torch.Tensor] = None) -> torch.Tensor:
    if mask is not None:
        inf_mask = torch.zeros_like(mask, dtype=x.dtype)
        inf_mask.masked_fill_(mask == 0, -torch.inf)
        inf_mask = inf_mask.view((*inf_mask.shape, 1))
        x = x + inf_mask
    x = x.max(dim=1, keepdim=False).values
    return x


def _index_pooling(x: torch.Tensor,
                   index: int,
                   mask: Optional[torch.Tensor] = None) -> torch.Tensor:
    if mask is not None:
        lengths = mask.sum(dim=1)
        x = x[torch.arange(len(x)), lengths - 1]
    else:
        x = x.select(index=index, dim=1)
    return x


def _last_pooling(x: torch.Tensor,
                  mask: Optional[torch.Tensor] = None) -> torch.Tensor:
    x = _index_pooling(x, index=-1, mask=mask)
    return x


def _first_pooling(x: torch.Tensor,
                   mask: Optional[torch.Tensor] = None) -> torch.Tensor:
    x = _index_pooling(x, index=0, mask=mask)
    return x


class Pooling(nn.Module):
    supported_pooling = {
        'mean': _mean_pooling,
        'max': _max_pooling,
        'last': _last_pooling,
        'first': _first_pooling
    }

    def __init__(self, pooling: str = 'mean'):
        super().__init__()
        self.pooling = pooling.lower()
        if pooling not in self.supported_pooling:
            raise KeyError(f'Activation function {pooling} is not supported. '
                           f'Supported list: {list(self.supported_pooling.keys())}')
        self._pooling_fn = Pooling.supported_pooling[pooling]

    def forward(self,
                x: torch.Tensor,
                mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        x = self._pooling_fn(x, mask)
        return x

    def extra_repr(self):
        return f'pooling={self.pooling}'
