import torch
import torch.nn as nn
import torch.nn.functional as F

class VectorQuantizer(nn.Module):
    def __init__(
            self,
            num_embeddings: int,
            embedding_dim: int,
            beta: float = 0.25
        ):
            super().__init__()
            self.K = num_embeddings
            self.D = embedding_dim
            self.beta = beta
            self.embedding = nn.Embedding(self.K, self.D)
            self.embedding.weight.data.uniform_(-1/self.K, 1/self.K)
    
    def forward(self, z: torch.Tensor):
        # Reshape latents z
        z = z.permute(0, 2, 3, 1).contiguous()  # [B x D x H x W] -> [B x H x W x D]
        flat_z = z.view(-1, self.D)  # [BHW x D]

        # Compute L2 distance between the latents and the embedding weights
        distances = torch.sum(flat_z ** 2, dim=1, keepdim=True) + \
               torch.sum(self.embedding.weight ** 2, dim=1) - \
               2 * torch.matmul(flat_z, self.embedding.weight.t())  # [BHW x K]

        # Find the indicices of the closest encodings
        encoding_inds = torch.argmin(distances, dim=1).unsqueeze(1)  # [BHW, 1]
        
        # Convert to one-hot encodings
        encoding_one_hot = torch.zeros(encoding_inds.size(0), self.K, device=z.device)
        encoding_one_hot.scatter_(1, encoding_inds, 1)  # [BHW x K]

        # Quantize z
        quantized_z = torch.matmul(encoding_one_hot, self.embedding.weight)  # [BHW, D]
        quantized_z = quantized_z.view(z.shape)  # [B x H x W x D]

        # Compute VQ Losses
        embedding_loss = F.mse_loss(quantized_z, z.detach())
        commitment_loss = F.mse_loss(quantized_z.detach(), z)
        vq_loss = embedding_loss + self.beta*commitment_loss

        # Pass gradient only to latents with straight-through estimation
        quantized_z = z + (quantized_z - z).detach()
        return quantized_z.permute(0, 3, 1, 2).contiguous(), vq_loss, encoding_inds  # [B x D x H x W]

class VectorQuantizerEMA(nn.Module):
    def __init__(
            self,
            num_embeddings: int,
            embedding_dim: int,
            decay: float,
            beta: float = 0.25,
    ):
        super().__init__()
        self.K = num_embeddings
        self.D = embedding_dim
        self.decay = decay
        self.beta = beta

        embedding = torch.empty(self.K, self.D)
        embedding.data.uniform_(-1/self.K, 1/self.K)
        self.register_buffer('embedding', embedding)

        self.register_buffer('cluster_size', torch.zeros(self.K))
        self.register_buffer('ema_embedding', torch.zeros(self.K, self.D))
        self.ema_embedding.data.uniform_(-1/self.K, 1/self.K)
    
    def forward(self, z: torch.Tensor):
        # Reshape latents z
        z = z.permute(0, 2, 3, 1).contiguous()  # [B x D x H x W] -> [B x H x W x D]
        flat_z = z.view(-1, self.D)  # [BHW x D]

        # Compute L2 distance between the latents and the embedding weights
        distances = torch.sum(flat_z ** 2, dim=1, keepdim=True) + \
                torch.sum(self.embedding ** 2, dim=1) - \
                2 * torch.matmul(flat_z, self.embedding.t())  # [BHW x K]
        
        # Find the indicices of the closest encodings
        encoding_inds = torch.argmin(distances, dim=1).unsqueeze(1)  # [BHW, 1]

        # Convert to one-hot encodings
        encoding_one_hot = torch.zeros(encoding_inds.size(0), self.K, device=z.device)
        encoding_one_hot.scatter_(1, encoding_inds, 1)  # [BHW x K]

        # Quantize z
        quantized_z = torch.matmul(encoding_one_hot, self.embedding)  # [BHW, D]
        quantized_z = quantized_z.view(z.shape)  # [B x H x W x D]

        # Update embedding parameters though ema
        if self.training:
            n_i = torch.sum(encoding_one_hot, dim=0)
            self.cluster_size = self.cluster_size*self.decay + n_i*(1 - self.decay)
            z_sum = torch.matmul(encoding_one_hot.T, z.reshape(-1,self.D))
            ema_embedding = self.ema_embedding*self.decay + z_sum*(1 - self.decay)
            n = torch.sum(self.cluster_size)
            self.cluster_size = ((self.cluster_size + 1e-5) / (n + self.K*1e-5)*n)
            self.embedding.data.copy_(ema_embedding / self.cluster_size.unsqueeze(-1))
            self.ema_embedding.data.copy_(ema_embedding)
        
        commitment_loss = F.mse_loss(quantized_z.detach(), z)
        vq_loss = self.beta*commitment_loss

        # Pass gradient only to latents with straight-through estimation
        quantized_z = z + (quantized_z - z).detach()
        return quantized_z.permute(0, 3, 1, 2).contiguous(), vq_loss, commitment_loss, encoding_inds # [B x D x H x W]

class VQ_VAE_2L(nn.Module):
    def __init__(
            self,
            in_channels: int,
            embedding_dim: int,
            num_embeddings: int,
            hidden_sizes: list = None,
            decay: float = 0.99,
            alpha: float = 1.0,
            beta: float = 0.25,
            use_ema: bool = False
    ):
        super().__init__()

        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.decay = decay
        self.alpha = alpha
        self.beta = beta
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 8, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(8, 16, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(16 * 16 * 16, embedding_dim)
        )

        # VectorQuantizer
        if use_ema:
            self.vq_layer = VectorQuantizerEMA(num_embeddings,embedding_dim,decay,beta)
        else:
            self.vq_layer = VectorQuantizer(num_embeddings,embedding_dim,beta)
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(embedding_dim, 16 * 16 * 16),
            nn.ReLU(),
            nn.Unflatten(1, (16, 16, 16)),
            nn.ConvTranspose2d(16, 8, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(8, in_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )
    
    def encode(self, inputs):
        return self.encoder(inputs)
    
    def decode(self, z):
        return self.decoder(z)
    
    def forward(self, inputs):
        z = self.encode(inputs)
        quantized_z, _ = self.vq_layer(z)
        quantized_z = quantized_z
        return self.decode(quantized_z)
    
    def compute_loss(self, inputs):
        # print(f'inputs: {inputs.shape}')
        z = self.encode(inputs)
        #print(f'z: {z.shape}')
        quantized_z, vq_loss, ids = self.vq_layer(z.unsqueeze(-1).unsqueeze(-1))
        # print(f'quantized_z: {quantized_z.shape}')
        recons = self.decode(quantized_z.squeeze(-1).squeeze(-1))
        # print(f'recons: {recons.shape}')
        # recons_loss = F.mse_loss(recons, inputs)
        # loss = self.alpha*vq_loss + recons_loss
        return vq_loss, recons, z, ids
    
    def sample(self, num_samples):
        # TODO by learning a PixelCNN on the latent space with our dataset
        # and regenerating z
        raise NotImplementedError
    
    def generate(self, inputs):
        outputs = self.forward(inputs)
        return outputs

class VQ_VAE_3L(nn.Module):
    def __init__(
            self,
            in_channels: int,
            embedding_dim: int,
            num_embeddings: int,
            hidden_sizes: list = None,
            decay: float = 0.99,
            alpha: float = 1.0,
            beta: float = 0.25,
            use_ema: bool = False
    ):
        super().__init__()

        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.decay = decay
        self.alpha = alpha
        self.beta = beta
            
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 16, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(64 * 8 * 8, embedding_dim)
        )

        # VectorQuantizer
        if use_ema:
            self.vq_layer = VectorQuantizerEMA(num_embeddings,embedding_dim,decay,beta)
        else:
            self.vq_layer = VectorQuantizer(num_embeddings,embedding_dim,beta)
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(embedding_dim, 64 * 8 * 8),
            nn.ReLU(),
            nn.Unflatten(1, (64, 8, 8)),
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(16, in_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )
    
    def encode(self, inputs):
        return self.encoder(inputs)
    
    def decode(self, z):
        return self.decoder(z)
    
    def forward(self, inputs):
        z = self.encode(inputs)
        quantized_z, _ = self.vq_layer(z)
        quantized_z = quantized_z
        return self.decode(quantized_z)
    
    def compute_loss(self, inputs):
        # print(f'inputs: {inputs.shape}')
        z = self.encode(inputs)
        #print(f'z: {z.shape}')
        quantized_z, vq_loss, ids = self.vq_layer(z.unsqueeze(-1).unsqueeze(-1))
        # print(f'quantized_z: {quantized_z.shape}')
        recons = self.decode(quantized_z.squeeze(-1).squeeze(-1))
        # print(f'recons: {recons.shape}')
        # recons_loss = F.mse_loss(recons, inputs)
        # loss = self.alpha*vq_loss + recons_loss
        return vq_loss, recons, z, ids

class VQ_VAE_Test(nn.Module):
    def __init__(
            self,
            in_channels: int,
            embedding_dim: int,
            num_embeddings: int,
            hidden_sizes: list = None,
            decay: float = 0.99,
            alpha: float = 1.0,
            beta: float = 0.25,
            use_ema: bool = False
    ):
        super().__init__()

        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.decay = decay
        self.alpha = alpha
        self.beta = beta
            
        # Encoder
        self.color_encoder = nn.Sequential(
            nn.Linear(in_channels, 16),
            nn.ReLU(),
            nn.Linear(16, 16),
            nn.ReLU(),
            nn.Linear(16, 3),
        )

        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=2, stride=2, padding=0),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=2, stride=2, padding=0),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(32 * 16 * 16, embedding_dim)

        )

        # VectorQuantizer
        if use_ema:
            self.vq_layer = VectorQuantizerEMA(num_embeddings,embedding_dim,decay,beta)
        else:
            self.vq_layer = VectorQuantizer(num_embeddings,embedding_dim,beta)
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(embedding_dim, 64 * 8 * 8),
            nn.ReLU(),
            nn.Unflatten(1, (64, 8, 8)),
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(16, in_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )
    
    def encode(self, inputs):
        x = inputs.permute(0,2,3,1)
        x = self.color_encoder(x)
        x = x.permute(0,3,1,2)
        x = self.encoder(x)
        return x
    
    def decode(self, z):
        return self.decoder(z)
    
    def forward(self, inputs):
        z = self.encode(inputs)
        quantized_z, _ = self.vq_layer(z)
        quantized_z = quantized_z
        return self.decode(quantized_z)
    
    def compute_loss(self, inputs):
        # print(f'inputs: {inputs.shape}')
        z = self.encode(inputs)
        #print(f'z: {z.shape}')
        quantized_z, vq_loss, ids = self.vq_layer(z.unsqueeze(-1).unsqueeze(-1))
        # print(f'quantized_z: {quantized_z.shape}')
        recons = self.decode(quantized_z.squeeze(-1).squeeze(-1))
        # print(f'recons: {recons.shape}')
        # recons_loss = F.mse_loss(recons, inputs)
        # loss = self.alpha*vq_loss + recons_loss
        return vq_loss, recons, z, ids
    
    def sample(self, num_samples):
        # TODO by learning a PixelCNN on the latent space with our dataset
        # and regenerating z
        raise NotImplementedError
    
    def generate(self, inputs):
        outputs = self.forward(inputs)
        return outputs

class VQ_VAE_IS_3L(nn.Module):
    def __init__(
            self,
            in_channels: int,
            embedding_dim: int,
            num_embeddings: int,
            hidden_sizes: list = None,
            decay: float = 0.99,
            alpha: float = 1.0,
            beta: float = 0.25,
            use_ema: bool = False
    ):
        super().__init__()

        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.decay = decay
        self.alpha = alpha
        self.beta = beta
            
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 16, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(64 * 8 * 8, embedding_dim)
        )

        # VectorQuantizer
        if use_ema:
            self.vq_layer = VectorQuantizerEMA(num_embeddings,embedding_dim,decay,beta)
        else:
            self.vq_layer = VectorQuantizer(num_embeddings,embedding_dim,beta)
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(embedding_dim, 64 * 8 * 8),
            nn.ReLU(),
            nn.Unflatten(1, (64, 8, 8)),
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(16, in_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )
    
    def encode(self, inputs):
        return self.encoder(inputs)
    
    def decode(self, z):
        return self.decoder(z)
    
    def forward(self, inputs):
        z = self.encode(inputs)
        quantized_z, _ = self.vq_layer(z)
        quantized_z = quantized_z
        return self.decode(quantized_z)
    
    def compute_loss(self, inputs):
        # print(f'inputs: {inputs.shape}')
        z = self.encode(inputs)
        #print(f'z: {z.shape}')
        quantized_z, vq_loss, ids = self.vq_layer(z.unsqueeze(-1).unsqueeze(-1))
        # print(f'quantized_z: {quantized_z.shape}')
        recons = self.decode(quantized_z.squeeze(-1).squeeze(-1))
        # print(f'recons: {recons.shape}')
        # recons_loss = F.mse_loss(recons, inputs)
        # loss = self.alpha*vq_loss + recons_loss
        return vq_loss, recons, z, ids
    
    def sample(self, num_samples):
        # TODO by learning a PixelCNN on the latent space with our dataset
        # and regenerating z
        raise NotImplementedError
    
    def generate(self, inputs):
        outputs = self.forward(inputs)
        return outputs