import math
import torch
import copy

import torch.nn as nn
import numpy as np
import torch.nn.functional as F

from sgcrl.models.image_based.misc import get_class
from vector_quantize_pytorch import VectorQuantize, ResidualVQ, RandomProjectionQuantizer
from tqdm import tqdm

special_tokens = {
    'PADDING_VALUE': -2, 
    'EOS_TOKEN': -3,
    'SOS_TOKEN': -1
}

def apply_variance_scaling_init(net, scale=1.0):
    if not hasattr(net, 'weight'):
        return
    fan = (net.weight.size(-2) + net.weight.size(-1)) / 2
    init_w = math.sqrt(scale / fan)
    net.weight.data.uniform_(-init_w, init_w)
    net.bias.data.fill_(0)

class quantizer(nn.Module):
    def __init__(
            self, 
            dim, 
            codebook_size, 
            hidden_dim=256, 
            input_dim=2, 
            kitchen_params=False
        ):
        super().__init__()
        if kitchen_params:
            self.quantize = VectorQuantize(
                dim=dim, 
                codebook_size=codebook_size, 
                threshold_ema_dead_code = 2, 
                kmeans_init = True, 
                kmeans_iters = 10
            )
        else:
            self.quantize = VectorQuantize(
                dim=dim, 
                codebook_size=codebook_size
            )

        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, dim)
        )

        self.decoder = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim)
        )

        self.encoder.apply(apply_variance_scaling_init)
        self.decoder.apply(apply_variance_scaling_init)
        self.quantize.apply(apply_variance_scaling_init)


    def forward(self, x, return_latent=False, return_quantized=False):
        z = self.encoder(x)
        quantized, _, commitment_loss = self.quantize(z)
        quantized = z + (quantized - z).detach() 
        x_hat = self.decoder(quantized)

        if return_latent:
            return x_hat, commitment_loss, z
        elif return_quantized:
            return x_hat, commitment_loss, quantized
        else:
            return x_hat, commitment_loss
        
class Tokenizer(nn.Module):
    def __init__(self, quantizer, save_folder, number_of_tokens, offset, norm, device, keys_to_tokenize, is_visual=False):
        nn.Module.__init__(self)
        self._quantizer = copy.deepcopy(quantizer).to(device)
        self._device = device
        self._number_of_tokens = number_of_tokens
        self._tokens = {} # {flattened tuple of token represantant in the latent space: id of that token}
        self._save_folder = save_folder
        self.keys_to_tokenize = keys_to_tokenize
        self.is_visual = is_visual

        if len(norm) > 0:
            self.norm = torch.tensor(norm, device=device)
        else:
            self.norm = torch.tensor(1.0, device=device)
        self.offset = offset
        self._quantizer.eval()

    def __call__(self, episode):
        return self.tokenize(episode)

    @torch.no_grad()
    def init__tokens(self, dataloader):
        for batch in tqdm(dataloader):
            self.tokenize({k: v.to(self._device) for k, v in batch.items()})
            if len(self._tokens) == self._number_of_tokens:
                break
            
        print(f"Initialized {len(self._tokens)} tokens")

    def to(self, device):
        self._quantizer.to(device)
        self._device = device
        self.norm = self.norm.to(device)
        return self

    @torch.no_grad()
    def tokenize(self, episode):
        self._quantizer.eval()
        # self._quantizer.quantize.eval()
        pos = torch.cat([episode[key] for key in self.keys_to_tokenize]).detach().squeeze(0)
        pos = (pos + self.offset) / self.norm
        tokenized, _, continuous_latent = self._quantizer(pos, return_latent=True) # tokenized is the quantized pos
        tokenized = tokenized * self.norm - self.offset
        token_ids = torch.zeros(tokenized.shape[0], dtype=torch.long) # token_ids 

        for i in range(tokenized.shape[0]):
            if hasattr(self, 'is_visual') and self.is_visual:
                token = tuple(tokenized[i].squeeze().detach().cpu().numpy().flatten()) # flatten the image
            else:
                token = tuple(tokenized[i].squeeze().detach().cpu().numpy())
            
            key = self._find_token(token) # find the token in the token cookbook
            # if the token is not found, set it as a representant of a new category
            if key is None:
                self._tokens[token] = len(self._tokens)
                key = token

            token_ids[i] = self._tokens[key]

        episode["token/ids"] = token_ids # index of the tokens for each timesteps
        episode["token"] = tokenized     # token represents for each timesteps
        episode["continuous_latent"] = continuous_latent
        return episode
    
    @torch.no_grad()
    def tokenize_tensor(self, inpt, return_quantized=True):
        self._quantizer.eval()
        # self._quantizer.quantize.eval()
        inpt1 = (inpt + self.offset) / self.norm
        if not return_quantized:
            tokenized, _ = self._quantizer(inpt1, return_latent=False, return_quantized=False)
        else:
            tokenized, _, quantized = self._quantizer(inpt1, return_latent=False, return_quantized=True)
        tokenized = tokenized * self.norm - self.offset
        keys_tensor = torch.tensor(list(self._tokens.keys()), dtype=torch.float, device=inpt.device)
        if hasattr(self, 'is_visual') and self.is_visual:
            distances = torch.cdist(inpt.view(-1,keys_tensor.shape[-1]).unsqueeze(1), keys_tensor.unsqueeze(0))
        else:
            distances = torch.cdist(tokenized.unsqueeze(1), keys_tensor.unsqueeze(0))        
        token_ids = distances.argmin(dim=-1)

        # token_ids = torch.zeros(tokenized.shape[0], dtype=torch.long)

        # for i in range(tokenized.shape[0]):
        #     token = tuple(tokenized[i].detach().cpu().numpy())
            
        #     key = self._find_token(token)
        #     if key is None:
        #         self._tokens[token] = len(self._tokens)
        #         key = token

        #     token_ids[i] = self._tokens[key]

        if not return_quantized:
            return token_ids, tokenized
        else:
            return token_ids, quantized

    def _find_token(self, token):
        maxi = 1e10
        for key in self._tokens.keys():
            maxi = min(np.max(np.abs(np.array(key) - np.array(token))), maxi)
            if np.all(np.abs(np.array(key) - np.array(token)) < 1e-4):
                return key
        print(len(self._tokens), maxi)
        return None
    
    def get_token(self, clusters):
        keys_tensor = torch.tensor(list(self._tokens.keys()), dtype=torch.float, device=clusters.device)

        return keys_tensor[clusters]
    

class ResidualLayer(nn.Module):
    def __init__(
            self,
            in_channels: int,
            out_channels: int):
        super(ResidualLayer, self).__init__()
        self.resblock = nn.Sequential(nn.Conv2d(in_channels, out_channels,
                                                kernel_size=3, padding=1, bias=False),
                                      nn.ReLU(True),
                                      nn.Conv2d(out_channels, out_channels,
                                                kernel_size=1, bias=False))

    def forward(self, input: torch.Tensor):
        return input + self.resblock(input)

class image_quantizer_v0(nn.Module):
    def __init__(self, dim, codebook_size, beta, hidden_dim=256, norm=[[27.5, 20], [27.5, 20]]):
        super().__init__()
        self.quantize = VectorQuantize(dim=dim, codebook_size=codebook_size, commitment_weight=beta)

        in_channels = 3
        # Encoder
        hidden_sizes = [128, 256, 256, 256, 256, 256]
        modules = []
        for h_dim in hidden_sizes:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels=h_dim,
                              kernel_size=4, stride=2, padding=1),
                    nn.LeakyReLU())
            )
            in_channels = h_dim

        modules.append(
            nn.Sequential(
                nn.Conv2d(in_channels, in_channels,
                          kernel_size=3, stride=1, padding=1),
                nn.LeakyReLU())
        )
        for _ in range(6):
            modules.append(ResidualLayer(in_channels, in_channels))
        modules.append(nn.LeakyReLU())
        modules.append(
            nn.Sequential(
                nn.Conv2d(in_channels, dim,
                          kernel_size=1, stride=1),
                nn.LeakyReLU())
        )
        self.encoder = nn.Sequential(*modules)

        # Decoder
        modules = []
        modules.append(
            nn.Sequential(
                nn.Conv2d(dim,
                          hidden_sizes[-1],
                          kernel_size=3,
                          stride=1,
                          padding=1),
                nn.LeakyReLU())
        )
        for _ in range(6):
            modules.append(ResidualLayer(hidden_sizes[-1], hidden_sizes[-1]))
        modules.append(nn.LeakyReLU())
        hidden_sizes.reverse()
        for i in range(len(hidden_sizes) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_sizes[i],
                                       hidden_sizes[i + 1],
                                       kernel_size=4,
                                       stride=2,
                                       padding=1),
                    nn.LeakyReLU())
            )
        modules.append(
            nn.Sequential(
                nn.ConvTranspose2d(hidden_sizes[-1],
                                   out_channels=3,
                                   kernel_size=4,
                                   stride=2, padding=1),
                nn.Tanh()))
        self.decoder = nn.Sequential(*modules)

        self.norm = torch.tensor(norm).to('cuda')

    def forward(self, x, return_latent=False):
        x = x.permute(0,3,1,2)
        z = self.encoder(x)
        z = z.squeeze(-1).squeeze(-1).unsqueeze(1)
        quantized, id, commitment_loss = self.quantize(z)
        quantized = quantized.squeeze(1).unsqueeze(-1).unsqueeze(-1)
        x_hat = self.decoder(quantized)
        x_hat  = x_hat.permute(0,2,3,1)
        z = z.squeeze(1).unsqueeze(-1).unsqueeze(-1)

        if return_latent:
            return x_hat, commitment_loss, z
        else:
            return x_hat, commitment_loss

class image_quantizer_v1(nn.Module):
    def __init__(
        self,
        model_class,
        in_channels,
        latent_dim,
        codebook_size, 
        hidden_dims: list = None,
        decay: float = 0.99,
        alpha: float = 1.0,
        beta: float = 0.25,
        use_ema: bool = False
    ):
        super().__init__()

        self.vq_vae = get_class(model_class)(
            in_channels=in_channels,
            embedding_dim=latent_dim,
            num_embeddings=codebook_size,
            hidden_sizes=hidden_dims,
            decay=decay,
            alpha=alpha,
            beta=beta,
            use_ema=use_ema
        )
    
    def forward(self, x, return_latent=False):
        x = x.permute(0,3,1,2)
        vq_loss, recons, z, ids = self.vq_vae.compute_loss(x)
        recons = recons.permute(0,2,3,1)
        if return_latent:
            return recons, vq_loss, z
        else:
            return recons, vq_loss
    
    def encode(self, x):
        x = x.permute(0,3,1,2)
        z = self.vq_vae.encode(x)
        return z

    def decode(self, z):
        return self.vq_vae.decode(z)
    
    def get_embeddings(self):
        return self.vq_vae.vq_layer.embedding.weight.data

