import torch
from torch import nn
from .resnet_casual import Resnet1D
from .conv_layer import CausalConv1d, CausalConvTranspose1d
from mGPT.archs.vqvae.utils.activation_function import get_activation


class CasualIdentity(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x

    def inference(self, x):
        return x


class Encoder(nn.Module):
    def __init__(
        self,
        input_emb_width=3,
        output_emb_width=512,
        down_t=3,
        stride_t=2,
        layers=0,
        width=512,
        depth=3,
        dilation_growth_rate=3,
        activation="relu",
        activation_params={},
        norm=None,
    ):
        super().__init__()

        filter_t = stride_t * 2

        self.activation = get_activation(activation, activation_params)
        self.conv1 = CausalConv1d(input_emb_width, width, 3, 1)

        self.res_units = torch.nn.ModuleList()
        for i in range(down_t):
            input_dim = width
            self.res_units += [
                CausalConv1d(input_dim, width, filter_t, stride_t),
                Resnet1D(
                    width, depth, dilation_growth_rate, activation=activation, norm=norm
                ),
            ]
        for i in range(layers):
            self.res_units += [
                CausalConv1d(input_dim, width, 2, 1),
                Resnet1D(
                    width, depth, dilation_growth_rate, activation=activation, norm=norm
                ),
            ]

        self.conv2 = CausalConv1d(width, output_emb_width, 3, 1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.activation(x)
        for res_unit in self.res_units:
            x = res_unit(x)
        x = self.conv2(x)
        return x

    def inference(self, x):
        x = self.conv1.inference(x)
        x = self.activation(x)
        for res_unit in self.res_units:
            x = res_unit.inference(x)
        x = self.conv2.inference(x)
        return x


class Decoder(nn.Module):
    def __init__(
        self,
        input_emb_width=3,
        output_emb_width=512,
        down_t=3,
        stride_t=2,
        layers=0,
        width=512,
        depth=3,
        dilation_growth_rate=3,
        activation="relu",
        activation_params={},
        norm=None,
    ):
        super().__init__()

        self.activation = get_activation(activation, activation_params)

        self.conv1 = CausalConv1d(output_emb_width, width, 3, 1)

        self.res_units = torch.nn.ModuleList()
        for i in range(down_t):
            out_dim = width
            self.res_units += [
                Resnet1D(
                    width,
                    depth,
                    dilation_growth_rate,
                    reverse_dilation=True,
                    activation=activation,
                    norm=norm,
                ),
                CausalConvTranspose1d(
                    in_channels=width,
                    out_channels=width,
                    kernel_size=4,
                    stride=stride_t,
                )
                if stride_t > 1
                else CasualIdentity(),
                CausalConv1d(width, out_dim, 3, 1),
            ]
        for i in range(layers):
            self.res_units += [
                Resnet1D(
                    width,
                    depth,
                    dilation_growth_rate,
                    reverse_dilation=True,
                    activation=activation,
                    norm=norm,
                ),
                CausalConv1d(width, out_dim, 3, 1),
            ]

        self.conv2 = CausalConv1d(width, width, 3, 1)
        self.conv3 = CausalConv1d(width, input_emb_width, 3, 1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.activation(x)
        for res_unit in self.res_units:
            x = res_unit(x)
        x = self.conv2(x)
        x = self.activation(x)
        x = self.conv3(x)
        return x

    def inference(self, x):
        x = self.conv1.inference(x)
        x = self.activation(x)
        for res_unit in self.res_units:
            x = res_unit.inference(x)
        x = self.conv2.inference(x)
        x = self.activation(x)
        x = self.conv3.inference(x)
        return x
