from pytorch_lightning import LightningModule
from torch import nn
from typing import Optional
from transformers import AutoModel, AutoConfig
from mGPT.archs.vqvae.utils.resnet import Resnet1D
from mGPT.archs.vqvae.utils.quantize_cnn import (
    QuantizeEMAReset,
    Quantizer,
    QuantizeEMA,
    QuantizeReset,
    QuantizeCVQVAE,
    QuantizeResidualVQVAE,
)
from collections import OrderedDict
from mGPT.archs.vqvae.utils.residual_vq import ResidualVQ
from mGPT.archs.vqvae.utils.residual_lfq import (
    ResidualLFQ,
    MultiscaleLFQ,
    MultiscaleLFQ2,
)
from mGPT.archs.vqvae.utils.quantize_vqgan import VectorQuantizer2
from mGPT.archs.vqvae.utils.multiscale_vq import MultiscaleVQ
from mGPT.archs.vqvae.utils.multiscale_vq2 import (
    MultiscaleVectorQuantize,
    ResidualVectorQuantize,
)
from mGPT.archs.vqvae.utils.multiscale_vq5 import MultiscaleVectorQuantize5
from mGPT.archs.vqvae.utils.multiscale_vq6 import MultiscaleVectorQuantize6
from mGPT.archs.vqvae.utils.multiscale_vq7 import MultiscaleVectorQuantize7
from .multiscale_vq8 import MultiscaleVectorQuantize8
from .cvq import MultiscaleVectorQuantize9
from .fsq import MultiscaleVectorQuantize10
from .multiscale_vq11 import MultiscaleVectorQuantize11
from .multiscale_vq12 import MultiscaleVectorQuantize12

from mGPT.archs.mld_vae import MldEnc
from .conv_layer import CausalConv1d, CausalConvTranspose1d
from mGPT.archs.vqvae.utils.activation_function import get_activation


class BaseVQVAE(LightningModule):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.save_hyperparameters()
        self.hparams.v_lengths = sorted(self.hparams.v_lengths, reverse=True)
        if len(self.hparams.v_lengths) > 1 and self.hparams.num_quantizers == 1:
            self.hparams.num_quantizers = len(self.hparams.v_lengths)

    def _init_encoder(self, double_z: bool = False) -> None:
        """Initialize encoder."""

        if "resnet1d" in self.hparams.encoder_type:
            if self.hparams.encoder_type == "resnet1d_casual":
                from .encdec_casual import Encoder
            else:
                from .encdec import Encoder

            self.encoder = Encoder(
                self.hparams.nfeats,
                self.hparams.output_emb_width * 2
                if double_z
                else self.hparams.output_emb_width,
                self.hparams.down_t,
                self.hparams.stride_t,
                self.hparams.layers,
                self.hparams.width,
                self.hparams.depth,
                self.hparams.dilation_growth_rate,
                activation=self.hparams.activation,
                norm=self.hparams.norm,
            )

        elif "/" in self.hparams.encoder_type:
            self.encoder = AutoModel.from_pretrained(
                self.hparams.encoder_type,
                config=AutoConfig.from_pretrained(
                    self.hparams.encoder_type,
                    num_hidden_layers=self.hparams.depth,
                    n_positions=self.hparams.max_length,
                    max_position_embeddings=self.hparams.max_length,
                ),
                ignore_mismatched_sizes=True,
            )
        else:
            raise ValueError(f"Unknown encoder type: {self.encoder_type}")

    def _init_decoder(self) -> None:
        """Initialize decoder."""

        if "resnet1d" in self.hparams.decoder_type:
            if self.hparams.decoder_type == "resnet1d_casual":
                from .encdec_casual import Decoder
            else:
                from .encdec import Decoder
            self.decoder = Decoder(
                self.hparams.nfeats,
                self.hparams.output_emb_width,
                self.hparams.down_t,
                self.hparams.stride_t,
                self.hparams.layers,
                self.hparams.width,
                self.hparams.depth,
                self.hparams.dilation_growth_rate,
                activation=self.hparams.activation,
                norm=self.hparams.norm,
            )
        elif "/" in self.hparams.decoder_type:
            self.decoder = AutoModel.from_pretrained(
                self.hparams.decoder_type,
                config=AutoConfig.from_pretrained(
                    self.hparams.decoder_type,
                    num_hidden_layers=self.hparams.depth,
                    n_positions=self.hparams.max_length,
                    max_position_embeddings=self.hparams.max_length,
                ),
                ignore_mismatched_sizes=True,
            )
        else:
            raise ValueError(f"Unknown decoder type: {self.hparams.decoder_type}")

    def _init_quantizer(self, module_name: str = "quantizer") -> None:
        """Initialize quantizer."""

        quantizer = self.hparams.quantizer
        if quantizer == "ema_reset":
            self.quantizer = QuantizeEMAReset(
                self.hparams.code_num, self.hparams.code_dim, mu=0.99
            )
        elif quantizer == "orig":
            self.quantizer = Quantizer(
                self.hparams.code_num, self.hparams.code_dim, beta=1.0
            )
        elif quantizer == "ema":
            self.quantizer = QuantizeEMA(
                self.hparams.code_num, self.hparams.code_dim, mu=0.99
            )
        elif quantizer == "reset":
            self.quantizer = QuantizeReset(self.hparams.code_num, self.hparams.code_dim)
        elif quantizer == "cvq":
            self.quantizer = QuantizeCVQVAE(
                self.hparams.code_num, self.hparams.code_dim
            )
        elif quantizer == "rvq":
            self.quantizer = QuantizeResidualVQVAE(
                self.hparams.code_num,
                self.hparams.code_dim,
                num_quantizers=self.hparams.num_quantizers,
                shared_codebook=self.hparams.shared_codebook,
            )
        elif quantizer == "rvq_guo":
            self.quantizer = ResidualVQ(
                self.hparams.code_num,
                self.hparams.code_dim,
                num_quantizers=self.hparams.num_quantizers,
                shared_codebook=self.hparams.shared_codebook,
                quantize_dropout_prob=self.hparams.quantize_dropout_prob,
                quantize_dropout_cutoff_index=self.hparams.quantize_dropout_cutoff_index,
            )
        elif quantizer == "rvq_lfq":
            self.quantizer = ResidualLFQ(
                dim=self.hparams.output_emb_width,
                num_quantizers=self.hparams.num_quantizers,
                codebook_size=self.hparams.code_num,
                commitment_loss_weight=1.0,
                quantize_dropout_prob=self.hparams.quantize_dropout_prob,
                quantize_dropout_cutoff_index=self.hparams.quantize_dropout_cutoff_index,
            )
            quantizer = "rvq_guo"
        elif quantizer == "rvq_vqgan":
            self.quantizer = VectorQuantizer2(
                vocab_size=self.hparams.code_num,
                Cvae=self.hparams.code_dim,
                beta=self.hparams.beta,
                share_quant_resi=self.hparams.share_quant_resi,
                v_lengths=self.hparams.v_lengths,
            )
        elif quantizer == "multiscale":
            self.quantizer = MultiscaleVQ(
                self.hparams.code_num,
                self.hparams.code_dim,
                v_lengths=self.hparams.v_lengths,
                shared_codebook=self.hparams.shared_codebook,
                quantize_dropout_prob=self.hparams.quantize_dropout_prob,
                quantize_dropout_cutoff_index=self.hparams.quantize_dropout_cutoff_index,
            )
        elif quantizer == "rvq_rvqgan":
            self.quantizer = ResidualVectorQuantize(
                input_dim=self.hparams.output_emb_width,
                n_codebooks=self.hparams.num_quantizers,
                codebook_size=self.hparams.code_num,
                codebook_dim=self.hparams.code_dim,
                quantizer_dropout=self.hparams.quantize_dropout_prob,
            )
        elif quantizer == "multiscale2":
            self.quantizer = MultiscaleVectorQuantize(
                input_dim=self.hparams.output_emb_width,
                v_lengths=self.hparams.v_lengths,
                codebook_size=self.hparams.code_num,
                codebook_dim=self.hparams.code_dim,
                shared_codebook=self.hparams.shared_codebook,
                quantizer_dropout=self.hparams.quantize_dropout_prob,
            )
            quantizer = "mvq"
        elif quantizer == "multiscale3":
            self.quantizer = MultiscaleLFQ(
                dim=self.hparams.output_emb_width,
                v_lengths=self.hparams.v_lengths,
                codebook_size=self.hparams.code_num,
                commitment_loss_weight=1.0,
                quantize_dropout_prob=self.hparams.quantize_dropout_prob,
                quantize_dropout_cutoff_index=self.hparams.quantize_dropout_cutoff_index,
            )
            quantizer = "rvq_guo"
        elif quantizer == "multiscale4":
            self.quantizer = MultiscaleLFQ2(
                dim=self.hparams.output_emb_width,
                v_lengths=self.hparams.v_lengths,
                codebook_size=self.hparams.code_num,
                commitment_loss_weight=1.0,
                quantize_dropout_prob=self.hparams.quantize_dropout_prob,
                quantize_dropout_cutoff_index=self.hparams.quantize_dropout_cutoff_index,
            )
            quantizer = "rvq_guo"
        elif quantizer == "multiscale5":
            self.quantizer = MultiscaleVectorQuantize5(
                input_dim=self.hparams.output_emb_width,
                v_lengths=self.hparams.v_lengths,
                codebook_size=self.hparams.code_num,
                codebook_dim=self.hparams.code_dim,
                shared_codebook=self.hparams.shared_codebook,
                quantizer_dropout=self.hparams.quantize_dropout_prob,
            )
            quantizer = "rvq_guo"
        elif quantizer == "multiscale6":
            self.quantizer = MultiscaleVectorQuantize6(
                input_dim=self.hparams.output_emb_width,
                v_lengths=self.hparams.v_lengths,
                codebook_size=self.hparams.code_num,
                codebook_dim=self.hparams.code_dim,
                shared_codebook=self.hparams.shared_codebook,
                quantizer_dropout=self.hparams.quantize_dropout_prob,
            )
            quantizer = "mvq"
        elif quantizer == "multiscale7":
            self.quantizer = MultiscaleVectorQuantize7(
                input_dim=self.hparams.output_emb_width,
                v_lengths=self.hparams.v_lengths,
                codebook_size=self.hparams.code_num,
                codebook_dim=self.hparams.code_dim,
                shared_codebook=self.hparams.shared_codebook,
                quantizer_dropout=self.hparams.quantize_dropout_prob,
            )
            quantizer = "mvq"
        elif quantizer == "multiscale8":
            self.quantizer = MultiscaleVectorQuantize8(
                input_dim=self.hparams.output_emb_width,
                v_lengths=self.hparams.v_lengths,
                codebook_size=self.hparams.code_num,
                codebook_dim=self.hparams.code_dim,
                shared_codebook=self.hparams.shared_codebook,
                quantizer_dropout=self.hparams.quantize_dropout_prob,
            )
            quantizer = "mvq"
        elif quantizer == "multiscale9":
            self.quantizer = MultiscaleVectorQuantize9(
                input_dim=self.hparams.output_emb_width,
                v_lengths=self.hparams.v_lengths,
                codebook_size=self.hparams.code_num,
                codebook_dim=self.hparams.code_dim,
                shared_codebook=self.hparams.shared_codebook,
                quantizer_dropout=self.hparams.quantize_dropout_prob,
            )
            quantizer = "mvq"
        elif quantizer == "multiscale10":
            self.quantizer = MultiscaleVectorQuantize10(
                input_dim=self.hparams.output_emb_width,
                levels=self.hparams.levels,
                n_codebooks=self.hparams.num_quantizers,
                shared_codebook=self.hparams.shared_codebook,
                quantizer_dropout=self.hparams.quantize_dropout_prob,
            )
            quantizer = "mvq"
        elif quantizer == "multiscale11":
            self.quantizer = MultiscaleVectorQuantize11(
                input_dim=self.hparams.output_emb_width,
                v_lengths=self.hparams.v_lengths,
                codebook_size=self.hparams.code_num,
                codebook_dim=self.hparams.code_dim,
                shared_codebook=self.hparams.shared_codebook,
                quantizer_dropout=self.hparams.quantize_dropout_prob,
                codebook_dropout=self.hparams.codebook_dropout,
            )
            quantizer = "mvq"
        elif quantizer == "multiscale12":
            self.quantizer = MultiscaleVectorQuantize12(
                input_dim=self.hparams.output_emb_width,
                v_lengths=self.hparams.v_lengths,
                codebook_size=self.hparams.code_num,
                codebook_dim=self.hparams.code_dim,
                shared_codebook=self.hparams.shared_codebook,
                quantizer_dropout=self.hparams.quantize_dropout_prob,
                codebook_dropout=self.hparams.codebook_dropout,
            )
            quantizer = "mvq"
        else:
            raise ValueError(f"Unknown quantizer type: {quantizer}")

        self.hparams.quantizer_type = quantizer

    def _init_discriminator(self) -> None:
        """Initialize discriminator."""

        if self.hparams.discriminator_type == "linear":
            layers = []
            in_dim = self.hparams.nfeats
            activation = get_activation(
                self.hparams.activation, self.hparams.activation_params
            )
            for i, out_dim in enumerate(self.hparams.discriminator_widths):
                layers.append((f"linear{i}", nn.Linear(in_dim, out_dim)))
                layers.append((f"activation{i}", activation))
                in_dim = out_dim
            layers.append(("flatten", nn.Flatten(0, 1)))
            layers.append(("linear", nn.Linear(in_dim, 1)))
            layers.append(("sigmoid", nn.Sigmoid()))
            self.discriminator = nn.Sequential(OrderedDict(layers))

        elif self.hparams.discriminator_type == "resnet1d":

            class Discriminator(nn.Module):
                def __init__(
                    self, in_dim, width, depth, dilation_growth_rate, activation, norm
                ):
                    super().__init__()
                    self.layers = nn.ModuleList()
                    self.layers.append(nn.Conv1d(in_dim, width, 3, 1, 1))
                    self.layers.append(nn.ReLU())
                    for i in range(depth):
                        input_dim = width
                        self.layers.append(
                            nn.Sequential(
                                nn.Conv1d(input_dim, width, 3, 1, 1),
                                Resnet1D(
                                    width,
                                    depth,
                                    dilation_growth_rate,
                                    activation=activation,
                                    norm=norm,
                                ),
                            )
                        )
                    self.layers.append(nn.Conv1d(width, 1, 1))
                    self.layers.append(nn.Sigmoid())

                def forward(self, x):
                    x = x.permute(0, 2, 1)
                    for layer in self.layers:
                        x = layer(x)
                    return x.squeeze(1)

            self.discriminator = Discriminator(
                self.hparams.nfeats,
                self.hparams.discriminator_widths[0],
                self.hparams.discriminator_depths,
                self.hparams.dilation_growth_rate,
                self.hparams.activation,
                self.hparams.norm,
            )

        elif self.hparams.discriminator_type == "mld":
            self.discriminator = nn.Sequential(
                MldEnc(
                    nfeats=self.hparams.nfeats,
                    ff_size=self.hparams.discriminator_ffsize,
                    num_layers=self.hparams.discriminator_depths,
                    num_heads=self.hparams.discriminator_heads,
                    dropout=self.hparams.discriminator_dropout,
                ),
                nn.Sigmoid(),
            )

        elif "/" in self.hparams.discriminator_type:
            self.discriminator = nn.Sequential(
                AutoModel.from_pretrained(
                    self.hparams.discriminator_type,
                    config=AutoConfig.from_pretrained(
                        self.hparams.discriminator_type,
                        num_hidden_layers=self.hparams.depth,
                        n_positions=self.hparams.max_length,
                        max_position_embeddings=self.hparams.max_length,
                    ),
                    ignore_mismatched_sizes=True,
                ),
                nn.Linear(self.hparams.output_emb_width, 1),
                nn.Sigmoid(),
            )

        else:
            raise ValueError(
                f"Unknown discriminator type: {self.hparams.discriminator_type}"
            )

    def preprocess(self, x):
        """
        Preprocess input tensor.
        (bs, T, Jx3) -> (bs, Jx3, T)
        """
        x = x.permute(0, 2, 1)
        return x

    def postprocess(self, x):
        """
        Postprocess input tensor.
        (bs, Jx3, T) ->  (bs, T, Jx3)
        """

        x = x.permute(0, 2, 1)
        return x

    def reset_buffer(self, target_module: Optional[str] = None):
        """Apply weight normalization module from all layers."""

        def _reset_buffer(m):
            if isinstance(m, CausalConv1d) or isinstance(m, CausalConvTranspose1d):
                m.reset_buffer()

        if target_module:
            getattr(self, target_module).apply(_reset_buffer)
        else:
            self.apply(_reset_buffer)
