import numpy as np
import torch
from torch import nn
from torch.nn import functional as F

from einops import rearrange
from vector_quantize_pytorch import ResidualVQ, FSQ
from dac.nn.quantize import ResidualVectorQuantize as DACResidualVQ

class Bottleneck(nn.Module):
    def __init__(self, is_discrete: bool = False):
        super().__init__()

        self.is_discrete = is_discrete

    def encode(self, x, return_info=False, **kwargs):
        raise NotImplementedError

    def decode(self, x):
        raise NotImplementedError

class DiscreteBottleneck(Bottleneck):
    def __init__(self, num_quantizers, codebook_size, tokens_id):
        super().__init__(is_discrete=True)

        self.num_quantizers = num_quantizers
        self.codebook_size = codebook_size
        self.tokens_id = tokens_id

    def decode_tokens(self, codes, **kwargs):
        raise NotImplementedError
    
class TanhBottleneck(Bottleneck):
    def __init__(self):
        super().__init__(is_discrete=False)
        self.tanh = nn.Tanh()

    def encode(self, x, return_info=False):
        info = {}

        x = torch.tanh(x)

        if return_info:
            return x, info
        else:
            return x

    def decode(self, x):
        return x

def vae_sample(mean, scale):
        stdev = nn.functional.softplus(scale) + 1e-4
        var = stdev * stdev
        logvar = torch.log(var)
        latents = torch.randn_like(mean) * stdev + mean

        kl = (mean * mean + var - logvar - 1).sum(1).mean()

        return latents, kl

class VAEBottleneck(Bottleneck):
    def __init__(self):
        super().__init__(is_discrete=False)

    def encode(self, x, return_info=False, **kwargs):
        info = {}

        mean, scale = x.chunk(2, dim=1)

        x, kl = vae_sample(mean, scale)

        info["kl"] = kl

        if return_info:
            return x, info
        else:
            return x

    def decode(self, x):
        return x

def compute_mean_kernel(x, y):
        kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1]
        return torch.exp(-kernel_input).mean()

def compute_mmd(latents):
    latents_reshaped = latents.permute(0, 2, 1).reshape(-1, latents.shape[1])
    noise = torch.randn_like(latents_reshaped)

    latents_kernel = compute_mean_kernel(latents_reshaped, latents_reshaped)
    noise_kernel = compute_mean_kernel(noise, noise)
    latents_noise_kernel = compute_mean_kernel(latents_reshaped, noise)
    
    mmd = latents_kernel + noise_kernel - 2 * latents_noise_kernel
    return mmd.mean()

class WassersteinBottleneck(Bottleneck):
    def __init__(self, noise_augment_dim: int = 0, bypass_mmd: bool = False):
        super().__init__(is_discrete=False)

        self.noise_augment_dim = noise_augment_dim
        self.bypass_mmd = bypass_mmd
    
    def encode(self, x, return_info=False):
        info = {}

        if self.training and return_info:
            if self.bypass_mmd:
                mmd = torch.tensor(0.0)
            else:
                mmd = compute_mmd(x)
                
            info["mmd"] = mmd
        
        if return_info:
            return x, info
        
        return x

    def decode(self, x):

        if self.noise_augment_dim > 0:
            noise = torch.randn(x.shape[0], self.noise_augment_dim,
                                x.shape[-1]).type_as(x)
            x = torch.cat([x, noise], dim=1)

        return x

class L2Bottleneck(Bottleneck):
    def __init__(self):
        super().__init__(is_discrete=False)
    
    def encode(self, x, return_info=False):
        info = {}

        x = F.normalize(x, dim=1)

        if return_info:
            return x, info
        else:
            return x
        
    def decode(self, x):
        return F.normalize(x, dim=1)
        
class RVQBottleneck(DiscreteBottleneck):
    def __init__(self, **quantizer_kwargs):
        super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices")
        self.quantizer = ResidualVQ(**quantizer_kwargs)
        self.num_quantizers = quantizer_kwargs["num_quantizers"]

    def encode(self, x, return_info=False, **kwargs):
        info = {}

        x = rearrange(x, "b c n -> b n c")
        x, indices, loss = self.quantizer(x)
        x = rearrange(x, "b n c -> b c n")

        info["quantizer_indices"] = indices
        info["quantizer_loss"] = loss.mean()

        if return_info:
            return x, info
        else:
            return x
        
    def decode(self, x):
        return x
    
    def decode_tokens(self, codes, **kwargs):
        latents = self.quantizer.get_outputs_from_indices(codes)

        return self.decode(latents, **kwargs)
    
class RVQVAEBottleneck(DiscreteBottleneck):
    def __init__(self, **quantizer_kwargs):
        super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices")
        self.quantizer = ResidualVQ(**quantizer_kwargs)
        self.num_quantizers = quantizer_kwargs["num_quantizers"]

    def encode(self, x, return_info=False):
        info = {}

        x, kl = vae_sample(*x.chunk(2, dim=1))

        info["kl"] = kl

        x = rearrange(x, "b c n -> b n c")
        x, indices, loss = self.quantizer(x)
        x = rearrange(x, "b n c -> b c n")

        info["quantizer_indices"] = indices
        info["quantizer_loss"] = loss.mean()

        if return_info:
            return x, info
        else:
            return x
        
    def decode(self, x):
        return x
    
    def decode_tokens(self, codes, **kwargs):
        latents = self.quantizer.get_outputs_from_indices(codes)

        return self.decode(latents, **kwargs)

class DACRVQBottleneck(DiscreteBottleneck):
    def __init__(self, quantize_on_decode=False, noise_augment_dim=0, **quantizer_kwargs):
        super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes")
        self.quantizer = DACResidualVQ(**quantizer_kwargs)
        self.num_quantizers = quantizer_kwargs["n_codebooks"]
        self.quantize_on_decode = quantize_on_decode
        self.noise_augment_dim = noise_augment_dim

    def encode(self, x, return_info=False, **kwargs):
        info = {}

        info["pre_quantizer"] = x

        if self.quantize_on_decode:
            return x, info if return_info else x

        z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, **kwargs)

        output = {
            "z": z,
            "codes": codes,
            "latents": latents,
            "vq/commitment_loss": commitment_loss,
            "vq/codebook_loss": codebook_loss,
        }

        output["vq/commitment_loss"] /= self.num_quantizers
        output["vq/codebook_loss"] /= self.num_quantizers

        info.update(output)

        if return_info:
            return output["z"], info
        
        return output["z"]
    
    def decode(self, x):

        if self.quantize_on_decode:
            x = self.quantizer(x)[0]

        if self.noise_augment_dim > 0:
            noise = torch.randn(x.shape[0], self.noise_augment_dim,
                                x.shape[-1]).type_as(x)
            x = torch.cat([x, noise], dim=1)

        return x
    
    def decode_tokens(self, codes, **kwargs):
        latents, _, _ = self.quantizer.from_codes(codes)

        return self.decode(latents, **kwargs)

class DACRVQVAEBottleneck(DiscreteBottleneck):
    def __init__(self, quantize_on_decode=False, **quantizer_kwargs):
        super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes")
        self.quantizer = DACResidualVQ(**quantizer_kwargs)
        self.num_quantizers = quantizer_kwargs["n_codebooks"]
        self.quantize_on_decode = quantize_on_decode

    def encode(self, x, return_info=False, n_quantizers: int = None):
        info = {}

        mean, scale = x.chunk(2, dim=1)

        x, kl = vae_sample(mean, scale)

        info["pre_quantizer"] = x
        info["kl"] = kl

        if self.quantize_on_decode:
            return x, info if return_info else x

        z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, n_quantizers=n_quantizers)

        output = {
            "z": z,
            "codes": codes,
            "latents": latents,
            "vq/commitment_loss": commitment_loss,
            "vq/codebook_loss": codebook_loss,
        }

        output["vq/commitment_loss"] /= self.num_quantizers
        output["vq/codebook_loss"] /= self.num_quantizers

        info.update(output)

        if return_info:
            return output["z"], info
        
        return output["z"]
    
    def decode(self, x):

        if self.quantize_on_decode:
            x = self.quantizer(x)[0]

        return x

    def decode_tokens(self, codes, **kwargs):
        latents, _, _ = self.quantizer.from_codes(codes)

        return self.decode(latents, **kwargs)
    
class FSQBottleneck(DiscreteBottleneck):
    def __init__(self, noise_augment_dim=0, **kwargs):
        super().__init__(num_quantizers = kwargs.get("num_codebooks", 1), codebook_size = np.prod(kwargs["levels"]), tokens_id = "quantizer_indices")

        self.noise_augment_dim = noise_augment_dim

        self.quantizer = FSQ(**kwargs, allowed_dtypes=[torch.float16, torch.float32, torch.float64])

    def encode(self, x, return_info=False):
        info = {}

        orig_dtype = x.dtype
        x = x.float()

        x = rearrange(x, "b c n -> b n c")
        x, indices = self.quantizer(x)
        x = rearrange(x, "b n c -> b c n")

        x = x.to(orig_dtype)

        # Reorder indices to match the expected format
        indices = rearrange(indices, "b n q -> b q n")

        info["quantizer_indices"] = indices

        if return_info:
            return x, info
        else:
            return x
        
    def decode(self, x):

        if self.noise_augment_dim > 0:
            noise = torch.randn(x.shape[0], self.noise_augment_dim,
                                x.shape[-1]).type_as(x)
            x = torch.cat([x, noise], dim=1)

        return x
    
    def decode_tokens(self, tokens, **kwargs):
        latents = self.quantizer.indices_to_codes(tokens)

        return self.decode(latents, **kwargs)