from typing import Any, Literal, Callable
import math
from pathlib import Path

import torch
import torch.nn as nn
from torch.nn.utils import weight_norm
import torchaudio
from alias_free_torch import Activation1d

from models.common import LoadPretrainedBase
from models.autoencoder.autoencoder_base import AutoEncoderBase
from utils.torch_utilities import remove_key_prefix_factory, create_mask_from_length


# jit script make it 1.4x faster and save GPU memory
@torch.jit.script
def snake_beta(x, alpha, beta):
    return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2)


class SnakeBeta(nn.Module):
    def __init__(
        self,
        in_features,
        alpha=1.0,
        alpha_trainable=True,
        alpha_logscale=True
    ):
        super(SnakeBeta, self).__init__()
        self.in_features = in_features

        # initialize alpha
        self.alpha_logscale = alpha_logscale
        if self.alpha_logscale:
            # log scale alphas initialized to zeros
            self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
            self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
        else:
            # linear scale alphas initialized to ones
            self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
            self.beta = nn.Parameter(torch.ones(in_features) * alpha)

        self.alpha.requires_grad = alpha_trainable
        self.beta.requires_grad = alpha_trainable

        # self.no_div_by_zero = 0.000000001

    def forward(self, x):
        alpha = self.alpha.unsqueeze(0).unsqueeze(-1)
        # line up with x to [B, C, T]
        beta = self.beta.unsqueeze(0).unsqueeze(-1)
        if self.alpha_logscale:
            alpha = torch.exp(alpha)
            beta = torch.exp(beta)
        x = snake_beta(x, alpha, beta)

        return x


def WNConv1d(*args, **kwargs):
    return weight_norm(nn.Conv1d(*args, **kwargs))


def WNConvTranspose1d(*args, **kwargs):
    return weight_norm(nn.ConvTranspose1d(*args, **kwargs))


def get_activation(
    activation: Literal["elu", "snake", "none"],
    antialias=False,
    channels=None
) -> nn.Module:
    if activation == "elu":
        act = nn.ELU()
    elif activation == "snake":
        act = SnakeBeta(channels)
    elif activation == "none":
        act = nn.Identity()
    else:
        raise ValueError(f"Unknown activation {activation}")

    if antialias:
        act = Activation1d(act)

    return act


class ResidualUnit(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        dilation,
        use_snake=False,
        antialias_activation=False
    ):
        super().__init__()

        self.dilation = dilation

        padding = (dilation * (7 - 1)) // 2

        self.layers = nn.Sequential(
            get_activation(
                "snake" if use_snake else "elu",
                antialias=antialias_activation,
                channels=out_channels
            ),
            WNConv1d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=7,
                dilation=dilation,
                padding=padding
            ),
            get_activation(
                "snake" if use_snake else "elu",
                antialias=antialias_activation,
                channels=out_channels
            ),
            WNConv1d(
                in_channels=out_channels,
                out_channels=out_channels,
                kernel_size=1
            )
        )

    def forward(self, x):
        res = x

        #x = checkpoint(self.layers, x)
        x = self.layers(x)

        return x + res


class EncoderBlock(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        stride,
        use_snake=False,
        antialias_activation=False
    ):
        super().__init__()

        self.layers = nn.Sequential(
            ResidualUnit(
                in_channels=in_channels,
                out_channels=in_channels,
                dilation=1,
                use_snake=use_snake
            ),
            ResidualUnit(
                in_channels=in_channels,
                out_channels=in_channels,
                dilation=3,
                use_snake=use_snake
            ),
            ResidualUnit(
                in_channels=in_channels,
                out_channels=in_channels,
                dilation=9,
                use_snake=use_snake
            ),
            get_activation(
                "snake" if use_snake else "elu",
                antialias=antialias_activation,
                channels=in_channels
            ),
            WNConv1d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=2 * stride,
                stride=stride,
                padding=math.ceil(stride / 2)
            ),
        )

    def forward(self, x):
        return self.layers(x)


class DecoderBlock(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        stride,
        use_snake=False,
        antialias_activation=False,
        use_nearest_upsample=False
    ):
        super().__init__()

        if use_nearest_upsample:
            upsample_layer = nn.Sequential(
                nn.Upsample(scale_factor=stride, mode="nearest"),
                WNConv1d(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    kernel_size=2 * stride,
                    stride=1,
                    bias=False,
                    padding='same'
                )
            )
        else:
            upsample_layer = WNConvTranspose1d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=2 * stride,
                stride=stride,
                padding=math.ceil(stride / 2)
            )

        self.layers = nn.Sequential(
            get_activation(
                "snake" if use_snake else "elu",
                antialias=antialias_activation,
                channels=in_channels
            ),
            upsample_layer,
            ResidualUnit(
                in_channels=out_channels,
                out_channels=out_channels,
                dilation=1,
                use_snake=use_snake
            ),
            ResidualUnit(
                in_channels=out_channels,
                out_channels=out_channels,
                dilation=3,
                use_snake=use_snake
            ),
            ResidualUnit(
                in_channels=out_channels,
                out_channels=out_channels,
                dilation=9,
                use_snake=use_snake
            ),
        )

    def forward(self, x):
        return self.layers(x)


class OobleckEncoder(nn.Module):
    def __init__(
        self,
        in_channels=2,
        channels=128,
        latent_dim=32,
        c_mults=[1, 2, 4, 8],
        strides=[2, 4, 8, 8],
        use_snake=False,
        antialias_activation=False
    ):
        super().__init__()

        c_mults = [1] + c_mults

        self.depth = len(c_mults)

        layers = [
            WNConv1d(
                in_channels=in_channels,
                out_channels=c_mults[0] * channels,
                kernel_size=7,
                padding=3
            )
        ]

        for i in range(self.depth - 1):
            layers += [
                EncoderBlock(
                    in_channels=c_mults[i] * channels,
                    out_channels=c_mults[i + 1] * channels,
                    stride=strides[i],
                    use_snake=use_snake
                )
            ]

        layers += [
            get_activation(
                "snake" if use_snake else "elu",
                antialias=antialias_activation,
                channels=c_mults[-1] * channels
            ),
            WNConv1d(
                in_channels=c_mults[-1] * channels,
                out_channels=latent_dim,
                kernel_size=3,
                padding=1
            )
        ]

        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)


class OobleckDecoder(nn.Module):
    def __init__(
        self,
        out_channels=2,
        channels=128,
        latent_dim=32,
        c_mults=[1, 2, 4, 8],
        strides=[2, 4, 8, 8],
        use_snake=False,
        antialias_activation=False,
        use_nearest_upsample=False,
        final_tanh=True
    ):
        super().__init__()

        c_mults = [1] + c_mults

        self.depth = len(c_mults)

        layers = [
            WNConv1d(
                in_channels=latent_dim,
                out_channels=c_mults[-1] * channels,
                kernel_size=7,
                padding=3
            ),
        ]

        for i in range(self.depth - 1, 0, -1):
            layers += [
                DecoderBlock(
                    in_channels=c_mults[i] * channels,
                    out_channels=c_mults[i - 1] * channels,
                    stride=strides[i - 1],
                    use_snake=use_snake,
                    antialias_activation=antialias_activation,
                    use_nearest_upsample=use_nearest_upsample
                )
            ]

        layers += [
            get_activation(
                "snake" if use_snake else "elu",
                antialias=antialias_activation,
                channels=c_mults[0] * channels
            ),
            WNConv1d(
                in_channels=c_mults[0] * channels,
                out_channels=out_channels,
                kernel_size=7,
                padding=3,
                bias=False
            ),
            nn.Tanh() if final_tanh else nn.Identity()
        ]

        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)


class Bottleneck(nn.Module):
    def __init__(self, is_discrete: bool = False):
        super().__init__()

        self.is_discrete = is_discrete

    def encode(self, x, return_info=False, **kwargs):
        raise NotImplementedError

    def decode(self, x):
        raise NotImplementedError


@torch.jit.script
def vae_sample(mean, scale) -> dict[str, torch.Tensor]:
    stdev = nn.functional.softplus(scale) + 1e-4
    var = stdev * stdev
    logvar = torch.log(var)
    latents = torch.randn_like(mean) * stdev + mean

    kl = (mean * mean + var - logvar - 1).sum(1).mean()
    return {"latents": latents, "kl": kl}


class VAEBottleneck(Bottleneck):
    def __init__(self):
        super().__init__(is_discrete=False)

    def encode(self,
               x,
               return_info=False,
               **kwargs) -> dict[str, torch.Tensor] | torch.Tensor:
        mean, scale = x.chunk(2, dim=1)
        sampled = vae_sample(mean, scale)

        if return_info:
            return sampled["latents"], {"kl": sampled["kl"]}
        else:
            return sampled["latents"]

    def decode(self, x):
        return x


def compute_mean_kernel(x, y):
    kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1]
    return torch.exp(-kernel_input).mean()


class Pretransform(nn.Module):
    def __init__(self, enable_grad, io_channels, is_discrete):
        super().__init__()

        self.is_discrete = is_discrete
        self.io_channels = io_channels
        self.encoded_channels = None
        self.downsampling_ratio = None

        self.enable_grad = enable_grad

    def encode(self, x):
        raise NotImplementedError

    def decode(self, z):
        raise NotImplementedError

    def tokenize(self, x):
        raise NotImplementedError

    def decode_tokens(self, tokens):
        raise NotImplementedError


class StableVAE(LoadPretrainedBase, AutoEncoderBase):
    def __init__(
        self,
        encoder,
        decoder,
        latent_dim,
        downsampling_ratio,
        sample_rate,
        io_channels=2,
        bottleneck: Bottleneck = None,
        pretransform: Pretransform = None,
        in_channels=None,
        out_channels=None,
        soft_clip=False,
        pretrained_ckpt: str | Path = None
    ):
        LoadPretrainedBase.__init__(self)
        AutoEncoderBase.__init__(
            self,
            downsampling_ratio=downsampling_ratio,
            sample_rate=sample_rate,
            latent_shape=(latent_dim, None)
        )

        self.latent_dim = latent_dim
        self.io_channels = io_channels
        self.in_channels = io_channels
        self.out_channels = io_channels
        self.min_length = self.downsampling_ratio

        if in_channels is not None:
            self.in_channels = in_channels

        if out_channels is not None:
            self.out_channels = out_channels

        self.bottleneck = bottleneck
        self.encoder = encoder
        self.decoder = decoder
        self.pretransform = pretransform
        self.soft_clip = soft_clip
        self.is_discrete = self.bottleneck is not None and self.bottleneck.is_discrete

        self.remove_autoencoder_prefix_fn: Callable = remove_key_prefix_factory(
            "autoencoder."
        )
        if pretrained_ckpt is not None:
            self.load_pretrained(pretrained_ckpt)

    def process_state_dict(self, model_dict, state_dict):
        state_dict = state_dict["state_dict"]
        state_dict = self.remove_autoencoder_prefix_fn(model_dict, state_dict)
        return state_dict

    def encode(
        self, waveform: torch.Tensor, waveform_lengths: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        z = self.encoder(waveform)
        z = self.bottleneck.encode(z)
        z_length = waveform_lengths // self.downsampling_ratio
        z_mask = create_mask_from_length(z_length)
        return z, z_mask

    def decode(self, latents: torch.Tensor) -> torch.Tensor:
        waveform = self.decoder(latents)
        return waveform


class StableVAEProjectorWrapper(nn.Module):
    def __init__(
        self,
        vae_dim: int,
        embed_dim: int,
        model: StableVAE | None = None,
    ):
        super().__init__()
        self.model = model
        self.proj = nn.Linear(vae_dim, embed_dim)

    def forward(
        self, waveform: torch.Tensor, waveform_lengths: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        self.model.eval()
        with torch.no_grad():
            z, z_mask = self.model.encode(waveform, waveform_lengths)
        z = self.proj(z.transpose(1, 2))
        return {"output": z, "mask": z_mask}


if __name__ == '__main__':
    import hydra
    from utils.config import generate_config_from_command_line_overrides
    model_config = generate_config_from_command_line_overrides(
        "configs/model/autoencoder/stable_vae.yaml"
    )
    autoencoder: StableVAE = hydra.utils.instantiate(model_config)
    autoencoder.eval()

    example_wav = os.environ.get("EXAMPLE_WAV", "./example.wav")
    waveform, sr = torchaudio.load(example_wav)
    waveform = waveform.mean(0, keepdim=True)
    waveform = torchaudio.functional.resample(
        waveform, sr, model_config["sample_rate"]
    )
    print("waveform: ", waveform.shape)
    with torch.no_grad():
        latent, latent_length = autoencoder.encode(
            waveform, torch.as_tensor([waveform.shape[-1]])
        )
        print("latent: ", latent.shape)
        reconstructed = autoencoder.decode(latent)
        print("reconstructed: ", reconstructed.shape)
    import soundfile as sf
    sf.write(
        "./reconstructed.wav",
        reconstructed[0, 0].numpy(),
        samplerate=model_config["sample_rate"]
    )
