import torch
import torch.nn as nn
import torch.nn.functional as F

from math import sqrt


def Conv1d(*args, **kwargs):
    layer = nn.Conv1d(*args, **kwargs)
    nn.init.kaiming_normal_(layer.weight)
    return layer


def make_causal_conv1d_weight_mask(conv_weight, lookahead_tokens = 0):
    out_ch, in_ch, kernel_size = conv_weight.shape
    center = kernel_size // 2
    # allow current + past + up to lookahead future
    end = min(kernel_size, center + 1 + lookahead_tokens)
    mask_1d = torch.zeros(kernel_size, dtype=conv_weight.dtype, device=conv_weight.device)
    mask_1d[:end] = 1.0
    mask = mask_1d.view(1, 1, kernel_size).expand(out_ch, in_ch, kernel_size)
    return mask


@torch.jit.script
def silu(x):
    return x * torch.sigmoid(x)


class ResidualBlock(nn.Module):
    def __init__(self, residual_channels, dilation, is_causal=False, lookahead_tokens = 0):
        """
        :param residual_channels: audio conv
        :param dilation: audio conv dilation
        """
        super().__init__()
        self.is_causal = is_causal
        self.lookahead_tokens = lookahead_tokens // dilation
        self.dilated_conv = Conv1d(
            residual_channels,
            2 * residual_channels,
            3,
            padding=dilation,
            dilation=dilation,
        )

        self.output_projection = Conv1d(residual_channels, 2 * residual_channels, 1)

    def forward(self, x):
        if self.is_causal:
            with torch.no_grad():
                tmp_mask = make_causal_conv1d_weight_mask(
                    self.dilated_conv.weight, self.lookahead_tokens
                ).clone()
            weight = self.dilated_conv.weight * tmp_mask
            y = F.conv1d(
                x,
                weight,
                bias=self.dilated_conv.bias,
                stride=self.dilated_conv.stride,
                padding=self.dilated_conv.padding,
                dilation=self.dilated_conv.dilation,
                groups=self.dilated_conv.groups,
            )
        else:
            y = self.dilated_conv(x)

        gate, filter = torch.chunk(y, 2, dim=1)
        y = torch.sigmoid(gate) * torch.tanh(filter)

        y = self.output_projection(y)
        residual, skip = torch.chunk(y, 2, dim=1)
        return (x + residual) / sqrt(2.0), skip


class Wave(nn.Module):
    def __init__(
        self,
        input_channels: int,
        residual_channels: int,
        dilation_cycle_length: int,
        residual_layers: int,
        is_causal: bool = False,
        lookahead_tokens: int = 0
    ):
        super().__init__()

        self.input_projection = Conv1d(input_channels, residual_channels, 1)

        self.residual_layers = nn.ModuleList(
            [
                ResidualBlock(
                    residual_channels,
                    2 ** (i % dilation_cycle_length),
                    is_causal=is_causal,
                    lookahead_tokens=lookahead_tokens,
                )
                for i in range(residual_layers)
            ]
        )
        self.skip_projection = Conv1d(residual_channels, residual_channels, 1)
        self.output_projection = Conv1d(residual_channels, input_channels, 1)
        nn.init.zeros_(self.output_projection.weight)

    def forward(self, input):
        x = input
        x = self.input_projection(x)
        x = F.relu(x)

        skip = None
        for layer in self.residual_layers:
            x, skip_connection = layer(x)
            skip = skip_connection if skip is None else skip_connection + skip

        x = skip / sqrt(len(self.residual_layers))
        x = self.skip_projection(x)
        x = F.relu(x)
        x = self.output_projection(x)
        return x
