import numpy as np
import torch
import torch.nn as nn
import torch.nn.init as nn_init
import torch.nn.functional as F
from torch import Tensor

import typing as ty
import math


import torch
import os
import json
import time
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm


class Tokenizer(nn.Module):

    def __init__(self, d_numerical, categories, d_token, bias):
        super().__init__()
        if categories is None:
            d_bias = d_numerical
            self.category_offsets = None
            self.category_embeddings = None
        else:
            d_bias = d_numerical + len(categories)
            category_offsets = torch.tensor([0] + categories[:-1]).cumsum(0)
            self.register_buffer('category_offsets', category_offsets)
            self.category_embeddings = nn.Embedding(sum(categories), d_token)
            nn_init.kaiming_uniform_(self.category_embeddings.weight, a=math.sqrt(5))
            print(f'{self.category_embeddings.weight.shape=}')

        # take [CLS] token into account
        self.weight = nn.Parameter(Tensor(d_numerical + 1, d_token))
        self.bias = nn.Parameter(Tensor(d_bias, d_token)) if bias else None
        # The initialization is inspired by nn.Linear
        nn_init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            nn_init.kaiming_uniform_(self.bias, a=math.sqrt(5))

    @property
    def n_tokens(self):
        return len(self.weight) + (
            0 if self.category_offsets is None else len(self.category_offsets)
        )

    def forward(self, x_num, x_cat):
        x_some = x_num if x_cat is None else x_cat
        assert x_some is not None
        x_num = torch.cat(
            [torch.ones(len(x_some), 1, device=x_some.device)]  # [CLS]
            + ([] if x_num is None else [x_num]),
            dim=1,
        )
    
        x = self.weight[None] * x_num[:, :, None]

        if x_cat is not None:
            x = torch.cat(
                [x, self.category_embeddings(x_cat + self.category_offsets[None])],
                dim=1,
            )
        if self.bias is not None:
            bias = torch.cat(
                [
                    torch.zeros(1, self.bias.shape[1], device=x.device),
                    self.bias,
                ]
            )
            x = x + bias[None]

        return x

class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, dropout=0.5):
        super(MLP, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.dropout = dropout

        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

class MultiheadAttention(nn.Module):
    def __init__(self, d, n_heads, dropout, initialization = 'kaiming'):

        if n_heads > 1:
            assert d % n_heads == 0
        assert initialization in ['xavier', 'kaiming']

        super().__init__()
        self.W_q = nn.Linear(d, d)
        self.W_k = nn.Linear(d, d)
        self.W_v = nn.Linear(d, d)
        self.W_out = nn.Linear(d, d) if n_heads > 1 else None
        self.n_heads = n_heads
        self.dropout = nn.Dropout(dropout) if dropout else None

        for m in [self.W_q, self.W_k, self.W_v]:
            if initialization == 'xavier' and (n_heads > 1 or m is not self.W_v):
                # gain is needed since W_qkv is represented with 3 separate layers
                nn_init.xavier_uniform_(m.weight, gain=1 / math.sqrt(2))
            nn_init.zeros_(m.bias)
        if self.W_out is not None:
            nn_init.zeros_(self.W_out.bias)

    def _reshape(self, x):
        batch_size, n_tokens, d = x.shape
        d_head = d // self.n_heads
        return (
            x.reshape(batch_size, n_tokens, self.n_heads, d_head)
            .transpose(1, 2)
            .reshape(batch_size * self.n_heads, n_tokens, d_head)
        )

    def forward(self, x_q, x_kv, key_compression = None, value_compression = None):
  
        q, k, v = self.W_q(x_q), self.W_k(x_kv), self.W_v(x_kv)
        for tensor in [q, k, v]:
            assert tensor.shape[-1] % self.n_heads == 0
        if key_compression is not None:
            assert value_compression is not None
            k = key_compression(k.transpose(1, 2)).transpose(1, 2)
            v = value_compression(v.transpose(1, 2)).transpose(1, 2)
        else:
            assert value_compression is None

        batch_size = len(q)
        d_head_key = k.shape[-1] // self.n_heads
        d_head_value = v.shape[-1] // self.n_heads
        n_q_tokens = q.shape[1]

        q = self._reshape(q)
        k = self._reshape(k)

        a = q @ k.transpose(1, 2)
        b = math.sqrt(d_head_key)
        attention = F.softmax(a/b , dim=-1)

        
        if self.dropout is not None:
            attention = self.dropout(attention)
        x = attention @ self._reshape(v)
        x = (
            x.reshape(batch_size, self.n_heads, n_q_tokens, d_head_value)
            .transpose(1, 2)
            .reshape(batch_size, n_q_tokens, self.n_heads * d_head_value)
        )
        if self.W_out is not None:
            x = self.W_out(x)

        return x
        
class Transformer(nn.Module):

    def __init__(
        self,
        n_layers: int,
        d_token: int,
        n_heads: int,
        d_out: int,
        d_ffn_factor: int,
        attention_dropout = 0.0,
        ffn_dropout = 0.0,
        residual_dropout = 0.0,
        activation = 'relu',
        prenormalization = True,
        initialization = 'kaiming',      
    ):
        super().__init__()

        def make_normalization():
            return nn.LayerNorm(d_token)

        d_hidden = int(d_token * d_ffn_factor)
        self.layers = nn.ModuleList([])
        for layer_idx in range(n_layers):
            layer = nn.ModuleDict(
                {
                    'attention': MultiheadAttention(
                        d_token, n_heads, attention_dropout, initialization
                    ),
                    'linear0': nn.Linear(
                        d_token, d_hidden
                    ),
                    'linear1': nn.Linear(d_hidden, d_token),
                    'norm1': make_normalization(),
                }
            )
            if not prenormalization or layer_idx:
                layer['norm0'] = make_normalization()
   
            self.layers.append(layer)

        self.activation = nn.ReLU()
        self.last_activation = nn.ReLU()
        # self.activation = lib.get_activation_fn(activation)
        # self.last_activation = lib.get_nonglu_activation_fn(activation)
        self.prenormalization = prenormalization
        self.last_normalization = make_normalization() if prenormalization else None
        self.ffn_dropout = ffn_dropout
        self.residual_dropout = residual_dropout
        self.head = nn.Linear(d_token, d_out)


    def _start_residual(self, x, layer, norm_idx):
        x_residual = x
        if self.prenormalization:
            norm_key = f'norm{norm_idx}'
            if norm_key in layer:
                x_residual = layer[norm_key](x_residual)
        return x_residual

    def _end_residual(self, x, x_residual, layer, norm_idx):
        if self.residual_dropout:
            x_residual = F.dropout(x_residual, self.residual_dropout, self.training)
        x = x + x_residual
        if not self.prenormalization:
            x = layer[f'norm{norm_idx}'](x)
        return x

    def forward(self, x):
        for layer_idx, layer in enumerate(self.layers):
            is_last_layer = layer_idx + 1 == len(self.layers)

            x_residual = self._start_residual(x, layer, 0)
            x_residual = layer['attention'](
                # for the last attention, it is enough to process only [CLS]
                x_residual,
                x_residual,
            )

            x = self._end_residual(x, x_residual, layer, 0)

            x_residual = self._start_residual(x, layer, 1)
            x_residual = layer['linear0'](x_residual)
            x_residual = self.activation(x_residual)
            if self.ffn_dropout:
                x_residual = F.dropout(x_residual, self.ffn_dropout, self.training)
            x_residual = layer['linear1'](x_residual)
            x = self._end_residual(x, x_residual, layer, 1)
        return x


class AE(nn.Module):
    def __init__(self, hid_dim, n_head):
        super(AE, self).__init__()
 
        self.hid_dim = hid_dim
        self.n_head = n_head


        self.encoder = MultiheadAttention(hid_dim, n_head)
        self.decoder = MultiheadAttention(hid_dim, n_head)

    def get_embedding(self, x):
        return self.encoder(x, x).detach() 

    def forward(self, x):

        z = self.encoder(x, x)
        h = self.decoder(z, z)
        
        return h

class VAE(nn.Module):
    def __init__(self, d_numerical, categories, num_layers, hid_dim, n_head = 1, factor = 4, bias = True):
        super(VAE, self).__init__()
 
        self.d_numerical = d_numerical
        self.categories = categories
        self.hid_dim = hid_dim
        d_token = hid_dim
        self.n_head = n_head
 
        self.Tokenizer = Tokenizer(d_numerical, categories, d_token, bias = bias)

        self.encoder_mu = Transformer(num_layers, hid_dim, n_head, hid_dim, factor)
        self.encoder_logvar = Transformer(num_layers, hid_dim, n_head, hid_dim, factor)

        self.decoder = Transformer(num_layers, hid_dim, n_head, hid_dim, factor)

    def get_embedding(self, x):
        return self.encoder_mu(x, x).detach() 

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x_num, x_cat):
        x = self.Tokenizer(x_num, x_cat)
        #print("Nan in tokenizer output:", torch.isnan(x).any())

        mu_z = self.encoder_mu(x)
        std_z = self.encoder_logvar(x)
        #print("Nan in encoder output:", torch.isnan(mu_z).any(), torch.isnan(std_z).any())

        z = self.reparameterize(mu_z, std_z)
        #print("Nan in reparameterize output:", torch.isnan(z).any())
        
        batch_size = x_num.size(0)
        h = self.decoder(z[:,1:])
        #print("Nan in decoder output:", torch.isnan(h).any())
        
        return h, mu_z, std_z

class Reconstructor(nn.Module):
    def __init__(self, d_numerical, categories, d_token):
        super(Reconstructor, self).__init__()

        self.d_numerical = d_numerical
        self.categories = categories
        self.d_token = d_token
        
        self.weight = nn.Parameter(Tensor(d_numerical, d_token))  
        nn.init.xavier_uniform_(self.weight, gain=1 / math.sqrt(2))
        self.cat_recons = nn.ModuleList()

        for d in categories:
            recon = nn.Linear(d_token, d)
            nn.init.xavier_uniform_(recon.weight, gain=1 / math.sqrt(2))
            self.cat_recons.append(recon)

    def forward(self, h):
        h_num  = h[:, :self.d_numerical]
        h_cat  = h[:, self.d_numerical:]

        recon_x_num = torch.mul(h_num, self.weight.unsqueeze(0)).sum(-1)
        recon_x_cat = []

        for i, recon in enumerate(self.cat_recons):
      
            recon_x_cat.append(recon(h_cat[:, i]))

        return recon_x_num, recon_x_cat


class Model_VAE(nn.Module):
    def __init__(self, num_layers, d_numerical, categories, d_token, n_head = 1, factor = 4,  bias = True):
        super(Model_VAE, self).__init__()

        self.VAE = VAE(d_numerical, categories, num_layers, d_token, n_head = n_head, factor = factor, bias = bias)
        self.Reconstructor = Reconstructor(d_numerical, categories, d_token)

    def get_embedding(self, x_num, x_cat):
        x = self.Tokenizer(x_num, x_cat)
        return self.VAE.get_embedding(x)

    def forward(self, x_num, x_cat):

        h, mu_z, std_z = self.VAE(x_num, x_cat)

        # recon_x_num, recon_x_cat = self.Reconstructor(h[:, 1:])
        recon_x_num, recon_x_cat = self.Reconstructor(h)

        return recon_x_num, recon_x_cat, mu_z, std_z


class Encoder_model(nn.Module):
    def __init__(self, num_layers, d_numerical, categories, d_token, n_head, factor, bias = True):
        super(Encoder_model, self).__init__()
        self.Tokenizer = Tokenizer(d_numerical, categories, d_token, bias)
        self.VAE_Encoder = Transformer(num_layers, d_token, n_head, d_token, factor)

    def load_weights(self, Pretrained_VAE):
        self.Tokenizer.load_state_dict(Pretrained_VAE.VAE.Tokenizer.state_dict())
        self.VAE_Encoder.load_state_dict(Pretrained_VAE.VAE.encoder_mu.state_dict())

    def forward(self, x_num, x_cat):
        x = self.Tokenizer(x_num, x_cat)
        z = self.VAE_Encoder(x)

        return z

class Decoder_model(nn.Module):
    def __init__(self, num_layers, d_numerical, categories, d_token, n_head, factor, bias = True, **kwargs):
        super(Decoder_model, self).__init__()
        self.VAE_Decoder = Transformer(num_layers, d_token, n_head, d_token, factor)
        self.Detokenizer = Reconstructor(d_numerical, categories, d_token)
        
    def load_weights(self, Pretrained_VAE):
        self.VAE_Decoder.load_state_dict(Pretrained_VAE.VAE.decoder.state_dict())
        self.Detokenizer.load_state_dict(Pretrained_VAE.Reconstructor.state_dict())

    def forward(self, z):

        h = self.VAE_Decoder(z)
        x_hat_num, x_hat_cat = self.Detokenizer(h)

        return x_hat_num, x_hat_cat
    




LR = 1e-3
WD = 0
D_TOKEN = 4
TOKEN_BIAS = True

N_HEAD = 1
FACTOR = 32
NUM_LAYERS = 2

def compute_loss(X_num, X_cat, Recon_X_num, Recon_X_cat, mu_z=None, logvar_z=None,**kwargs):
    ce_loss_fn = nn.CrossEntropyLoss()
    mse_loss = torch.tensor(0.0, device=X_num.device)
    ce_loss = torch.tensor(0.0, device=X_num.device)
    acc = torch.tensor(0.0, device=X_num.device)
    total_num = 0

    # Compute MSE loss only if X_num and Recon_X_num have non-zero size
    if X_num.numel() > 0 and Recon_X_num.numel() > 0:
        mse_loss = (X_num - Recon_X_num).pow(2).mean()

    # Compute CrossEntropy loss and accuracy for categorical columns if present
    if len(Recon_X_cat) > 0 and X_cat.numel() > 0:
        for idx, x_cat in enumerate(Recon_X_cat):
            if x_cat.numel() > 0:  # Ensure the categorical prediction is non-zero
                ce_loss += ce_loss_fn(x_cat, X_cat[:, idx])
                x_hat = x_cat.argmax(dim=-1)
                acc += (x_hat == X_cat[:, idx]).float().sum()
                total_num += x_hat.shape[0]
        
        # Normalize cross-entropy loss by the number of categories
        ce_loss /= (idx + 1)

        # Calculate accuracy
        acc /= total_num if total_num > 0 else 1  # Avoid division by zero

    # KL Divergence loss (if applicable)
    if mu_z is not None and logvar_z is not None:
        temp = 1 + logvar_z - mu_z.pow(2) - logvar_z.exp()
        loss_kld = -0.5 * torch.mean(temp.mean(-1).mean())
    else:
        loss_kld = 0.0

    return mse_loss, ce_loss, loss_kld, acc


import numpy as np
import pandas as pd
from sklearn.base import TransformerMixin, BaseEstimator
from sklearn.preprocessing import LabelEncoder, MinMaxScaler

class TabularTokenizerTransformer(BaseEstimator, TransformerMixin):
    def __init__(self):
        self.label_encoders = {}
        self.numerical_columns = []
        self.categorical_columns = []
        self.categories = []
        self.scaler = None
        
    def fit(self, df):
        # Identify numerical and categorical columns
        self.numerical_columns = df.select_dtypes(include=[np.number]).columns.tolist()
        self.categorical_columns = df.select_dtypes(include=['object']).columns.tolist()
        
        # Fit label encoders for each categorical column
        for col in self.categorical_columns:
            le = LabelEncoder()
            le.fit(df[col])
            self.label_encoders[col] = le
            self.categories.append(len(le.classes_))
        
        # Fit the MinMaxScaler for numerical columns
        if self.numerical_columns:
            self.scaler = MinMaxScaler()
            self.scaler.fit(df[self.numerical_columns])
        
        return self
    
    def transform(self, df):
        # Transform numerical columns using MinMaxScaler
        if self.numerical_columns:
            x_num = self.scaler.transform(df[self.numerical_columns].values)
        else:
            x_num = np.array([])
        
        # Transform categorical columns using the fitted label encoders
        x_cat = np.column_stack([
            self.label_encoders[col].transform(df[col])
            for col in self.categorical_columns
        ])
        
        return x_num, x_cat
    
    def inverse_transform(self, x_num, x_cat):
        # Inverse transform numerical columns using MinMaxScaler
        if self.numerical_columns:
            x_num = self.scaler.inverse_transform(x_num)
        
        # Inverse transform categorical columns
        inv_cat_columns = {
            col: self.label_encoders[col].inverse_transform(x_cat[:, idx])
            for idx, col in enumerate(self.categorical_columns)
        }
        
        # Reconstruct the DataFrame
        inv_df = pd.DataFrame(data=x_num, columns=self.numerical_columns)
        for col, data in inv_cat_columns.items():
            inv_df[col] = data
        
        return inv_df


from torch.utils.data import Dataset

class TabularDataset(Dataset):
    def __init__(self, X_num, X_cat):
        """
        Args:
            X_num (torch.Tensor): Tensor of numerical features.
            X_cat (torch.Tensor): Tensor of categorical features.
        """
        self.X_num = X_num
        self.X_cat = X_cat

    def __len__(self):
        # Return the number of samples in the dataset
        return self.X_num.size(0)

    def __getitem__(self, idx):
        """
        Args:
            idx (int): Index of the sample to retrieve.

        Returns:
            tuple: (X_num, X_cat) corresponding to the given index.
        """
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        # Retrieve the numerical and categorical data corresponding to the given index
        X_num_sample = self.X_num[idx]
        X_cat_sample = self.X_cat[idx]

        return X_num_sample, X_cat_sample

def train_vae_model(df, dataname, num_layers, d_token, n_head, factor, lr, wd, max_beta=1e-2, min_beta=1e-5, lambd=0.7, device='gpu', num_epochs=4000, batch_size=4096):


    curr_dir = os.path.dirname(os.path.abspath(__file__))
    ckpt_dir = f'{curr_dir}/ckpt/{dataname}' 
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)

    model_save_path = f'{ckpt_dir}/model.pt'

    # Use the TabularTokenizerTransformer to process the DataFrame
    transformer = TabularTokenizerTransformer()
    transformer.fit(df)
    X_num, X_cat = transformer.transform(df)
    X_num, X_cat = torch.from_numpy(X_num).float(), torch.from_numpy(X_cat).long()

    categories = transformer.categories
    d_numerical = len(transformer.numerical_columns)

    X_train_num, X_test_num = torch.tensor(X_num).float(), torch.tensor(X_num).float()
    X_train_cat, X_test_cat = torch.tensor(X_cat), torch.tensor(X_cat)

    X_train_num, X_test_num = X_train_num.to(device), X_test_num.to(device)
    X_train_cat, X_test_cat = X_train_cat.to(device), X_test_cat.to(device)

    train_data = TabularDataset(X_train_num, X_train_cat)

    train_loader = DataLoader(
        train_data,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,
    )

    model = Model_VAE(num_layers, d_numerical, categories, d_token, n_head=n_head, factor=factor, bias=True)
    model = model.to(device)

    pre_encoder = Encoder_model(num_layers, d_numerical, categories, d_token, n_head=n_head, factor=factor).to(device)
    pre_decoder = Decoder_model(num_layers, d_numerical, categories, d_token, n_head=n_head, factor=factor).to(device)

    pre_encoder.eval()
    pre_decoder.eval()

    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.95, patience=10, verbose=True)

    best_train_loss = float('inf')
    current_lr = optimizer.param_groups[0]['lr']
    patience = 0
    beta = max_beta
    start_time = time.time()

    for epoch in range(num_epochs):
        pbar = tqdm(train_loader, total=len(train_loader))
        pbar.set_description(f"Epoch {epoch+1}/{num_epochs}")

        curr_loss_multi = 0.0
        curr_loss_gauss = 0.0
        curr_loss_kl = 0.0
        curr_count = 0

        for batch_num, batch_cat in pbar:
            model.train()
            optimizer.zero_grad()

            batch_num = batch_num.to(device)
            batch_cat = batch_cat.to(device)

            Recon_X_num, Recon_X_cat, mu_z, std_z = model(batch_num, batch_cat)
        
            loss_mse, loss_ce, loss_kld, train_acc = compute_loss(batch_num, batch_cat, Recon_X_num, Recon_X_cat, mu_z, std_z)

            loss = loss_mse + loss_ce + beta * loss_kld
            loss.backward()
            optimizer.step()

            batch_length = batch_num.shape[0]
            curr_count += batch_length
            curr_loss_multi += loss_ce.item() * batch_length
            curr_loss_gauss += loss_mse.item() * batch_length
            curr_loss_kl += loss_kld.item() * batch_length

        num_loss = curr_loss_gauss / curr_count
        cat_loss = curr_loss_multi / curr_count
        kl_loss = curr_loss_kl / curr_count
        

        # Evaluation
        model.eval()
        with torch.no_grad():
            Recon_X_num, Recon_X_cat, mu_z, std_z = model(X_test_num, X_test_cat)

            val_mse_loss, val_ce_loss, val_kl_loss, val_acc = compute_loss(X_test_num, X_test_cat, Recon_X_num, Recon_X_cat, mu_z, std_z)
            val_loss = val_mse_loss.item() * 0 + val_ce_loss.item()    

            scheduler.step(val_loss)
            new_lr = optimizer.param_groups[0]['lr']

            if new_lr != current_lr:
                current_lr = new_lr
                print(f"Learning rate updated: {current_lr}")
                
            train_loss = val_loss
            if train_loss < best_train_loss:
                best_train_loss = train_loss
                patience = 0
                torch.save(model.state_dict(), model_save_path)
            else:
                patience += 1
                if patience == 10:
                    if beta > min_beta:
                        beta = beta * lambd

        print(f'Epoch {epoch+1}: beta = {beta:.6f}, Train MSE: {num_loss:.6f}, Train CE: {cat_loss:.6f}, Train KL: {kl_loss:.6f}, Val MSE: {val_mse_loss.item():.6f}, Val CE: {val_ce_loss.item():.6f}, Train ACC: {train_acc.item():.6f}, Val ACC: {val_acc.item():.6f}')

    end_time = time.time()
    print(f'Training time: {(end_time - start_time)/60:.4f} mins')

    return model