from typing import Iterable, Tuple, Union
import math
import torch
import torch.nn as nn
from omegaconf import DictConfig, OmegaConf

from pado.nn.modules import Conv2d, MaskedBatchNorm, Linear, MaxPool2d, Dropout, LayerNorm
from pado.nn.modules.activation import get_activation_cls
from pado.nn.utils import make_mask_by_length
from pado.tasks.asr.asr_subsampling import BaseASRSubsampling

__all__ = ["VGGSubsampling", "StrideSubsampling"]


class VGGSubsampling(BaseASRSubsampling):

    def __init__(self,
                 feature_dim: int,
                 out_dim: int,
                 num_layers: int = 2,
                 num_channels: Union[int, Iterable[int]] = 32,
                 kernel_size: int = 3,
                 drop_prob: float = 0.0,
                 bn: bool = False,
                 eps: float = 1e-5,
                 momentum: float = 0.01, *,
                 act_type: str = "relu",
                 out_norm: bool = False,
                 sync_bn: bool = True) -> None:
        super().__init__()
        self.feature_dim = feature_dim
        self.out_dim = out_dim
        self.num_layers = num_layers

        if isinstance(num_channels, int):
            num_channels = [num_channels] * num_layers

        if len(num_channels) != num_layers:
            raise ValueError(f"VGGSubsampling length of num_channels {num_channels} != {num_layers}.")
        self.num_channels = num_channels
        self.kernel_size = kernel_size

        conv_layers = []
        bn_layers = []
        act_layers = []
        pool_layers = []

        conv_channels = 1

        for i in range(num_layers):
            # ------------------ 1st
            conv_layers.append(Conv2d(conv_channels, num_channels[i],
                                      kernel_size=kernel_size, padding=kernel_size // 2, bias=not bn, partial=False))
            if bn:
                bn_layers.append(MaskedBatchNorm(num_channels[i], eps=eps, momentum=momentum, sync_bn=sync_bn))
            act_layers.append(get_activation_cls(act_type, inplace=False))

            # ------------------ 2nd
            conv_layers.append(Conv2d(num_channels[i], num_channels[i],
                                      kernel_size=kernel_size, padding=kernel_size // 2, bias=not bn, partial=False))
            if bn:
                bn_layers.append(MaskedBatchNorm(num_channels[i], eps=eps, momentum=momentum, sync_bn=sync_bn))
            act_layers.append(get_activation_cls(act_type, inplace=False))

            # ------------------ pool
            pool_layers.append(MaxPool2d(2, 2, padding=0, ceil_mode=True))
            feature_dim = math.ceil((feature_dim - (kernel_size - 1) - 1) / 2 + 1)
            conv_channels = num_channels[i]

        self.reduced_feature_dim = int(feature_dim)

        self.conv_layers = nn.ModuleList(conv_layers)
        if bn:
            self.bn_layers = nn.ModuleList(bn_layers)
        else:
            self.bn_layers = None
        self.act_layers = nn.ModuleList(act_layers)
        self.pool_layers = nn.ModuleList(pool_layers)

        self.drop = Dropout(drop_prob, inplace=True)
        self.linear = Linear(self.reduced_feature_dim * conv_channels, out_dim, bias=not out_norm)

        if out_norm:
            self.out_norm = LayerNorm(out_dim, eps=eps)
        else:
            self.out_norm = None

        self._initialize_parameters()

    def forward(self,
                x: torch.Tensor,
                lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """VGGSubsampling forward
        :param x:               (batch_size, 1, seq_length, feature_dim)
        :param lengths:         (batch_size,)
        :return:
                output:         (batch_size, out_seq_length, out_dim)
                new_lengths:    (batch_size,) which contains reduced sequence length.
        """
        if x.ndim == 3:
            x = x.unsqueeze(1)

        device = x.device
        lengths = lengths.clone().detach_()

        for i in range(self.num_layers):
            mask = make_mask_by_length(lengths, max_length=x.shape[2]).to(device)  # (batch_size, input_seq_length)
            x = self.conv_layers[2 * i](x, mask)
            if self.bn_layers is not None:
                x = self.bn_layers[2 * i](x, mask)
            x = self.act_layers[2 * i](x)

            x = self.conv_layers[2 * i + 1](x, mask)
            if self.bn_layers is not None:
                x = self.bn_layers[2 * i + 1](x, mask)
            x = self.act_layers[2 * i + 1](x)

            x = self.pool_layers[i](x)
            # 2x2 max-pool
            lengths = torch.ceil((lengths - (2 - 1) - 1) / 2 + 1).long().to(device)

        b, c, t, f = x.shape
        x = x.transpose(1, 2).contiguous().view(b, t, c * f)
        x = self.drop(x)
        x = self.linear(x)
        if self.out_norm is not None:
            x = self.out_norm(x)
        return x, lengths

    @torch.no_grad()
    def calculate_length(self, lengths: torch.Tensor) -> torch.Tensor:
        """
        :param lengths:         (batch_size,)
        :return:
                new_lengths:    (batch_size,)
        """
        for i in range(self.num_layers):
            # 2x2 max-pool
            lengths = torch.ceil((lengths - (2 - 1) - 1) / 2 + 1)
        return lengths.detach_().long()

    @classmethod
    def from_config(cls, cfg: DictConfig):
        cfg = OmegaConf.to_container(cfg, resolve=True)
        return cls(**cfg)


class StrideSubsampling(BaseASRSubsampling):

    def __init__(self,
                 feature_dim: int,
                 out_dim: int,
                 num_layers: int = 2,
                 num_channels: Union[int, Iterable[int]] = 32,
                 kernel_size: int = 3,
                 drop_prob: float = 0.0,
                 bn: bool = True,
                 eps: float = 1e-5,
                 momentum: float = 0.01, *,
                 act_type: str = "relu",
                 out_norm: bool = False,
                 sync_bn: bool = True) -> None:
        super().__init__()
        self.feature_dim = feature_dim
        self.out_dim = out_dim
        self.num_layers = num_layers

        if isinstance(num_channels, int):
            num_channels = [num_channels] * num_layers

        if len(num_channels) != num_layers:
            raise ValueError(f"StrideSubsampling length of num_channels {len(num_channels)} mismatch to {num_layers}.")
        self.num_channels = num_channels
        self.kernel_size = kernel_size

        conv_layers = []
        bn_layers = []
        act_layers = []

        conv_channels = 1

        for i in range(num_layers):
            conv_layers.append(
                Conv2d(conv_channels, num_channels[i],
                       kernel_size=kernel_size, stride=2, padding=kernel_size // 2, bias=not bn, partial=False))
            if bn:
                bn_layers.append(MaskedBatchNorm(num_channels[i], eps=eps, momentum=momentum, sync_bn=sync_bn))
            act_layers.append(get_activation_cls(act_type, inplace=False))

            feature_dim = math.floor((feature_dim + (2 * (kernel_size // 2)) - (kernel_size - 1) - 1) / 2 + 1)
            conv_channels = num_channels[i]

        self.reduced_feature_dim = int(feature_dim)

        self.conv_layers = nn.ModuleList(conv_layers)
        if bn:
            self.bn_layers = nn.ModuleList(bn_layers)
        else:
            self.bn_layers = None
        self.act_layers = nn.ModuleList(act_layers)

        self.drop = Dropout(drop_prob, inplace=True)
        self.linear = Linear(self.reduced_feature_dim * conv_channels, out_dim, bias=not out_norm)

        if out_norm:
            self.out_norm = LayerNorm(out_dim, eps=eps)
        else:
            self.out_norm = None

        self._initialize_parameters()

    def forward(self,
                x: torch.Tensor,
                lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """StrideSubsampling forward
        :param x:               (batch_size, 1, seq_length, feature_dim)
        :param lengths:         (batch_size,)
        :return:
                output:         (batch_size, out_seq_length, out_dim)
                new_lengths:    (batch_size,) which contains reduced sequence length.
        """
        if x.ndim == 3:
            x = x.unsqueeze(1)

        device = x.device
        lengths = lengths.clone().detach_()

        for i in range(self.num_layers):
            mask = make_mask_by_length(lengths, max_length=x.shape[2]).to(device)  # (batch_size, input_seq_length)
            x = self.conv_layers[i](x, mask)
            # stride 2 conv
            lengths = torch.floor((lengths + (2 * (self.kernel_size // 2)) -
                                   (self.kernel_size - 1) - 1) / 2 + 1).long().to(device)

            mask = make_mask_by_length(lengths, max_length=x.shape[2]).to(device)  # (batch_size, input_seq_length)
            if self.bn_layers is not None:
                x = self.bn_layers[i](x, mask)
            x = self.act_layers[i](x)

        b, c, t, f = x.shape
        x = x.transpose(1, 2).contiguous().view(b, t, c * f)
        x = self.drop(x)
        x = self.linear(x)
        if self.out_norm is not None:
            x = self.out_norm(x)
        return x, lengths

    @torch.no_grad()
    def calculate_length(self, lengths: torch.Tensor) -> torch.Tensor:
        """
        :param lengths:         (batch_size,)
        :return:
                new_lengths:    (batch_size,)
        """
        for i in range(self.num_layers):
            # one conv
            lengths = torch.floor((lengths + (2 * (self.kernel_size // 2)) - (self.kernel_size - 1) - 1) / 2 + 1)
        return lengths.detach_().long()

    @classmethod
    def from_config(cls, cfg: DictConfig):
        cfg = OmegaConf.to_container(cfg, resolve=True)
        return cls(**cfg)
