import torch
from torch import nn
from torch.nn import functional as F


# Vector Quantization Class

class VectorQuantizerEMA(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost=0.25, decay=0.99, epsilon=1e-5):
        super(VectorQuantizerEMA, self).__init__()
        
        self._embedding_dim = embedding_dim
        self._num_embeddings = num_embeddings
        
        self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
        self._embedding.weight.data.normal_()
        self._commitment_cost = commitment_cost
        
        self.register_buffer('_ema_cluster_size', torch.zeros(num_embeddings))
        self._ema_w = nn.Parameter(torch.Tensor(num_embeddings, self._embedding_dim))
        self._ema_w.data.normal_()
        
        self._decay = decay
        self._epsilon = epsilon

    def forward(self, inputs):
        # convert inputs from BCL -> BLC
        inputs = inputs.permute(0, 2, 1).contiguous()
        input_shape = inputs.shape
        
        # Flatten input
        flat_input = inputs.view(-1, self._embedding_dim)
        
        # Calculate distances
        distances = (torch.sum(flat_input**2, dim=1, keepdim=True) 
                    + torch.sum(self._embedding.weight**2, dim=1)
                    - 2 * torch.matmul(flat_input, self._embedding.weight.t()))
            
        # Encoding
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
        encodings.scatter_(1, encoding_indices, 1)
        
        # Quantize and unflatten
        quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
        
        # Use EMA to update the embedding vectors
        if self.training:
            self._ema_cluster_size = self._ema_cluster_size * self._decay + \
                                     (1 - self._decay) * torch.sum(encodings, 0)
            
            # Laplace smoothing of the cluster size
            n = torch.sum(self._ema_cluster_size.data)
            self._ema_cluster_size = (
                (self._ema_cluster_size + self._epsilon)
                / (n + self._num_embeddings * self._epsilon) * n)
            
            dw = torch.matmul(encodings.t(), flat_input)
            self._ema_w = nn.Parameter(self._ema_w * self._decay + (1 - self._decay) * dw)
            
            self._embedding.weight = nn.Parameter(self._ema_w / self._ema_cluster_size.unsqueeze(1))
        
        # Loss
        e_latent_loss = F.mse_loss(quantized.detach(), inputs)
        loss = self._commitment_cost * e_latent_loss
        
        # Straight Through Estimator
        quantized = inputs + (quantized - inputs).detach()
        avg_probs = torch.mean(encodings, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
        
        # convert quantized from BLC -> BCL
        return loss, quantized.permute(0, 2, 1).contiguous(), perplexity, encodings, encoding_indices, distances



# One residual block

class ResBlock(nn.Module):
    def __init__(self, in_channel, channel):
        super().__init__()

        self.conv = nn.Sequential(
            nn.LeakyReLU(),
            nn.Conv1d(in_channel, channel, 3, padding=1),
            nn.LeakyReLU(),
            nn.Conv1d(channel, in_channel, 1),
        )

    def forward(self, input):
        out = self.conv(input)
        out += input

        return out


#ENCODERS

class Encoder(nn.Module):
    def __init__(self, in_channel, channel, k_s_list, n_res_block, n_res_channel):
        super().__init__()

        blocks = []

        for idx, kernel_size in enumerate(k_s_list):

        	if idx == 0:
        		blocks.append(nn.Conv1d(in_channel, channel, kernel_size, stride=2, padding=1))
        	else:
        		blocks.append(nn.Conv1d(channel, channel, kernel_size, stride=2, padding=1))
        	blocks.append(nn.LeakyReLU())


        for i in range(n_res_block):
            blocks.append(ResBlock(channel, n_res_channel))

        blocks.append(nn.LeakyReLU())

        self.blocks = nn.Sequential(*blocks)

    def forward(self, input):
        return self.blocks(input)



#DECODERS

class Decoder_ConvTranspose(nn.Module):
    def __init__(self, in_channel, out_channel, channel, k_s_list):
        super().__init__()

        blocks = []

        for idx in range(len(k_s_list)):


            if idx == 0:

                blocks.append(nn.ConvTranspose1d(in_channel, channel, k_s_list[idx], 
                          stride=2, padding=1, bias=False))
                blocks.append(nn.LeakyReLU())


            elif (idx > 0) and idx != (len(k_s_list) - 1):

                blocks.append(nn.ConvTranspose1d(channel, channel, k_s_list[idx], 
            			  stride=2, padding=1, bias=False))
                blocks.append(nn.LeakyReLU())


            else:

                blocks.append(nn.ConvTranspose1d(channel, out_channel, k_s_list[idx], 
                          stride=2, padding=1, bias=False))


        self.blocks = nn.Sequential(*blocks)

    def forward(self, input):

        return self.blocks(input)


#Hierarchical Vector Quantization Auto-Encoder

class Hierarchical_VQ_AE(nn.Module):
    def __init__(self, k_s_list_encod_b, 
                 k_s_list_encod_m,
                 ks_list_decod_mid,
                 ks_list_decod_bot,
                 ks_list_decod_mid_to_bot,
                 in_channel, 
                 channel,
                 out_channel_deconv, 
                 n_res_block, 
                 n_res_channel, 
                 embedd_dim, 
                 n_embed_bot, 
                 n_embed_mid, 
                 comit_cost, 
                 seq_len, 
                 seq_bot_len, 
                 seq_mid_len):

        super(Hierarchical_VQ_AE, self).__init__()

        #Two levels encoder 
        self.enc_b = Encoder(in_channel, channel, k_s_list_encod_b, n_res_block, n_res_channel)
        self.enc_m = Encoder(channel, channel, k_s_list_encod_m, n_res_block, n_res_channel)

        self.quantize_conv_m = nn.Conv1d(channel, embedd_dim, 1)
        self.quantize_m = VectorQuantizerEMA(num_embeddings=n_embed_mid, embedding_dim=embedd_dim, commitment_cost=comit_cost)

        self.decode_mid_to_bot = Decoder_ConvTranspose(in_channel=embedd_dim, out_channel=embedd_dim, 
                                                       channel=channel, k_s_list=ks_list_decod_mid_to_bot)

        self.decode_mid = Decoder_ConvTranspose(in_channel=embedd_dim, out_channel=out_channel_deconv, 
                                                       channel=channel, k_s_list=ks_list_decod_mid)

        self.decode_bot = Decoder_ConvTranspose(in_channel=embedd_dim, out_channel=out_channel_deconv, 
                                                       channel=channel, k_s_list=ks_list_decod_bot)

        self.quantize_conv_b = nn.Conv1d(embedd_dim + channel, embedd_dim, 1)
        self.quantize_b = VectorQuantizerEMA(num_embeddings=n_embed_bot, embedding_dim=embedd_dim, commitment_cost=comit_cost)
        

    def encode(self, input):

        enc_b = self.enc_b(input)
        enc_m = self.enc_m(enc_b)

        quant_m = self.quantize_conv_m(enc_m)
        loss_m, quant_m, perplexity_m, encodings_m, encoding_indices_m, distances_m = self.quantize_m(quant_m)

        #Bottom part
        dec_m = self.decode_mid_to_bot(quant_m)
        enc_b = torch.cat([dec_m, enc_b], 1)
        quant_b = self.quantize_conv_b(enc_b)
        loss_b, quant_b, perplexity_b, encodings_b, encoding_indices_b, distances_b = self.quantize_b(quant_b)

        return quant_m, quant_b, loss_m+loss_b, encoding_indices_m, encoding_indices_b, distances_m, distances_b


    def decode(self, quant_m, quant_b):


        recon_b = self.decode_bot(quant_b).mean(dim=1)
        recon_m = self.decode_mid(quant_m).mean(dim=1)
        recon = torch.stack([recon_b, recon_m]).mean(dim=0)

        return recon
        

    def forward(self, input):

        quant_m, quant_b, loss, encoding_indices_m, encoding_indices_b, distances_m, distances_b = self.encode(input)
        recons = self.decode(quant_m, quant_b)

        return recons, loss, quant_b, quant_m, encoding_indices_m.view(input.shape[0],-1), encoding_indices_b.view(input.shape[0],-1), 
 











