import torch
import torch.nn as nn
from einops import rearrange

from quant.fsq import FSQ


class Downsample(nn.Module):
    def __init__(self, in_channels, out_channels, factor):
        super().__init__()
        self.conv = nn.Conv1d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=2 * factor + 1,
            stride=factor,
            padding=factor,
        )

    def forward(self, X):
        return self.conv(X)


class Upsample(nn.Module):
    def __init__(self, in_channels, out_channels, factor):
        super().__init__()
        self.conv = nn.ConvTranspose1d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=2 * factor,
            stride=factor,
            padding=(factor // 2) + (factor % 2),
            output_padding=factor % 2)

    def forward(self, X):
        return self.conv(X)


class ConvBlock1d(nn.Module):
    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            kernel_size: int = 3,
            stride: int = 1,
            padding: int = 1,
            dilation: int = 1,
            num_groups: int = 8,
    ) -> None:
        super().__init__()

        self.groupnorm = nn.GroupNorm(num_groups=min(num_groups, in_channels), num_channels=in_channels)
        self.activation = nn.SiLU()
        self.project = nn.Conv1d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
        )

    def forward(self, x):
        x = self.groupnorm(x)
        x = self.activation(x)
        return self.project(x)


class ResnetBlock1d(nn.Module):
    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            kernel_size: int = 3,
            stride: int = 1,
            padding: int = 1,
            num_groups: int = 8,
    ) -> None:
        super().__init__()
        self.block1 = ConvBlock1d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            num_groups=num_groups,
        )

        self.block2 = ConvBlock1d(
            in_channels=out_channels,
            out_channels=out_channels,
            num_groups=num_groups,
        )

        self.to_out = (
            nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
            if in_channels != out_channels
            else nn.Identity()
        )


    def forward(self, x):
        return self.block2(self.block1(x)) + self.to_out(x)


class Downsampler(nn.Module):
    def __init__(self, patch_channels, channels, emb_channels, resnet_count):
        super().__init__()
        self.patch_channels = patch_channels

        layers = []
        layers.append(ResnetBlock1d(in_channels=patch_channels,
                                    out_channels=channels[0],
                                    kernel_size=7,
                                    padding=3))
        for i in range(len(channels) - 1):
            layers.append(Downsample(in_channels=channels[i],
                                     out_channels=channels[i + 1],
                                     factor=2))
            for _ in range(resnet_count):
                layers.append(ResnetBlock1d(in_channels=channels[i + 1],
                                            out_channels=channels[i + 1]))
        layers.append(nn.Conv1d(in_channels=channels[-1], out_channels=emb_channels, kernel_size=1))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        B, N = x.shape
        assert N % self.patch_channels == 0
        x = rearrange(x, "b (n pc) -> b pc n", pc=self.patch_channels)
        return self.net(x)


class Upsampler(nn.Module):
    def __init__(self, patch_channels, channels, emb_channels, resnet_count):
        super().__init__()
        self.patch_channels = patch_channels
        layers = []
        layers.append(ResnetBlock1d(in_channels=emb_channels,
                                    out_channels=channels[-1],
                                    kernel_size=7,
                                    padding=3))
        for i in range(len(channels) - 2, -1, -1):
            layers.append(Upsample(in_channels=channels[i + 1],
                                   out_channels=channels[i],
                                   factor=2))
            for _ in range(resnet_count):
                layers.append(ResnetBlock1d(in_channels=channels[i],
                                            out_channels=channels[i]))
        layers.append(nn.Conv1d(in_channels=channels[0], out_channels=patch_channels, kernel_size=1))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        x = self.net(x)
        B, pc, n = x.shape
        assert pc == self.patch_channels
        return rearrange(x, "b pc n -> b (n pc)", pc=pc)


class Autoencoder(nn.Module):
    def __init__(self, patch_channels, channels, resnet_count, use_fsq=False, fsq_levels=None):
        super().__init__()
        self.use_fsq = use_fsq
        if self.use_fsq:
            if fsq_levels is None:
                raise ValueError("fsq_levels should be present when using FSQ")
            self.fsq = FSQ(fsq_levels)
            emb_channels = len(fsq_levels)
        else:
            emb_channels = channels[-1]
        self.downsampler = Downsampler(patch_channels, channels, emb_channels, resnet_count)
        self.upsampler = Upsampler(patch_channels, channels, emb_channels, resnet_count)

    def forward(self, x):
        x = self.downsampler(x)
        if self.use_fsq:
            x = rearrange(x, "b c n -> b n c")
            x = self.fsq(x)
            x = rearrange(x, "b n c -> b c n")
        x = self.upsampler(x)
        return x

    def encode(self, x):
        x = self.downsampler(x)
        if self.use_fsq:
            x = rearrange(x, "b c n -> b n c")
            x = self.fsq(x)
            x = self.fsq.codes_to_indices(x)
        return x

    def decode(self, x):
        if self.use_fsq:
            x = self.fsq.indices_to_codes(x)
            x = rearrange(x, "b n c -> b c n")
        x = self.upsampler(x)
        return x


if __name__ == "__main__":
    downsampler = Downsampler(patch_channels=16, channels=[32, 32, 32, 32], emb_channels=8, resnet_count=2)
    upsampler = Upsampler(patch_channels=16, channels=[32, 32, 32, 32], emb_channels=8, resnet_count=2)
    inp = torch.randn((3, 40960))
    print("inp", inp.shape)
    emb = downsampler(inp)
    print("emb", emb.shape)
    out = upsampler(emb)
    print("out", out.shape)
