import numpy as np
import torch
import torch.nn as nn
import torch.nn.init as nn_init
import torch.nn.functional as F
from torch import Tensor

import typing as ty
import math
import pandas as pd



import torch
import os
import json
import time
from torch.utils.data import DataLoader, TensorDataset
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm

import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR


from dataset.dataset import get_table_latent_dataloader, create_df_dict_from_dir
from sklearn.preprocessing import QuantileTransformer, LabelEncoder, OneHotEncoder, OrdinalEncoder


from tableLatent.perceive.encoderDecoders import PerceiveAggregator, TransformerRowDecoder,TransformerCondDecoder
from tableLatent.perceive.contrastTrainer import SupConLoss,ContrastiveTrainer
from tableLatent.perceive.vaeTrainer import process_column_embeddings

from tableLatent.feature.VAE import Transformer,compute_loss, Decoder_model

class TableMixedTransformer:
    """
        TableLatentTransformer encodes table rows into latent vectors using neural network, as well as decoding latent to table rows with the same format using a data specific neural network. It handles numerical and categorical columns. Categorical columns must be represented as their string form. 
    """
    DECODE_CLASS = Decoder_model
    def __init__(self,params={}) -> None:
        """
        Initialize TableMixedTransformer with given parameters for the aggregator and decoder models.

        Args:
            params: A dictionary containing two sub-dictionaries for aggregator and decoder hyperparameters:
                    - 'aggregator': Hyperparameters for the encoder model (aggregator).
                    - 'decoder': Hyperparameters for the decoder model.

        """
        self.reset_fitting(params)

    def fit(self, train_df_dict, test_df_dict, config_dict,train_transformer_for_test=True,fitting_bsize=1024, retrain_aggregator=False, retrain_decoder=False):
        """
        Fits encoder and decoders networks.

        Args:
            df_dict: dictionary,  keys are dataset names/meta and values are pd.DataFrame

        Returns:
            None

        Raises:
            None
        """
        # Step 1: create datasets/data loader
        #         and set data attributes
        train_loader, test_loader = self._set_data_attribute(train_df_dict, test_df_dict,config_dict, train_transformer_for_test, fitting_bsize)

        # Step 2: fit encoder
        self.aggregator = self._train_encoder(train_loader, test_loader,retrain_aggregator=retrain_aggregator,**self.aggregator_params)

        # Step 3: fit decoder
        self.decoder = self._train_decoder(self.transformers_per_table, train_loader, test_loader,retrain_decoder=retrain_decoder, **self.decoder_params)

    def _set_data_attribute(self, train_df_dict, test_df_dict, config_dict, train_transformer_for_test,batch_size=1024):
        train_loader = get_table_latent_dataloader(df_dict=train_df_dict,config_dict=config_dict,return_label=True,fixed_batch=True,batch_size=batch_size)

        train_ds = train_loader.dataset
        
        train_transformer_dict = train_ds.transformers_dict
        self.transformers_per_table = train_transformer_dict
        self.config_dict = config_dict
        self.emb_dict = train_ds.result_dict

        test_transformer_dict = train_transformer_dict if train_transformer_for_test else None
        test_loader = get_table_latent_dataloader(df_dict=test_df_dict,config_dict=config_dict,return_label=True,fixed_batch=True,transformers_dict=test_transformer_dict,batch_size=batch_size)

        return train_loader, test_loader

    def _train_encoder(self,dataloader,test_loader,dim = 768,dim_latent = 64,depth = 4,
                        learning_rate = 5e-4,epochs=1200,mask_ratio=0.3,max_input_seq_len = 64, contrast_temp = 0.1,device = 'cuda' ,num_latent = 16, dim_head=64,ff_mult=4,retrain_aggregator = False
        ):

        if self.fitted['aggregator'] and not retrain_aggregator:
            return self.aggregator
        elif self.fitted['aggregator'] and retrain_aggregator:
            model = self.aggregator
        else:
            # Create a dummy input tensor with shape (batch_size, column_number*2, embedding_dim)
            model = PerceiveAggregator(
                dim=dim,
                dim_latent=dim_latent,
                depth=depth,
                dim_head=dim_head,
                num_latents=num_latent,
                max_seq_len=max_input_seq_len,
                ff_mult=ff_mult
            )

        self.aggregator_params = {
            'dim': dim,
            'dim_latent': dim_latent,
            'depth': depth,
            'learning_rate': learning_rate,
            'epochs': epochs,
            'mask_ratio': mask_ratio,
            'max_input_seq_len': max_input_seq_len,
            'contrast_temp': contrast_temp,
            'num_latent': num_latent,
            'dim_head':dim_head,
            "ff_mult":ff_mult
        }

        print("*"*40 + "\nTraining aggregator with parameters:\n"+ f"{self.aggregator_params}\n" + "*"*40)

        # Optimizer
        optimizer = optim.Adam(model.parameters(), lr=learning_rate)
        scheduler = CosineAnnealingLR(optimizer, epochs)

        # Loss function
        criterion = SupConLoss(temperature=contrast_temp)

        # Trainer
        trainer = ContrastiveTrainer(model, dataloader, optimizer, criterion, scheduler, device=device, mask_ratio=mask_ratio)

        # Run the training process
        trainer.train(num_epochs=epochs)

        self.fitted['aggregator'] = True
        
        return model

    def _train_decoder(self, transformer_dict, train_loader, test_loader, num_layers=2, d_emb=64, n_head=1, factor=32, lr=1e-3, wd=0, max_beta=1e-2, min_beta=1e-5, lambd=0.7, device='cuda:0', num_epochs=4000, retrain_decoder=False):

        if self.fitted['decoder'] and not retrain_decoder:
            return self.decoder
        elif self.fitted['decoder'] and retrain_decoder:
            model = self.decoder
        else:
            d_numerical = 0
            categories = []
            num_idx, cat_idx = [], []
            for dataset_name, tfs in transformer_dict.items():
                col_idx = 0
                for column, tf in tfs.items():
                    if isinstance(tf, QuantileTransformer):
                        d_numerical += 1
                        num_idx.append(col_idx)
                    else:
                        #print("tf.categories_:",tf.categories_[0])
                        categories.append(len(tf.categories_[0])+1) # + 1 for unseen category
                        cat_idx.append(col_idx)
                    col_idx += 1
                # TODO: currently TableMixedTransformer can only be fitted on one data frame.
                # The design might be extend to multi-table fitting in the future.
                break

            #print(f"num_idx: {num_idx}, cat_idx:{cat_idx}, categories:{categories}")
            model = self.DECODE_CLASS(num_layers, d_numerical, categories, d_emb, n_head=n_head, factor=factor)

        model = model.to(device)

        self.decoder_params = {
            "num_layers":num_layers, 
            "d_numerical":d_numerical, 
            "categories":categories, 
            "d_emb":d_emb, 
            "n_head":n_head, 
            "factor":factor
        }
        print("*"*40 + "Training decoder with parameters:\n"+ f"{self.decoder_params}" + "*"*40)

        self.aggregator.eval()
        self.aggregator = self.aggregator.to(device)

        optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.95, patience=10, verbose=True)

        best_train_loss = float('inf')
        current_lr = optimizer.param_groups[0]['lr']
        patience = 0
        beta = max_beta
        start_time = time.time()

        for epoch in range(num_epochs):
            pbar = tqdm(train_loader, total=len(train_loader))
            pbar.set_description(f"Epoch {epoch+1}/{num_epochs}")

            curr_loss_multi = 0.0
            curr_loss_gauss = 0.0
            curr_count = 0

            for batch_emb, batch_label, categories, meta, _ in pbar:
                batch_emb = batch_emb.squeeze().to(device)
                batch_label = batch_label.squeeze()
                dtype = torch.where(categories < 1)[0]
                dtype, meta = dtype.to(device).long(), meta.to(device).float()
                model.train()
                optimizer.zero_grad()

                batch_num = batch_label[:, num_idx].to(device)
                batch_cat = batch_label[:, cat_idx].to(device).long()

                attention_mask = torch.all(batch_emb != -1000, dim=-1).long().to(device)
                batch_latent = self.aggregator(batch_emb,attention_mask, dtype, meta).squeeze()

                Recon_X_num, Recon_X_cat = model(batch_latent)
                #print(batch_num.shape, batch_cat.shape)
                #print(Recon_X_num.shape, [rc.shape for rc in Recon_X_cat])

                loss_mse, loss_ce, _, train_acc = compute_loss(batch_num, batch_cat, Recon_X_num, Recon_X_cat)

                loss = loss_mse + loss_ce
                loss.backward()
                optimizer.step()

                batch_length = batch_num.shape[0]
                curr_count += batch_length
                curr_loss_multi += loss_ce.item() * batch_length
                curr_loss_gauss += loss_mse.item() * batch_length

            num_loss = curr_loss_gauss / curr_count
            cat_loss = curr_loss_multi / curr_count

            # Evaluation
            model.eval()
            val_loss_gauss = 0.0
            val_loss_multi = 0.0
            val_count = 0
            val_acc_total = 0.0

            with torch.no_grad():
                for batch_emb, batch_label, categories, meta, _ in test_loader:
                    batch_emb = batch_emb.squeeze().to(device)
                    batch_label = batch_label.squeeze()
                    dtype = torch.where(categories < 1)[0]
                    dtype, meta = dtype.to(device).long(), meta.to(device).float()

                    batch_num = batch_label[:, num_idx].to(device)
                    batch_cat = batch_label[:, cat_idx].to(device).long()

                    attention_mask = torch.all(batch_emb != -1000, dim=-1).long().to(device)
                    batch_latent = self.aggregator(batch_emb, attention_mask, dtype, meta).squeeze()

                    Recon_X_num, Recon_X_cat = model(batch_latent)

                    val_mse_loss, val_ce_loss, _, val_acc = compute_loss(batch_num, batch_cat, Recon_X_num, Recon_X_cat)

                    batch_length = batch_num.shape[0]
                    val_count += batch_length
                    val_loss_gauss += val_mse_loss.item() * batch_length
                    val_loss_multi += val_ce_loss.item() * batch_length
                    val_acc_total += val_acc.item() * batch_length

            val_mse_loss = val_loss_gauss / val_count
            val_ce_loss = val_loss_multi / val_count
            val_acc = val_acc_total / val_count

            scheduler.step(val_ce_loss)
            new_lr = optimizer.param_groups[0]['lr']

            if new_lr != current_lr:
                current_lr = new_lr
                print(f"Learning rate updated: {current_lr}")

            train_loss = val_ce_loss + val_mse_loss
            if train_loss < best_train_loss:
                best_train_loss = train_loss
                patience = 0
            else:
                patience += 1
                if patience == 10:
                    if beta > min_beta:
                        beta = beta * lambd

            print(f'Epoch {epoch+1}: beta = {beta:.6f}, Train MSE: {num_loss:.6f}, Train CE: {cat_loss:.6f}, Val MSE: {val_mse_loss:.6f}, Val CE: {val_ce_loss:.6f}, Train ACC: {train_acc.item():.6f}, Val ACC: {val_acc:.6f}')

        end_time = time.time()
        print(f'Training time: {(end_time - start_time)/60:.4f} mins')

        self.fitted['decoder'] = True

        return model

    def df_to_latent(self, df_dict, batch_size=512, return_format='tensor', device='cuda:0'):
        """
        Transforms columns of a Dask DataFrame using fitted transformers and concatenates the results.

        Args:
            ddf (dask.dataframe.DataFrame): The Dask DataFrame to transform.
            return_format (str): format of return data, ['tensor', 'dataloader']

        Returns:
            dask.dataframe.DataFrame: Transformed Dask DataFrame.

        Raises:
            AssertionError: If transformers are not fitted.
        """

        assert all([dataset_name in self.transformers_per_table for dataset_name in df_dict]), f"All dataset_name must have been seen during fitting! Acceptable dataset names: {self.transformers_per_table.keys()}"
        assert self.fitted['aggregator'], "Aggregator is not fitted, transformation failed!"

        dataloader = get_table_latent_dataloader(df_dict=df_dict,config_dict=self.config_dict,return_label=True,fixed_batch=True,batch_size=batch_size, transformers_dict=self.transformers_per_table)

        self.aggregator.eval()
        self.aggregator.to(device)

        latent = []

        for batch, _, categories, meta, _ in dataloader:
            batch = batch.squeeze().to(device)
            dtype = torch.where(categories < 1)[0]
            dtype, meta = dtype.to(device).long(), meta.to(device).float()

            attention_mask = torch.all(batch != -1000, dim=-1).long().to(device)
            batch_latent = self.aggregator(batch, attention_mask,dtype,meta).squeeze().detach().cpu()
            latent.append(batch_latent)

        latent = torch.cat(latent, dim=0)

        self.aggregator.cpu()

        if return_format == 'tensor':
            return latent
        elif return_format == 'dataloader':
            dataset = TensorDataset(latent)
            dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
        else:
            raise NotImplementedError(f"Return format {return_format} is not implemented yet!")        

    def transform(self, df_dict):
        return self.latent_to_df(df_dict)

    def inverse_transform(self, latent_data):
        return self.latent_to_df(latent_data)

    def _decode_feature(self,Recon_X_num, Recon_X_cat, transformer_dict):
        """
        Inverse transform numerical and categorical tensors back to their original form and return a DataFrame.

        Parameters:
        - Recon_X_num: 2D tensor (B, num_numerical_columns) of reconstructed numerical values.
        - Recon_X_cat: list of 2D tensors, each representing predicted class probabilities of one categorical column.
        - transformer_dict: dictionary mapping column names to fitted transformers (QuantileTransformer for numerical
                            columns, OrdinalEncoder for categorical columns).
        
        Returns:
        - A pandas DataFrame with the original data form (inverse-transformed).
        """
        
        # Initialize an empty dictionary to store the columns
        df_dict = {}
        
        num_idx = 0  # Index for numerical columns
        cat_idx = 0  # Index for categorical columns
        
        for col_name, transformer in transformer_dict.items():
            if isinstance(transformer, QuantileTransformer):
                # Inverse transform numerical columns
                inv_num_col = transformer.inverse_transform(Recon_X_num[:, [num_idx]].detach().cpu().numpy())  # Extract numerical column
                df_dict[col_name] = inv_num_col.flatten()  # Add to dictionary, flattening the column
                num_idx += 1
            elif isinstance(transformer, OrdinalEncoder):
                # Inverse transform categorical columns
                cat_probs = Recon_X_cat[cat_idx].detach().cpu().numpy()  # Get predicted class probabilities
                cat_preds = np.argmax(cat_probs, axis=1)  # Take the argmax to get the predicted class index
                inv_cat_col = transformer.inverse_transform(cat_preds.reshape(-1, 1))  # Inverse transform to original categories
                df_dict[col_name] = inv_cat_col.flatten()  # Add to dictionary
                cat_idx += 1
            else:
                raise ValueError(f"Unknown transformer type for column: {col_name}")
        
        # Convert the dictionary to a DataFrame
        df_inverse = pd.DataFrame(df_dict)
        
        return df_inverse

    def latent_to_df(self, latent_data, dataset_name,batch_size=512,device='cuda:0'):
        """
        Applies inverse transformation to a latent tensor/dataloader. All latent data must come from the same dataset, as indicated by the dataset_name argument

        Args:
            latent_data (torch.tensor or dataloader): The latent data.
            dataset_name (str): name of the dataset among data seen during fitting
            batch_size (int): batch size during latent decoding.

        Returns:
            pandas.DataFrame

        Raises:
            AssertionError: If decoder is not fitted.
        """
        assert self.fitted['decoder'], "Decoder network are not fitted, inverse transformation failed!"
        # Step 1, latent vector pass through decoders

        self.decoder.eval().to(device)

        if isinstance(latent_data, torch.Tensor):
            dataset = TensorDataset(latent_data)
            dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        elif isinstance(latent_data, DataLoader):
            dataloader = latent_data
        else:
            raise NotImplementedError(f"Latent data input with type {type(latent_data)} is not supported!")
        
        transformer_dict = self.transformers_per_table[dataset_name]

        recon_df = []
        
        for batch_latent in dataloader:
            Recon_X_num, Recon_X_cat = self.decoder(batch_latent[0].to(device))
            batch_df_reconstructed = self._decode_feature(Recon_X_num, Recon_X_cat, transformer_dict)
            recon_df.append(batch_df_reconstructed)

        self.decoder.cpu()

        return pd.concat(recon_df).reset_index(drop=True)
    
    def fit_transform(self, df_dict):
        self.fit(df_dict)
        return self.df_to_latent(df_dict)
    
    def reset_fitting(self,params={}):
        self.transformers_per_table = {}
        self.fitted = {"aggregator":False,"decoder":False}
        self.aggregator = None
        self.decoder = None

        # Store the parameters for aggregator and decoder
        self.aggregator_params = params.get('aggregator', {})
        self.decoder_params = params.get('decoder', {})

    def get_checkpoint(self):
        """
        Returns a checkpoint containing all column transformer objects and summarized information about all fitted transformers.

        Returns:
            dict: Checkpoint dictionary containing 'transformers' with all fitted transformers and 
                'column_ranges_in_transformed' with ranges and types of transformed columns.

        Raises:
            None
        """
        # save the following: aggregator, decoder(s), embedding dict, transformer_per_table. 
        checkpoint = {}

        assert self.fitted['aggregator'] and self.fitted['decoder'], "Both encoder and decoder must be fitted!"

        checkpoint['aggregator_state'] = self.aggregator.state_dict()
        checkpoint['transformers_per_table'] = self.transformers_per_table
        checkpoint['aggregator_params'] = self.aggregator_params

        checkpoint['decoder_state'] = self.decoder.state_dict()
        checkpoint['decoder_params'] = self.decoder_params

        return checkpoint
    
    def load_checkpoint(self, checkpoint, allow_partial_loading=True):
        """
            Given a checkpoint dictionary, set the transformers_parameters of all column transformers. Note that this will CLEAR all existing fitting.

            allow_partial_loading: boolean. Whether a checkpoint with only aggregator and dataset attributes can be loaded. Useful for case where we are using a pretrained encoder but no decoder.
        """
        # Set the following: aggregator, decoder(s), embedding dict, transformer_per_table. 
        # If transformer is fitted, existing fit will be OVERWRITTEN.

        AGGREGATOR_STATES = ['transformers_per_table', 'aggregator_params', 'aggregator_state']
        DECODER_STATES = ['decoder_params', 'decoder_state']

        assert all([ST in checkpoint for ST in AGGREGATOR_STATES]), f"All aggregator states and data attributes must be provide: {AGGREGATOR_STATES}"
        assert all([ST in checkpoint for ST in DECODER_STATES]) or allow_partial_loading, f"Decoder states must be provide when allow_partial_loading is false: {DECODER_STATES}"

        print("Aggregator loaded!")
        self.transformers_per_table = checkpoint['transformers_per_table']
        self.aggregator_params = checkpoint['aggregator_params']
        self.aggregator = PerceiveAggregator(**self.aggregator_params)
        self.aggregator.load_state_dict(checkpoint['aggregator_state'])
        self.fitted['aggregator'] = True

        if not all([ST in checkpoint for ST in DECODER_STATES]):
            print("Decoder state not provided! Skipped due to partial loading enabled.")
        else:
            self.decoder_params = checkpoint['decoder_params']
            self.decoder = self.DECODE_CLASS(**self.decoder_params)
            self.decoder.load_state_dict(checkpoint['decoder_state'])
            self.fitted['decoder'] = True


class TableLatentTransformer(TableMixedTransformer):
    DECODE_CLASS = TransformerCondDecoder
    def __init__(self,params={}) -> None:
        super().__init__(params)

    def _train_decoder(self, transformer_dict, train_loader, test_loader, device='cuda:0', num_epochs=4000, aggregated_dim=64, depth=4,output_dim=768,lm_emb=768,lr=1e-4,wd=0,factor=0.95,patience=10, retrain_decoder=False):

        if self.fitted['decoder'] and not retrain_decoder:
            return self.decoder
        elif self.fitted['decoder'] and retrain_decoder:
            model = self.decoder
        else:
            model = TransformerCondDecoder(aggregated_dim=aggregated_dim, depth=depth, lm_emb=lm_emb,output_dim=output_dim)

        model = model.to(device)

        self.decoder_params = {
            "depth":depth, 
            "output_dim":output_dim, 
            "aggregated_dim":aggregated_dim, 
            "lm_emb":lm_emb,
            "num_epochs":num_epochs
        }
        print("*"*40 + "Training decoder with parameters:\n"+ f"{self.decoder_params}" + "*"*40)

        self.aggregator.eval()
        self.aggregator = self.aggregator.to(device)

        optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=factor, patience=patience, verbose=True)

        best_train_loss = float('inf')
        current_lr = optimizer.param_groups[0]['lr']
        start_time = time.time()

        for epoch in range(num_epochs):
            pbar = tqdm(train_loader, total=len(train_loader))
            pbar.set_description(f"Epoch {epoch+1}/{num_epochs}")

            curr_loss_multi = 0.0
            curr_loss_gauss = 0.0
            curr_count = 0

            for batch_emb, batch_label, dtype, meta in pbar:
                batch_emb = batch_emb.squeeze().to(device)

                batch_length = batch_emb.shape[0]
                cat_idx, num_idx = torch.where(dtype == 0)[1], torch.where(dtype == 1)[1]

                batch_label = batch_label.squeeze()
                dtype, meta = dtype.to(device).long(), meta.to(device).float()
                model.train()
                optimizer.zero_grad()
                batch_num = batch_label[:, num_idx].to(device)
                batch_cat = batch_label[:, cat_idx].to(device).long()
                #print("dtype:",dtype, "batch_cat.shape:",batch_cat.shape,"pos_count:",torch.sum(batch_cat.squeeze()==0))

                batch_cat_embs = batch_emb[:, 2*(cat_idx)+1,:]
                batch_cat_normalized, unique_embedding_list = process_column_embeddings(batch_cat_embs, batch_cat)

                attention_mask = torch.all(batch_emb != -1000, dim=-1).long().to(device)
                batch_latent = self.aggregator(batch_emb,attention_mask, dtype, meta).squeeze()
                column_names_emb = batch_emb[:, ::1, :]
                metadata_emb = meta.repeat(batch_length, 1, 1)

                Recon_X_num, Recon_X_cat = model(column_names_emb, metadata_emb, batch_latent, dtype,  unique_embedding_list)

                loss_mse, loss_ce, _, train_acc = compute_loss(batch_num, batch_cat_normalized, Recon_X_num, Recon_X_cat)

                loss = loss_mse + loss_ce
                loss.backward()
                optimizer.step()
                
                curr_count += batch_length
                curr_loss_multi += loss_ce.item() * batch_length
                curr_loss_gauss += loss_mse.item() * batch_length

            num_loss = curr_loss_gauss / curr_count
            cat_loss = curr_loss_multi / curr_count

            # Evaluation
            model.eval()
            val_loss_gauss = 0.0
            val_loss_multi = 0.0
            val_count = 0
            val_acc_total = 0.0

            with torch.no_grad():
                for batch_emb, batch_label, dtype, meta in test_loader:
                    batch_emb = batch_emb.squeeze().to(device)
                    batch_length = batch_emb.shape[0]
                    cat_idx, num_idx = torch.where(dtype == 0)[1], torch.where(dtype == 1)[1]
                    
                    batch_label = batch_label.squeeze()
                    dtype, meta = dtype.to(device).long(), meta.to(device).float()
                    model.train()
                    optimizer.zero_grad()
                    
                    batch_num = batch_label[:, num_idx].to(device)
                    batch_cat = batch_label[:, cat_idx].to(device).long()
                    #print("dtype:",dtype, "batch_cat.shape:",batch_cat.shape,"pos_count:",torch.sum(batch_cat.squeeze()==0),"cat_idx:",torch.where(dtype == 0), "num_idx:", torch.where(dtype == 1))

                    batch_cat_embs = batch_emb[:, 2*(cat_idx)+1,:] 
                    batch_cat_normalized, unique_embedding_list = process_column_embeddings(batch_cat_embs, batch_cat)

                    attention_mask = torch.all(batch_emb != -1000, dim=-1).long().to(device)
                    batch_latent = self.aggregator(batch_emb,attention_mask, dtype, meta).squeeze()
                    column_names_emb = batch_emb[:, ::1, :]
                    metadata_emb = meta.repeat(batch_length, 1, 1)

                    Recon_X_num, Recon_X_cat = model(column_names_emb, metadata_emb, batch_latent, dtype,  unique_embedding_list)

                    val_mse_loss, val_ce_loss, _, val_acc = compute_loss(batch_num, batch_cat_normalized, Recon_X_num, Recon_X_cat)

                    val_count += batch_length
                    val_loss_gauss += val_mse_loss.item() * batch_length
                    val_loss_multi += val_ce_loss.item() * batch_length
                    val_acc_total += val_acc.item() * batch_length

            val_mse_loss = val_loss_gauss / val_count
            val_ce_loss = val_loss_multi / val_count
            val_acc = val_acc_total / val_count

            scheduler.step(val_ce_loss)
            new_lr = optimizer.param_groups[0]['lr']

            if new_lr != current_lr:
                current_lr = new_lr
                print(f"Learning rate updated: {current_lr}")

            train_loss = val_ce_loss + val_mse_loss
            if train_loss < best_train_loss:
                best_train_loss = train_loss
                patience = 0

            print(f'Epoch {epoch+1}: Train MSE: {num_loss:.6f}, Train CE: {cat_loss:.6f}, Val MSE: {val_mse_loss:.6f}, Val CE: {val_ce_loss:.6f}, Train ACC: {train_acc.item():.6f}, Val ACC: {val_acc:.6f}')

        end_time = time.time()
        print(f'Training time: {(end_time - start_time)/60:.4f} mins')

        self.fitted['decoder'] = True

        return model
    
    def latent_to_df(self, latent_data, dataset_name, result_dict, batch_size=512, device='cuda:0'):
        """
        Applies inverse transformation to a latent tensor/dataloader. All latent data must come from the same dataset, as indicated by the dataset_name argument.

        Args:
            latent_data (torch.Tensor or DataLoader): The latent data.
            dataset_name (str): Name of the dataset among data seen during fitting.
            result_dict (dict): Dictionary with embeddings of column names, categories, and metadata.
            batch_size (int): Batch size during latent decoding.
            device (str): Device for model execution ('cuda:0', 'cpu', etc.).

        Returns:
            pandas.DataFrame: Reconstructed data from latent embeddings.

        Raises:
            AssertionError: If decoder is not fitted.
        """
        assert self.fitted['decoder'], "Decoder network is not fitted, inverse transformation failed!"

        # Step 1, latent vector pass through decoders
        self.decoder.eval().to(device)

        # Prepare dataloader from latent data
        if isinstance(latent_data, torch.Tensor):
            dataset = TensorDataset(latent_data)
            dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        elif isinstance(latent_data, DataLoader):
            dataloader = latent_data
        else:
            raise NotImplementedError(f"Latent data input type {type(latent_data)} is not supported!")

        # Step 2: Extract relevant information from result_dict for the given dataset_name
        dataset_info = result_dict[dataset_name]
        transformer_dict = self.transformers_per_table[dataset_name]
        columns_this_ds = list(transformer_dict.keys())

        # Column name embeddings (shape: B, num_cols, lm_emb)
        column_names_emb = torch.tensor(
            [dataset_info['column_name'][col] for col in columns_this_ds], device=device
        ).unsqueeze(0) 

        # Metadata embedding (shape: B, 1, lm_emb)
        metadata_emb = torch.tensor(dataset_info['metadata'], device=device).unsqueeze(0)

        # Extract dtype tensor (0 for categorical, 1 for numerical) based on column types
        dtype_tensor = torch.tensor(
            [0 if col in dataset_info['categories'] else 1 for col in columns_this_ds], device=device
        )

        # Unique embedding list: list of tensors for categorical columns
        unique_embedding_list = [
            torch.stack([torch.tensor(embedding, device=device) for embedding in dataset_info['categories'][col].values()])
            for col in columns_this_ds if col in dataset_info['categories']
        ]

        for col in columns_this_ds:
            if col not in dataset_info['categories']:
                continue
            print("column:",col)
            print("Order in unique_embedding_list:",dataset_info['categories'][col].keys())
            print("Order in column transformer:",transformer_dict[col].categories_)


        recon_df = []
        
        # Step 3: Iterate over latent data batches
        for batch in dataloader:
            batch_latent = torch.stack(batch).squeeze()
            row_latent_emb = batch_latent.to(device)
            batch_length = batch_latent.shape[0]

            column_name_this_batch = column_names_emb.repeat(batch_length, 1, 1)  # batch-wise repetition of column name embeddings
            metadta_emb_this_batch = metadata_emb.repeat(batch_length, 1, 1)

            # Forward pass through TransformerCondDecoder
            print(column_name_this_batch.shape,metadta_emb_this_batch.shape,metadata_emb.shape, row_latent_emb.shape)

            recon_num, recon_cat = self.decoder(
                column_name_this_batch, metadta_emb_this_batch, row_latent_emb, dtype_tensor, unique_embedding_list
            )

            # Decode the reconstructed numerical and categorical features
            batch_df_reconstructed = self._decode_feature(recon_num, recon_cat, transformer_dict)
            recon_df.append(batch_df_reconstructed)

        self.decoder.cpu()

        return pd.concat(recon_df).reset_index(drop=True)

