
import torch
import torch.nn as nn
import numpy as np
from models.quantizer import VectorQuantizer, SoftConvexQuantizer_Approx, GumbelQuantize
from vqtorch.nn.vq import VectorQuant
from vqtorch.nn.rvq import ResidualVectorQuant

from typing import Tuple, Optional
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
class VAE(nn.Module):
    def __init__(
        self,
        in_channel: int=3,
        out_channel: int=3,
        hidden_channel: int=128,
        n_res_block: int=2,
        n_res_channel: int=32,
        embed_dim: int=64,
        n_embed: int=512,
        decay: float=0.99,
        quant_type: str='scq',
        iterations_scq: int= 20,
        generate_latents: bool=False,
        num_samples_latent: int=1,
    ):
        super().__init__()

        self.encoder = Encoder(in_channel, hidden_channel, n_res_block, n_res_channel, stride=2)
        self.decoder = Decoder(hidden_channel, out_channel, hidden_channel, n_res_block, n_res_channel, stride=2,)
        self.pre_quantize = nn.Conv2d(hidden_channel, embed_dim, 1)
        self.post_quantize = nn.Conv2d(embed_dim, hidden_channel, 1)
        self.generate_latents = generate_latents
        self.num_samples_latents = num_samples_latent
        self.codebook_vec_dim = embed_dim
        self.codebook_size = n_embed
        self.quant_type = quant_type
        if quant_type == 'scq': 
            self.quantization = SoftConvexQuantizer_Approx(iterations_scq, embed_dim, n_embed, beta=0.25)
        elif quant_type == 'vq':
            self.quantization = VectorQuantizer(n_embed, embed_dim, beta=0.25)
        elif quant_type == 'vqreplace':
            self.quantization = VectorQuant(embed_dim, n_embed, beta=0.25, replace_freq=20)
        elif quant_type == 'vqaffineopt':
            self.quantization = VectorQuant(embed_dim, n_embed, sync_nu=0.1, affine_lr=0.1, beta=0.25)
        elif quant_type == 'vqreplaceaffineopt':
            self.quantization = VectorQuant(embed_dim, n_embed, sync_nu=0.1, affine_lr=0.1, beta=0.25, replace_freq=20)
        elif quant_type == 'rv':
            self.quantization = ResidualVectorQuant(embed_dim, n_embed)
        else:
            self.quantization = GumbelQuantize(n_embed, embed_dim)

    def forward(self, x):
        h = self.encoder(x)
        h = self.pre_quantize(h)
        #print(h.shape)
        z_e = h.detach()
        if self.quant_type == 'scq' or self.quant_type == 'vq' or self.quant_type == 'gumbel': 
            embedding_loss, h, perplexity, quant_loss, latents, latents_prob = self.quantization(h)
            latent_images = None
            if self.generate_latents:
                latent_images = self.generate_latent_images(latents, h.shape)
        else:
            h, results = self.quantization(h)
            embedding_loss = results['loss'] 
            perplexity = results['perplexity']
        h = self.post_quantize(h)
        h = self.decoder(h)
        return embedding_loss, h, perplexity

    def generate_latent_images(self, latents, shape):
        latents_r = latents.view(-1, self.codebook_size)
        if self.scq:
            _, latent_images = torch.topk(latents_r, self.num_samples_latents, dim=1)
            #print(latent_images)
        else:
            latent_images = torch.argmax(latents_r, dim=1, keepdim=True)
            #print(latent_images.shape)
        return latent_images.view(shape[0], self.num_samples_latents, shape[2], shape[3])

class ResBlock(nn.Module):
    def __init__(self, in_channel, channel):
        super().__init__()

        self.conv = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(in_channel, channel, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channel, in_channel, 1),
        )

    def forward(self, input):
        out = self.conv(input)
        out += input

        return out


class Encoder(nn.Module):
    def __init__(self, in_channel, channel, n_res_block, n_res_channel, stride):
        super().__init__()

        if stride == 4:
            blocks = [
                nn.Conv2d(in_channel, channel // 2, 4, stride=2, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(channel // 2, channel, 4, stride=2, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(channel, channel, 3, padding=1),
            ]

        elif stride == 2:
            blocks = [
                nn.Conv2d(in_channel, channel // 2, 4, stride=2, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(channel // 2, channel, 3, padding=1),
            ]

        for i in range(n_res_block):
            blocks.append(ResBlock(channel, n_res_channel))

        blocks.append(nn.ReLU(inplace=True))

        self.blocks = nn.Sequential(*blocks)

    def forward(self, input):
        return self.blocks(input)


class Decoder(nn.Module):
    def __init__(
        self, in_channel, out_channel, channel, n_res_block, n_res_channel, stride
    ):
        super().__init__()

        blocks = [nn.Conv2d(in_channel, channel, 3, padding=1)]

        for i in range(n_res_block):
            blocks.append(ResBlock(channel, n_res_channel))

        blocks.append(nn.ReLU(inplace=True))

        if stride == 4:
            blocks.extend(
                [
                    nn.ConvTranspose2d(channel, channel // 2, 4, stride=2, padding=1),
                    nn.ReLU(inplace=True),
                    nn.ConvTranspose2d(
                        channel // 2, out_channel, 4, stride=2, padding=1
                    ),
                ]
            )

        elif stride == 2:
            blocks.append(
                nn.ConvTranspose2d(channel, out_channel, 4, stride=2, padding=1)
            )

        self.blocks = nn.Sequential(*blocks)

    def forward(self, input):
        return self.blocks(input)



       
          
       