import torch
import torch.nn as nn
import torch.nn.functional as F

from codes.data.models.dgms.base import BaseDGM

class VectorQuantizer(nn.Module):

    def __init__(self, embedding_dim, num_embeddings, commitment_cost):
        super(VectorQuantizer, self).__init__()

        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.commitment_cost = commitment_cost

        # Initialize the codebook (embedding weights)
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        self.embedding.weight.data.uniform_(-1/num_embeddings, 1/num_embeddings)

    def forward(self, inputs):
        """

        Args:
            inputs (tensor): Tensor of size (batch_size, num_channels, sequence_length).

        Returns:

        """
        # Reshape input: N x C x L -> N x L x C
        inputs = inputs.permute(0, 2, 1).contiguous()
        input_shape = inputs.size()
        
        # Flatten input
        input_flattened = inputs.view(-1, self.embedding_dim)

        # Calculate distances
        distances = (torch.sum(input_flattened ** 2, dim=1, keepdim=True) 
                     - 2 * torch.matmul(input_flattened, self.embedding.weight.t())
                     + torch.sum(self.embedding.weight ** 2, dim=1))

        # Find nearest embedding
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        encodings = torch.zeros(
            encoding_indices.shape[0], self.num_embeddings, device=inputs.device)
        encodings.scatter_(1, encoding_indices, 1)

        # Quantize and unflatten
        quantized = torch.matmul(encodings, self.embedding.weight).view(input_shape)

        # Loss
        e_latent_loss = F.mse_loss(quantized.detach(), inputs)
        q_latent_loss = F.mse_loss(quantized, inputs.detach())
        loss = q_latent_loss + self.commitment_cost * e_latent_loss

        quantized = inputs + (quantized - inputs).detach()
        avg_probs = torch.mean(encodings, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))

        return loss, quantized.permute(0, 2, 1).contiguous(), perplexity, encodings

class VQVAE(nn.Module):

    def __init__(
        self, 
        encoder, 
        decoder, 
        enc_out_dim, 
        z_dim,
        num_embedding,
        commitment_cost
    ):
        super(VQVAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

        self.pre_vq_conv = nn.Linear(enc_out_dim, z_dim)
        self.vq_layer = VectorQuantizer(
            num_embedding,
            z_dim,
            commitment_cost
        )

    def forward_encoder(self, x):
        enc_out = self.encoder(x)
        z = self.pre_vq_conv(enc_out)
        return z
    
    def generate(self, z):
        return self.decoder(z)
    
    def forward(self, x, return_only_loss=True):
        z = self.encode(x)
        vq_loss, quantized, perplexity, _ = self.vq_layer(z)
        x_recon = self.decode(quantized)
        
        recon_loss = F.mse_loss(x_recon, x)
        loss = recon_loss + vq_loss

        assert return_only_loss
        return loss