"""
Convolution layers from https://github.com/facebookresearch/AudioDec
"""

import math
import torch
import torch.nn as nn


def int2tuple(variable, length):
    if isinstance(variable, int):
        return (variable,) * length
    else:
        assert len(variable) == length, f"The length of {variable} is not {length}!"
        return variable


class Conv1d1x1(nn.Conv1d):
    """1x1 Conv1d."""

    def __init__(self, in_channels, out_channels, bias=True):
        super(Conv1d1x1, self).__init__(
            in_channels, out_channels, kernel_size=1, bias=bias
        )


class NonCausalConv1d(nn.Module):
    """1D noncausal convloution w/ 2-sides padding."""

    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride=1,
        padding=-1,
        dilation=1,
        groups=1,
        bias=True,
    ):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        if padding < 0:
            padding = (kernel_size - 1) // 2 * dilation
        self.dilation = dilation
        self.conv = nn.Conv1d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias=bias,
        )

    def forward(self, x):
        """
        Args:
            x (Tensor): Float tensor variable with the shape  (B, C, T).
        Returns:
            Tensor: Float tensor variable with the shape (B, C, T).
        """
        x = self.conv(x)
        return x


class NonCausalConvTranspose1d(nn.Module):
    """1D noncausal transpose convloution."""

    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride,
        padding=-1,
        output_padding=-1,
        groups=1,
        bias=True,
    ):
        super().__init__()
        if padding < 0:
            padding = (stride + 1) // 2
        if output_padding < 0:
            output_padding = 1 if stride % 2 else 0
        self.deconv = nn.ConvTranspose1d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            output_padding=output_padding,
            groups=groups,
            bias=bias,
        )

    def forward(self, x):
        """
        Args:
            x (Tensor): Float tensor variable with the shape  (B, C, T).
        Returns:
            Tensor: Float tensor variable with the shape (B, C', T').
        """
        x = self.deconv(x)
        return x


class CausalConv1d(NonCausalConv1d):
    """1D causal convloution w/ 1-side padding."""

    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride=1,
        dilation=1,
        groups=1,
        bias=True,
        pad_buffer=None,
    ):
        super(CausalConv1d, self).__init__(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=0,
            dilation=dilation,
            groups=groups,
            bias=bias,
        )
        self.stride = stride
        self.pad_length = (kernel_size - 1) * dilation
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.groups = groups
        self.bias = bias
        self.kernel_size = kernel_size
        if pad_buffer is None:
            pad_buffer = torch.zeros(1, in_channels, self.pad_length).contiguous()
        self.register_buffer("pad_buffer", pad_buffer)

    def forward(self, x):
        pad = nn.ConstantPad1d((self.pad_length, 0), 0.0)
        x = pad(x)
        self.pad_buffer = x[:, :, -self.pad_length :]
        return self.conv(x).float()

    def inference(self, x):
        self.align_buffer(x)
        x = torch.cat((self.pad_buffer, x), -1)
        self.pad_buffer = x[:, :, -self.pad_length :]
        return self.conv(x).float()

    def align_buffer(self, x):
        B_x = x.size(0)
        B_pad = self.pad_buffer.size(0)
        if B_pad < B_x:
            repeat_factor = B_x // B_pad  # 计算需要 repeat 的次数
            self.pad_buffer = self.pad_buffer.repeat(repeat_factor, 1, 1)
            if B_x % B_pad != 0:
                extra = B_x - self.pad_buffer.size(0)
                self.pad_buffer = torch.cat(
                    (self.pad_buffer, self.pad_buffer[:extra]), dim=0
                )
        elif B_pad > B_x:
            self.pad_buffer = self.pad_buffer[:B_x]

    def reset_buffer(self):
        self.pad_buffer = torch.zeros(
            1,
            self.in_channels,
            self.pad_length,
            device=self.pad_buffer.device,
            dtype=self.pad_buffer.dtype,
        ).contiguous()

    def __repr__(self):
        return f"CausalConv1d(in_channels={self.in_channels}, out_channels={self.out_channels}, kernel_size={self.kernel_size}, stride={self.stride}, dilation={self.dilation}, groups={self.groups}, bias={self.bias})"


class CausalConvTranspose1d(NonCausalConvTranspose1d):
    """1D causal transpose convloution."""

    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride,
        bias=True,
        pad_buffer=None,
    ):
        super(CausalConvTranspose1d, self).__init__(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=0,
            output_padding=0,
            bias=bias,
        )
        self.stride = stride
        self.pad_length = math.ceil(kernel_size / stride) - 1
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.bias = bias
        self.kernel_size = kernel_size
        if pad_buffer is None:
            pad_buffer = torch.zeros(1, in_channels, self.pad_length).contiguous()
        self.register_buffer("pad_buffer", pad_buffer)

    def forward(self, x):
        pad = nn.ReplicationPad1d((self.pad_length, 0))
        x = pad(x)
        self.pad_buffer = x[:, :, -self.pad_length :]
        return self.deconv(x)[:, :, self.stride : -self.stride].float()

    def inference(self, x):
        self.align_buffer(x)
        x = torch.cat((self.pad_buffer, x), -1)
        self.pad_buffer = x[:, :, -self.pad_length :]
        return self.deconv(x)[:, :, self.stride : -self.stride].float()

    def align_buffer(self, x):
        B_x = x.size(0)
        B_pad = self.pad_buffer.size(0)
        if B_pad < B_x:
            repeat_factor = B_x // B_pad
            self.pad_buffer = self.pad_buffer.repeat(repeat_factor, 1, 1)
            if B_x % B_pad != 0:
                extra = B_x - self.pad_buffer.size(0)
                self.pad_buffer = torch.cat(
                    (self.pad_buffer, self.pad_buffer[:extra]), dim=0
                )
        elif B_pad > B_x:
            self.pad_buffer = self.pad_buffer[:B_x]

    def reset_buffer(self):
        self.pad_buffer = torch.zeros(
            1,
            self.in_channels,
            self.pad_length,
            device=self.pad_buffer.device,
            dtype=self.pad_buffer.dtype,
        ).contiguous()

    def __repr__(self):
        return f"CausalConvTranspose1d(in_channels={self.in_channels}, out_channels={self.out_channels}, kernel_size={self.kernel_size}, stride={self.stride}, bias={self.bias})"


class NonCausalConv2d(nn.Module):
    """2D noncausal convloution w/ 4-sides padding."""

    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride=1,
        padding=-1,
        dilation=1,
        groups=1,
        bias=True,
    ):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = int2tuple(kernel_size, 2)
        self.dilation = int2tuple(dilation, 2)
        if isinstance(padding, int) and padding < 0:
            padding_0 = (self.kernel_size[0] - 1) // 2 * self.dilation[0]
            padding_1 = (self.kernel_size[1] - 1) // 2 * self.dilation[1]
            padding = (padding_0, padding_1)

        self.conv = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias=bias,
        )

    def forward(self, x):
        """
        Args:
            x (Tensor): Float tensor variable with the shape  (B, C, T).
        Returns:
            Tensor: Float tensor variable with the shape (B, C, T).
        """
        x = self.conv(x)
        return x


class CasualIdentity(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x

    def inference(self, x):
        return x
