import torch
from torch import nn
from torch.nn import functional as F

from einops import rearrange
# from vector_quantize_pytorch import ResidualVQ, FSQ
from dac.nn.quantize import ResidualVectorQuantize as DACResidualVQ


class Bottleneck(nn.Module):
    def __init__(self, is_discrete: bool):
        super().__init__()
        self.is_discrete = is_discrete

    def encode(self, x, return_info: bool, **kwargs):
        raise NotImplementedError

    def decode(self, x):
        raise NotImplementedError


class DiscreteBottleneck(Bottleneck):
    def __init__(self, num_quantizers, codebook_size, tokens_id):
        super().__init__(is_discrete=True)
        self.num_quantizers = num_quantizers
        self.codebook_size = codebook_size
        self.tokens_id = tokens_id

    def decode_tokens(self, codes, **kwargs):
        raise NotImplementedError


class TanhBottleneck(Bottleneck):
    def __init__(self):
        super().__init__(is_discrete=False)
        self.tanh = nn.Tanh()

    def encode(self, x, return_info=False):
        x = torch.tanh(x)
        return (x, {}) if return_info else x

    def decode(self, x):
        return x


def vae_sample(mean, scale):
    # stdev = nn.functional.softplus(scale) + 1e-4
    # var = stdev * stdev
    # logvar = torch.log(var)
    # latents = torch.randn_like(mean) * stdev + mean
    # kl = (mean * mean + var - logvar - 1).sum(1).mean()
    
    stdev = torch.exp(0.5 * scale)
    var = torch.exp(scale)
    logvar = scale
    latents = torch.randn_like(mean) * stdev + mean
    kl = 0.5 * (mean * mean + var - logvar - 1).sum(1).mean()
    return latents, kl


class VAEBottleneck(Bottleneck):
    def __init__(self):
        super().__init__(is_discrete=False)

    def encode(self, x, return_info=False):
        # mean, scale = x.chunk(2, dim=1)
        # x, kl = vae_sample(mean, scale)
        
        mean, logvar = x.chunk(2, dim=1)
        logvar = torch.clamp(logvar, -30.0, 20.0)
        x, kl = vae_sample(mean, logvar)
        return (x, {"kl": kl}) if return_info else x

    def decode(self, x):
        return x


def compute_mean_kernel(x, y):
    kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1]
    return torch.exp(-kernel_input).mean()


def compute_mmd(latents):
    latents_reshaped = latents.permute(0, 2, 1).reshape(-1, latents.shape[1])
    noise = torch.randn_like(latents_reshaped)

    latents_kernel = compute_mean_kernel(latents_reshaped, latents_reshaped)
    noise_kernel = compute_mean_kernel(noise, noise)
    latents_noise_kernel = compute_mean_kernel(latents_reshaped, noise)
    mmd = latents_kernel + noise_kernel - 2 * latents_noise_kernel

    return mmd.mean()


class WassersteinBottleneck(Bottleneck):
    def __init__(self, noise_augment_dim: int = 0):
        super().__init__(is_discrete=False)
        self.noise_augment_dim = noise_augment_dim

    def encode(self, x, return_info=False):
        info = {}
        if self.training and return_info:
            mmd = compute_mmd(x)
            info["mmd"] = mmd

        return (x, info) if return_info else x

    def decode(self, x):
        if self.noise_augment_dim > 0:
            noise = torch.randn(x.shape[0], self.noise_augment_dim, x.shape[-1]).type_as(x)
            x = torch.cat([x, noise], dim=1)

        return x


class L2Bottleneck(Bottleneck):
    def __init__(self):
        super().__init__(is_discrete=False)

    def encode(self, x, return_info=False):
        x = F.normalize(x, dim=1)
        return (x, {}) if return_info else x

    def decode(self, x):
        return F.normalize(x, dim=1)


class RVQBottleneck(DiscreteBottleneck):
    def __init__(self, **quantizer_kwargs):
        super().__init__(num_quantizers=quantizer_kwargs["num_quantizers"],
                         codebook_size=quantizer_kwargs["codebook_size"], tokens_id="quantizer_indices")
        self.quantizer = ResidualVQ(**quantizer_kwargs)
        self.num_quantizers = quantizer_kwargs["num_quantizers"]

    def encode(self, x, return_info=False, **kwargs):
        x = rearrange(x, "b c n -> b n c")
        x, indices, loss = self.quantizer(x)
        x = rearrange(x, "b n c -> b c n")

        info = {"quantizer_indices": indices, "quantizer_loss": loss.mean()}
        return (x, info) if return_info else x

    def decode(self, x):
        return x

    def decode_tokens(self, codes, **kwargs):
        latents = self.quantizer.get_outputs_from_indices(codes)
        return self.decode(latents, **kwargs)


class RVQVAEBottleneck(DiscreteBottleneck):
    def __init__(self, **quantizer_kwargs):
        super().__init__(num_quantizers=quantizer_kwargs["num_quantizers"],
                         codebook_size=quantizer_kwargs["codebook_size"], tokens_id="quantizer_indices")
        self.quantizer = ResidualVQ(**quantizer_kwargs)
        self.num_quantizers = quantizer_kwargs["num_quantizers"]

    def encode(self, x, return_info=False):
        x, kl = vae_sample(*x.chunk(2, dim=1))

        x = rearrange(x, "b c n -> b n c")
        x, indices, loss = self.quantizer(x)
        x = rearrange(x, "b n c -> b c n")

        info = {"kl": kl, "quantizer_indices": indices, "quantizer_loss": loss.mean()}
        return (x, info) if return_info else x

    def decode(self, x):
        return x

    def decode_tokens(self, codes, **kwargs):
        latents = self.quantizer.get_outputs_from_indices(codes)
        return self.decode(latents, **kwargs)


class DACRVQBottleneck(DiscreteBottleneck):
    def __init__(self, quantize_on_decode=False, **quantizer_kwargs):
        super().__init__(num_quantizers=quantizer_kwargs["n_codebooks"], codebook_size=quantizer_kwargs["codebook_size"], tokens_id="codes")
        self.quantizer = DACResidualVQ(**quantizer_kwargs)
        self.num_quantizers = quantizer_kwargs["n_codebooks"]
        self.quantize_on_decode = quantize_on_decode

    def encode(self, x, return_info=False, **kwargs):
        info = {"pre_quantizer": x}

        if self.quantize_on_decode:
            return (x, info) if return_info else x

        z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, **kwargs)

        output = {
            "z": z,
            "codes": codes,
            "latents": latents,
            "vq/commitment_loss": commitment_loss,
            "vq/codebook_loss": codebook_loss,
        }

        output["vq/commitment_loss"] /= self.num_quantizers
        output["vq/codebook_loss"] /= self.num_quantizers

        info.update(output)
        return (output["z"], info) if return_info else output["z"]

    def decode(self, x):
        if self.quantize_on_decode:
            x = self.quantizer(x)[0]

        return x

    def decode_tokens(self, codes, **kwargs):
        latents, _, _ = self.quantizer.from_codes(codes)
        return self.decode(latents, **kwargs)


class DACRVQVAEBottleneck(DiscreteBottleneck):
    def __init__(self, quantize_on_decode=False, **quantizer_kwargs):
        super().__init__(num_quantizers=quantizer_kwargs["n_codebooks"], codebook_size=quantizer_kwargs["codebook_size"], tokens_id="codes")
        self.quantizer = DACResidualVQ(**quantizer_kwargs)
        self.num_quantizers = quantizer_kwargs["n_codebooks"]
        self.quantize_on_decode = quantize_on_decode

    def encode(self, x, return_info=False, n_quantizers: int = None):
        mean, scale = x.chunk(2, dim=1)
        x, kl = vae_sample(mean, scale)

        info = {"pre_quantizer": x, "kl": kl}

        if self.quantize_on_decode:
            return (x, info) if return_info else x

        z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, n_quantizers=n_quantizers)

        output = {
            "z": z,
            "codes": codes,
            "latents": latents,
            "vq/commitment_loss": commitment_loss,
            "vq/codebook_loss": codebook_loss,
        }

        output["vq/commitment_loss"] /= self.num_quantizers
        output["vq/codebook_loss"] /= self.num_quantizers

        info.update(output)

        return (output["z"], info) if return_info else output["z"]

    def decode(self, x):
        if self.quantize_on_decode:
            x = self.quantizer(x)[0]

        return x

    def decode_tokens(self, codes, **kwargs):
        latents, _, _ = self.quantizer.from_codes(codes)
        return self.decode(latents, **kwargs)


class FSQBottleneck(DiscreteBottleneck):
    def __init__(self, dim, levels):
        super().__init__(num_quantizers=1, codebook_size=levels ** dim, tokens_id="quantizer_indices")
        self.quantizer = FSQ(levels=[levels] * dim)

    def encode(self, x, return_info=False):
        x = rearrange(x, "b c n -> b n c")
        x, indices = self.quantizer(x)
        x = rearrange(x, "b n c -> b c n")

        info = {"quantizer_indices": indices}
        return (x, info) if return_info else x

    def decode(self, x):
        return x

    def decode_tokens(self, tokens, **kwargs):
        latents = self.quantizer.indices_to_codes(tokens)
        return self.decode(latents, **kwargs)
