from typing import Optional, Union, Tuple
import math
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from pado.core import PadoModule
from pado.nn.parameter import ParameterModule

__all__ = ["Conv1d", "Conv2d"]

logger = logging.getLogger("pado")


def _int_to_tuple(i: Union[int, tuple, list], n: int) -> Tuple[int, ...]:
    if isinstance(i, int):
        return (i,) * n
    return tuple(i)


def _reverse_repeat_tuple(t: tuple, n: int) -> Tuple[int, ...]:
    return tuple(x for x in reversed(t) for _ in range(n))


class _ConvNd(PadoModule):

    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel_size: Tuple[int, ...],
                 stride: Tuple[int, ...],
                 padding: Tuple[int, ...],
                 dilation: Tuple[int, ...],
                 transposed: bool,
                 output_padding: Tuple[int, ...],
                 groups: int,
                 bias: bool = True,
                 padding_mode: str = "zeros",
                 partial: bool = False) -> None:
        super().__init__()
        if in_channels % groups != 0:
            raise ValueError(f"Conv in_channels {in_channels} not divisible by groups {groups}.")
        if out_channels % groups != 0:
            raise ValueError(f"Conv out_channels {out_channels} not divisible by groups {groups}.")
        padding_mode = padding_mode.lower()
        if padding_mode not in ("zeros", "reflect", "replicate", "circular"):
            raise ValueError(f"Conv padding_mode {padding_mode} is invalid.")
        # currently we do not support 'same' padding.

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.transposed = transposed
        self.output_padding = output_padding
        self.groups = groups
        self.padding_mode = padding_mode

        self._reversed_padding_x2 = _reverse_repeat_tuple(self.padding, 2)
        # padding for (h, w = 0, 1) -> (w_left, w_right, h_top, h_bottom = 1, 1, 0, 0)

        if transposed:
            self.weight = ParameterModule(torch.empty(in_channels, out_channels // groups, *kernel_size))
        else:
            self.weight = ParameterModule(torch.empty(out_channels, in_channels // groups, *kernel_size))
        if bias:
            self.bias = ParameterModule(torch.zeros(out_channels))
        else:
            self.bias = None

        # partial conv
        self.partial = partial
        if partial:
            dummy_weight = torch.ones(1, 1, *kernel_size, dtype=torch.float32)
            dummy_weight /= dummy_weight.numel()
            if self.padding_mode != "zeros":
                logger.warning(f"Conv2d with partial=True require padding mode to be zeros, got {self.padding_mode}.")
                self.padding_mode = "zeros"  # force set
        else:
            dummy_weight = None
        self.register_buffer("dummy_weight", dummy_weight, persistent=False)

        self._initialize_parameters()

    def _initialize_parameters(self):
        fan_in = self.in_channels
        fan_out = self.out_channels
        # intentionally remove the effect of #groups in xavier_init.
        nn.init.uniform_(self.weight.data, -math.sqrt(6.0 / (fan_in + fan_out)), math.sqrt(6.0 / (fan_in + fan_out)))
        if self.bias is not None:
            nn.init.zeros_(self.bias.data)

    @torch.no_grad()
    def _generate_partial_scale(self, mask: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

    def conv_forward(self, x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor:
        raise NotImplementedError

    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        weight = self.weight()
        bias = self.bias() if (self.bias is not None) else None
        return self.conv_forward(x, weight, bias)


class Conv1d(_ConvNd):

    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel_size: Union[int, Tuple[int, ...]],
                 stride: Union[int, Tuple[int, ...]] = 1,
                 padding: Union[int, Tuple[int, ...]] = 0,
                 dilation: Union[int, Tuple[int, ...]] = 1,
                 groups: int = 1,
                 bias: bool = True,
                 padding_mode: str = "zeros",
                 *, partial: bool = False) -> None:
        kernel_size = _int_to_tuple(kernel_size, 1)
        stride = _int_to_tuple(stride, 1)
        padding = _int_to_tuple(padding, 1)
        dilation = _int_to_tuple(dilation, 1)
        super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation,
                         False, (0,), groups, bias, padding_mode, partial)

    @torch.no_grad()
    def _generate_partial_scale(self, mask: torch.Tensor) -> torch.Tensor:
        """
        :param mask:        (batch_size, 1, seq_length)
        :return:
                            (batch_size, 1, seq_length)
        """
        b, _, s = mask.shape
        assert mask.shape == (b, 1, s)

        device = self.dummy_weight.device
        mask = mask.float().to(device)
        scale = F.conv1d(mask, self.dummy_weight, None, self.stride, self.padding, self.dilation, groups=1)
        scale = mask / scale.add_(1e-6)
        scale = scale.clamp_(1.0, 2.0)  # scale cannot overflow. just for sure.
        scale = scale.mul_(mask).detach_()  # to be sure that zero is still zero
        return scale

    def conv_forward(self, x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor:
        if self.padding_mode != "zeros":
            return F.conv1d(F.pad(x, self._reversed_padding_x2, mode=self.padding_mode),
                            weight, bias, self.stride, (0,), self.dilation, self.groups)
        return F.conv1d(x, weight, bias, self.stride, self.padding, self.dilation, self.groups)

    def forward(self,
                x: torch.Tensor,
                mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        :param x:       (batch_size, in_channels, s)
        :param mask:    (batch_size, s)                     bool, T: valid, F: pad
        :return:
                        (batch_size, out_channels, out_s)
        """
        weight = self.weight()
        bias = self.bias() if (self.bias is not None) else None

        if (mask is not None) and (self.padding_mode == "zeros"):
            if mask.ndim == 2:
                mask = mask.unsqueeze(1)  # (b, 1, s)
            x.masked_fill_(torch.logical_not(mask), 0.0)

        y = self.conv_forward(x, weight, bias)
        if self.partial:
            if mask is None:
                b, _, s = y.shape
                mask = torch.ones(b, 1, s, dtype=torch.bool, device=y.device)
            elif mask.ndim == 2:
                mask = mask.unsqueeze(1)  # (b, 1, s)
            scale = self._generate_partial_scale(mask)
            y = y * scale
        return y


class Conv2d(_ConvNd):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel_size: Union[int, Tuple[int, ...]],
                 stride: Union[int, Tuple[int, ...]] = 1,
                 padding: Union[int, Tuple[int, ...]] = 0,
                 dilation: Union[int, Tuple[int, ...]] = 1,
                 groups: int = 1,
                 bias: bool = True,
                 padding_mode: str = "zeros",
                 *, partial: bool = False) -> None:
        kernel_size = _int_to_tuple(kernel_size, 2)
        stride = _int_to_tuple(stride, 2)
        padding = _int_to_tuple(padding, 2)
        dilation = _int_to_tuple(dilation, 2)
        super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation,
                         False, (0, 0), groups, bias, padding_mode, partial)

    @torch.no_grad()
    def _generate_partial_scale(self, mask: torch.Tensor, dim: int) -> torch.Tensor:
        """
        :param mask:        (batch_size, 1, seq_length, 1)
        :return:
                            (batch_size, 1, seq_length, 1)
        """
        b, _, s, _ = mask.shape
        assert mask.shape == (b, 1, s, 1)

        device = self.dummy_weight.device
        mask = mask.float().to(device).expand(b, 1, s, dim)
        scale = F.conv2d(mask, self.dummy_weight, None, self.stride, self.padding, self.dilation, groups=1)
        scale = mask / scale.add_(1e-6)
        scale = scale.clamp_(1.0, 2.0)  # scale cannot overflow. just for sure.
        scale = scale.mul_(mask).detach_()  # to be sure that zero is still zero
        return scale

    def conv_forward(self, x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor:
        if self.padding_mode != "zeros":
            return F.conv2d(F.pad(x, self._reversed_padding_x2, mode=self.padding_mode),
                            weight, bias, self.stride, (0,), self.dilation, self.groups)
        return F.conv2d(x, weight, bias, self.stride, self.padding, self.dilation, self.groups)

    def forward(self,
                x: torch.Tensor,
                mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        :param x:       (batch_size, in_channels, h, w)
        :param mask:    (batch_size, h)                     bool, T: valid, F: pad
        :return:
                        (batch_size, out_channels, out_h, out_w)
        """
        weight = self.weight()
        bias = self.bias() if (self.bias is not None) else None

        if (mask is not None) and (self.padding_mode == "zeros"):
            if mask.ndim == 2:
                mask = mask.unsqueeze(1).unsqueeze(-1)  # (b, 1, s, 1)
            x.masked_fill_(torch.logical_not(mask), 0.0)

        y = self.conv_forward(x, weight, bias)
        if self.partial:
            if mask is None:
                b, _, s, _ = y.shape
                mask = torch.ones(b, 1, s, 1, dtype=torch.bool, device=y.device)
            elif mask.ndim == 2:
                mask = mask.unsqueeze(1).unsqueeze(-1)  # (b, 1, s, 1)
            scale = self._generate_partial_scale(mask, dim=y.shape[-1])
            y = y * scale
        return y
