import torch
import torch.nn as nn
import torch.nn.functional as F
from pantheonrl.algos.diffusion_human_ai.ldm.vqvae import VectorQuantizer
from transformers import BertTokenizer, BertModel


class VQVAE(nn.Module):
    def __init__(self, max_seq_len=64, n_embed=512, embed_dim=128, input_dim=768, hidden_layers=[256, 256], latent_dim=128, output_dim=25):
        super().__init__()
        
        def LinearBlock(input_dim, output_dim, normalize=True):
            layers = [nn.Linear(input_dim, output_dim)]
            if normalize:
                layers.append(nn.BatchNorm1d(output_dim))    
            layers.append(nn.ReLU())
            return layers
        
        def Conv1dBlock(in_channels, out_channels, kernel_size, stride, padding, normalize=True):
            layers = [nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_channels))    
            layers.append(nn.ReLU())
            return layers
        
        self.encoder = nn.Sequential(
            *Conv1dBlock(max_seq_len, hidden_layers[0], 5, 3, 1), # (batch_size, 64, 768) -> (batch_size, 256, 256)
            *Conv1dBlock(hidden_layers[0], hidden_layers[1], 5, 3, 1), # (batch_size, 256, 256) -> (batch_size, 256, 85)
            nn.Flatten(),
            *LinearBlock(hidden_layers[1] * (input_dim // 9), 2048),
            nn.Linear(2048, latent_dim)
        )

        self.decoder = nn.Sequential(
            *LinearBlock(latent_dim, hidden_layers[0]),
            *LinearBlock(hidden_layers[1], output_dim)
        )
        
        assert embed_dim == latent_dim
        self.vector_quantizer = VectorQuantizer(n_embed, embed_dim, beta=0.25)

    def forward(self, x):
        # (batch_size, seq_len, 768)
        z_e = self.encoder(x)
        z_q, loss, perplexity = self.vector_quantizer(z_e)
        x_hat = self.decoder(z_q)
        return x_hat, z_e, loss, perplexity
    
    def encode(self, x):
        z_e = self.encoder(x)
        z_q, loss, perplexity = self.vector_quantizer(z_e)
        return z_q
    
    def decode(self, z):
        return self.decoder(z)
    

class VAE(nn.Module):
    def __init__(self, max_seq_len=32, context_dim=768, hidden_layers=[1024, 1024], latent_dim=64, output_dim=25):
        super().__init__()
        
        def LinearBlock(input_dim, output_dim, normalize=True):
            layers = [nn.Linear(input_dim, output_dim)]
            if normalize:
                layers.append(nn.BatchNorm1d(output_dim))    
            layers.append(nn.SiLU())
            return layers
        
        self.encoder = nn.Sequential(
            *LinearBlock(max_seq_len * context_dim, hidden_layers[0]),
            *LinearBlock(hidden_layers[0], hidden_layers[1]),
            nn.Linear(hidden_layers[1], latent_dim * 2)
        )

        self.decoder = nn.Sequential(
            *LinearBlock(latent_dim, hidden_layers[0]),
            *LinearBlock(hidden_layers[0], hidden_layers[1]),
            nn.Linear(hidden_layers[1], output_dim)
        )
        
    def forward(self, x):
        mean, logvar = torch.chunk(self.encoder(x), 2, dim=-1)
        logvar = torch.clamp(logvar, -30, 20)
        variance = logvar.exp()
        stdev = variance.sqrt()
        
        noise = torch.randn_like(mean)
        latent = mean + stdev * noise
        
        y = self.decoder(latent)
        
        return y, latent, mean, logvar


class Translator(nn.Module):
    def __init__(self, event_info_dim, vae=None, finetuned_bert=None, max_seq_len=32, pooler_output=False,\
                 bert_path="diffusion_human_ai/models/bert-base-uncased", device="cuda"):
        super().__init__()

        self.max_seq_len = max_seq_len
        self.pooler_output = pooler_output
        self.tokenizer = BertTokenizer.from_pretrained(bert_path)
        if finetuned_bert is not None:
            self.bert = finetuned_bert
        else:
            self.bert = BertModel.from_pretrained(bert_path).to(device)

        self.flatten = nn.Flatten()

        # self.vae = vae if vae else VQVAE(output_dim=event_info_dim).to(device)
        input_seq_len = 1 if pooler_output else max_seq_len
        self.vae = vae if vae else VAE(output_dim=event_info_dim, max_seq_len=input_seq_len).to(device)
        
        # for param in self.bert.parameters():
        #     param.requirse_grad = False
            
    def convert(self, desc):
        event_desc, _, _, _ = self.forward(desc)
        return event_desc
    
    def forward(self, desc):
        encoded_input = self.tokenizer(desc, return_tensors='pt', padding="max_length", max_length=self.max_seq_len).to(self.bert.device)
        if self.pooler_output:
            bert_output = self.bert(**encoded_input).pooler_output
        else:
            bert_output = self.flatten(self.bert(**encoded_input).last_hidden_state)

        return self.vae(bert_output)