import pandas as pd

import numpy as np

import torch
import os
import json
import time
from torch.utils.data import DataLoader, TensorDataset, Dataset
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm

import torch.nn.functional as F


from dataset.dataset import get_table_latent_dataloader, create_df_dict_from_dir, LatentConditionDataset
from sklearn.preprocessing import QuantileTransformer, LabelEncoder, OneHotEncoder, OrdinalEncoder
from torch.nn.utils.rnn import pad_sequence


from tableLatent.perceive.encoderDecoders import LatentAutoEncoder
from tableLatent.perceive.vaeTrainer import VAETrainer

from tableLatent.feature.VAE import compute_loss

def pad_condition_tensors(condition_tensors, pad_value=-1000):
    """
    Pads a list of condition tensors to ensure uniform length along the sequence (L) dimension
    and concatenates them into a single tensor.

    Args:
        condition_tensors (list of torch.Tensor): A list of tensors with shape (B, L, D) where each tensor may have a different L.
        pad_value (int, optional): The value to use for padding. Default is -1000.

    Returns:
        torch.Tensor: A single padded tensor with shape (total_sample_size, max_L_padded, D).
    """
    # Find the maximum L across all tensors
    max_L_padded = max(tensor.shape[1] for tensor in condition_tensors)

    # Pad each tensor individually to have the same sequence length (L)
    padded_tensors = [
        F.pad(tensor, (0, 0, 0, max_L_padded - tensor.shape[1]), value=pad_value)  # Padding along the L dimension
        for tensor in condition_tensors
    ]

    # Concatenate the tensors along the batch dimension
    padded_tensors = torch.cat(padded_tensors, dim=0)  # Shape: (total_sample_size, max_L_padded, D)
    
    return padded_tensors

class TableVAETransformer:
    MODEL_CLASS = LatentAutoEncoder
    """
    TableVAETransformer encodes table rows into latent vectors using a VAE and decodes latent vectors
    back to the original table format. Handles both numerical and categorical columns.
    """
    def __init__(self, params={}):
        """
        Initialize TableVAETransformer with given parameters for the aggregator and decoder models.
        
        Args:
            params (dict): Contains two sub-dictionaries:
                - 'encoder_params': Hyperparameters for the VAE encoder.
                - 'decoder_params': Hyperparameters for the VAE decoder.
        """
        self.reset_fitting(params)

    def fit(self, train_df_dict, test_df_dict, config_dict, train_transformer_for_test=True, 
            fitting_bsize=1024, retrain_vae=False, num_epochs=1000, learning_rate=1e-4, kl_reg=5e-5,distributed=False, world_size=1, rank=0,factor=0.95, patience=10,retrain_decoder_only=False):
        """
        Fits the VAE model (encoder and decoder networks). Always run this method to set/reset data attributes

        Args:
            train_df_dict (dict): Training dataset, keys are dataset names and values are pandas DataFrames.
            test_df_dict (dict): Testing dataset, keys are dataset names and values are pandas DataFrames.
            config_dict (dict): Configuration for datasets.
            train_transformer_for_test (bool): Whether to use training transformers for test data.
            fitting_bsize (int): Batch size for fitting the VAE.
            retrain_vae (bool): Whether to retrain the VAE if it has already been fitted.
            num_epochs (int): Number of training epochs.
            learning_rate (float): Learning rate for the optimizer.
            kl_reg (float): KL regularization weight.
        
        Returns:
            None
        """
        # Step 1: Create data loaders and set data attributes
        train_loader, test_loader = self._set_data_attribute(train_df_dict, test_df_dict, config_dict, train_transformer_for_test, fitting_bsize)

        # If model is fitted and no retraining needed, end fitting
        if self.fitted and not retrain_vae and not retrain_decoder_only:
            return

        # Step 2: Initialize VAE model
        encoder_params = self.encoder_params
        decoder_params = self.decoder_params

        # Assuming we have a VAE class (LatentAutoEncoder) with encoder and decoder params
        if not self.fitted:
            model = self.MODEL_CLASS(encoder_params=encoder_params, decoder_params=decoder_params)
        else:
            model = self.VAE

        if retrain_decoder_only:
            print("Retraining decoder part of VAE, freezing the encoder!")
            for param in model.encoder.parameters():
                param.requires_grad = False

        # Step 3: Optimizer and Loss
        optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
        criterion = compute_loss  # Assuming we have a custom loss function for VAE
        scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=factor, patience=patience, verbose=True)

        # Step 4: Training Loop
        # Trainer setup, handle distributed setup within the trainer
        trainer = VAETrainer(
            model=model,
            dataloader=train_loader,
            test_dataloader=test_loader,
            optimizer=optimizer,
            criterion=criterion,
            scheduler=scheduler,
            device="cuda" if torch.cuda.is_available() else "cpu",
            distributed=distributed,
            rank=rank,
            world_size=world_size,
            kl_reg=kl_reg
        )

        trainer.train(num_epochs)

        # Step 5: Save the fitted VAE model as an attribute
        self.VAE = trainer.get_model()

        self.fitted = True

    def _set_data_attribute(self, train_df_dict, test_df_dict, config_dict, train_transformer_for_test, batch_size=1024):
        """
        Prepares data attributes and returns train and test data loaders.

        Args:
            train_df_dict (dict): Dictionary of training data.
            test_df_dict (dict): Dictionary of testing data.
            config_dict (dict): Dataset configuration dictionary.
            train_transformer_for_test (bool): Use training transformers for test data.
            batch_size (int): Batch size for loading data.

        Returns:
            tuple: (train_loader, test_loader)
        """
        train_loader = get_table_latent_dataloader(df_dict=train_df_dict, config_dict=config_dict, return_label=True, fixed_batch=True, batch_size=batch_size)
        train_ds = train_loader.dataset
        
        train_transformer_dict = train_ds.transformers_dict
        self.transformers_per_table = train_transformer_dict
        self.config_dict = config_dict
        self.emb_dict = train_ds.result_dict

        test_transformer_dict = train_transformer_dict if train_transformer_for_test else None
        test_loader = get_table_latent_dataloader(df_dict=test_df_dict, config_dict=config_dict, return_label=True, fixed_batch=True, transformers_dict=test_transformer_dict, batch_size=batch_size)

        return train_loader, test_loader
    
    def reset_fitting(self,params={}):
        self.transformers_per_table = {}
        self.fitted = False
        self.VAE = None
        self.transformers_per_table = {}
        self.config_dict = {}
        self.emb_dict = {}

        # Store the parameters for aggregator and decoder
        self.encoder_params = params.get('encoder_params', {})
        self.decoder_params = params.get('decoder_params', {})

    def df_to_latent(self, df_dict, batch_size=512, return_format='tensor', device='cuda'):
        """
        Transforms columns of a Dask DataFrame into latent representations using the encoder part of a trained VAE.
        Additionally, transforms column names into condition embeddings using the VAE's decoder. 

        Args:
            df_dict (dictionary): A dictionary mapping dataset names to pandas DataFrames.
            batch_size (int, optional): The batch size to use for transformation. Default is 512.
            return_format (str, optional): The format to return the data. One of ['tensor', 'dataset', 'dataloader']. 
                                        - 'tensor' returns the latents and conditions as torch tensors.
                                        - 'dataset' returns them as a torch TensorDataset.
                                        - 'dataloader' returns a DataLoader.
            device (str, optional): The device to perform the transformation on. Default is 'cuda:0'.

        Returns:
            If return_format is 'tensor', returns a tuple of:
                - latent (torch.Tensor): Latent representations of the input data.
                - conditions (torch.Tensor): Condition embeddings of the input data.
            If return_format is 'dataset', returns a TensorDataset of (latent, conditions).
            If return_format is 'dataloader', returns a DataLoader of the TensorDataset.

        Raises:
            AssertionError: If transformers are not fitted or the dataset names in df_dict were not seen during fitting.
        """

        assert all([dataset_name in self.transformers_per_table for dataset_name in df_dict]), f"All dataset_name must have been seen during fitting! Acceptable dataset names: {self.transformers_per_table.keys()}"
        assert self.fitted, "VAE is not fitted, transformation failed!"

        dataloader = get_table_latent_dataloader(df_dict=df_dict, config_dict=self.config_dict, return_label=True, fixed_batch=True, batch_size=batch_size, transformers_dict=self.transformers_per_table, result_dict=self.emb_dict)

        # Use VAE models
        aggregator = self.VAE.encoder
        cond_encode_fn = self.VAE.decoder.encode_embs
        self.VAE.eval()
        self.VAE.to(device)

        latent = []
        conditions = []

        for batch, _, dtype, meta, _ in tqdm(dataloader):
            batch = batch.squeeze().to(device)
            B, L, D = batch.shape
            dtype, meta = dtype.to(device).long(), meta.to(device).float()

            attention_mask = torch.all(batch != -1000, dim=-1).long().to(device)
            batch_latent = aggregator(batch, attention_mask, dtype, meta).squeeze().detach().cpu()
            latent.append(batch_latent)

            meta_batch = meta.repeat(B, 1, 1)
            column_names = batch[:, ::2, :]  # Assuming column names are interleaved
            condition = cond_encode_fn(column_names, meta_batch).detach().cpu()
            conditions.append(condition)

            torch.cuda.empty_cache()

        latent = torch.cat(latent, dim=0)
        conditions = pad_condition_tensors(conditions)

        self.VAE.cpu()

        if return_format == 'tensor':
            return latent, conditions
        elif return_format == "dataset":
            dataset = LatentConditionDataset(latent, conditions)
            return dataset
        elif return_format == 'dataloader':
            dataset = LatentConditionDataset(latent, conditions)
            return DataLoader(dataset, batch_size=batch_size, shuffle=False)
        else:
            raise NotImplementedError(f"Return format {return_format} is not implemented yet!")
     

    def transform(self, df_dict):
        return self.latent_to_df(df_dict)

    def inverse_transform(self, latent_data):
        return self.latent_to_df(latent_data)

    def _decode_feature(self,Recon_X_num, Recon_X_cat, transformer_dict):
        """
        Inverse transform numerical and categorical tensors back to their original form and return a DataFrame.

        Parameters:
        - Recon_X_num: 2D tensor (B, num_numerical_columns) of reconstructed numerical values.
        - Recon_X_cat: list of 2D tensors, each representing predicted class probabilities of one categorical column.
        - transformer_dict: dictionary mapping column names to fitted transformers (QuantileTransformer for numerical
                            columns, OrdinalEncoder for categorical columns).
        
        Returns:
        - A pandas DataFrame with the original data form (inverse-transformed).
        """
        
        # Initialize an empty dictionary to store the columns
        df_dict = {}
        
        num_idx = 0  # Index for numerical columns
        cat_idx = 0  # Index for categorical columns
        
        for col_name, transformer in transformer_dict.items():
            if isinstance(transformer, QuantileTransformer):
                # Inverse transform numerical columns
                inv_num_col = transformer.inverse_transform(Recon_X_num[:, [num_idx]].detach().cpu().numpy())  # Extract numerical column
                df_dict[col_name] = inv_num_col.flatten()  # Add to dictionary, flattening the column
                num_idx += 1
            elif isinstance(transformer, OrdinalEncoder):
                # Inverse transform categorical columns
                cat_probs = Recon_X_cat[cat_idx].detach().cpu().numpy()  # Get predicted class probabilities
                cat_preds = np.argmax(cat_probs, axis=1)  # Take the argmax to get the predicted class index
                #print(cat_idx)
                #print(cat_probs)
                #print(cat_preds)
                inv_cat_col = transformer.inverse_transform(cat_preds.reshape(-1, 1))  # Inverse transform to original categories
                df_dict[col_name] = inv_cat_col.flatten()  # Add to dictionary
                #print(inv_cat_col)
                cat_idx += 1
            else:
                raise ValueError(f"Unknown transformer type for column: {col_name}")
        
        # Convert the dictionary to a DataFrame
        df_inverse = pd.DataFrame(df_dict)
        
        return df_inverse
    
    def _extract_embeddings_and_metadata(self, dataset_name,dataset_embeddings=None,transformers=None):
        """
        Extracts column name embeddings, metadata embeddings, unique category embeddings, 
        and dtype tensor for the given dataset name.

        Args:
            dataset_name (str): The original key or dataset name.

        Returns:
            tuple: 
                - column_name_embedding (torch.Tensor): Tensor of shape (num_columns, lm_emb)
                - metaembedding (torch.Tensor): Tensor of shape (1, lm_emb)
                - unique_embedding_list (list of torch.Tensor): Each tensor is of shape 
                (num_categories_in_categorical_column, lm_emb)
                - dtypes_tensor (torch.Tensor): Tensor of shape (num_columns,), where 0 represents 
                numerical columns and >0 represents categorical columns with the number of unique categories.
        """
        # Retrieve dataset embeddings and transformers for the given dataset name
        if dataset_embeddings is None:
            dataset_embeddings = self.emb_dict[dataset_name]
        if transformers is None:
            transformers = self.transformers_per_table[dataset_name]
        
        # Initialize list for unique embeddings and dtype tensor
        unique_embeddings_list = []
        columns_with_unq_embedding_recorded = set()
        column_names = list(transformers.keys())
        num_columns = len(column_names)
        
        # Initialize tensors for column name embeddings and dtype information
        column_name_embedding = []
        dtypes_tensor = torch.zeros(num_columns)

        # Iterate over each column to extract embeddings and dtype information
        for j, col in enumerate(column_names):
            # Get column name embedding
            col_embedding = dataset_embeddings['column_name'][col]
            column_name_embedding.append(torch.tensor(col_embedding))

            # Determine column type using transformers_dict
            transformer = transformers.get(col)
            if isinstance(transformer, (LabelEncoder, OrdinalEncoder, OneHotEncoder)):
                # Categorical column
                column_categories = transformer.categories_[0]
                dtypes_tensor[j] = 0  

                # Ensure unique embeddings are accumulated for this categorical column
                if col not in columns_with_unq_embedding_recorded:
                    columns_with_unq_embedding_recorded.add(col)
                    unique_column_embeddings = []
                    for category_idx in range(len(column_categories)):
                        unique_category_embedding = dataset_embeddings['categories'][col][category_idx]
                        unique_column_embeddings.append(torch.tensor(unique_category_embedding))
                    
                    # Append as a tensor to the unique embeddings list
                    unique_embeddings_list.append(torch.stack(unique_column_embeddings))
            else:
                # Numerical column
                dtypes_tensor[j] = 1  

        # Retrieve the metadata embedding
        metaembedding = torch.from_numpy(dataset_embeddings['metadata']).float().unsqueeze(0)
        column_name_embedding = torch.stack(column_name_embedding)

        return column_name_embedding, metaembedding, unique_embeddings_list, dtypes_tensor.long()

    def latent_to_df(self, latent_data, dataset_name, transformer_dict=None,emb_dict=None,batch_size=512,device='cuda:0'):
        """
        Applies inverse transformation to a latent tensor/dataloader. All latent data must come from the same dataset, as indicated by the dataset_name argument

        Args:
            latent_data (torch.tensor or dataloader): The latent data.
            dataset_name (str): name of the dataset among data seen during fitting
            batch_size (int): batch size during latent decoding.

        Returns:
            pandas.DataFrame

        Raises:
            AssertionError: If decoder is not fitted.
        """
        assert self.fitted, "VAE not fitted, inverse transformation failed!"

        decoder = self.VAE.decoder
        decoder.eval().to(device)

        if isinstance(latent_data, torch.Tensor):
            dataset = TensorDataset(latent_data)
            dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
        elif isinstance(latent_data, DataLoader):
            dataloader = latent_data
        else:
            raise NotImplementedError(f"Latent data input with type {type(latent_data)} is not supported!")
        
        # Take transformer_dict from new data
        # Else use existing data
        if transformer_dict is None:
            transformer_dict = self.transformers_per_table[dataset_name]

        column_names_emb, metadata_emb, unique_embedding_list, dtype = self._extract_embeddings_and_metadata(dataset_name,emb_dict,transformer_dict)
        unique_embedding_list = [unq_emb.to(device) for unq_emb in unique_embedding_list]

        recon_df = []

        if False:
            print("*"*40+"\n"+"Recon_cat in val:")
            print("column_names_emb.shape,",column_names_emb.shape)
            #print("column_names_emb:",column_names_emb[0,:,:5])
            print("metadata_emb.shape",metadata_emb.shape)
            #print("metadata_emb:",metadata_emb[:5])
            print("dtype:",dtype)
            print("unique_embedding_list:",unique_embedding_list[0][:,:5])
            print("*"*40)
    
        for batch_latent in dataloader:
            batch_latent = batch_latent[0].squeeze().to(device)
            batch_length = len(batch_latent)
            Recon_X_num, Recon_X_cat = decoder(column_names_emb.repeat(batch_length, 1, 1).to(device), 
                                               metadata_emb.repeat(batch_length, 1, 1).to(device), 
                                               batch_latent, 
                                               dtype.long().to(device), 
                                               unique_embedding_list)
            batch_df_reconstructed = self._decode_feature(Recon_X_num, Recon_X_cat, transformer_dict)
            recon_df.append(batch_df_reconstructed)

        decoder.cpu()

        return pd.concat(recon_df).reset_index(drop=True)
    
    def fit_transform(self, df_dict):
        self.fit(df_dict)
        return self.df_to_latent(df_dict)

    def get_checkpoint(self):
        """
        Returns a checkpoint containing all column transformer objects and summarized information about all fitted transformers.

        Returns:
            dict: Checkpoint dictionary containing 'transformers' with all fitted transformers and 
                'column_ranges_in_transformed' with ranges and types of transformed columns.

        Raises:
            None
        """
        # save the following: aggregator, decoder(s), embedding dict, transformer_per_table. 
        checkpoint = {}

        assert self.fitted, "VAE model must be fitted!"

        checkpoint['vae_state'] = self.VAE.state_dict()
        checkpoint['transformers_per_table'] = self.transformers_per_table
        checkpoint['encoder_params'] = self.encoder_params
        checkpoint['decoder_params'] = self.decoder_params
        checkpoint['config_dict'] = self.config_dict
        checkpoint['emb_dict'] = self.emb_dict

        return checkpoint
    
    def load_checkpoint(self, checkpoint):
        """
            Given a checkpoint dictionary, set the transformers_parameters of all column transformers. Note that this will CLEAR all existing fitting.

            allow_partial_loading: boolean. Whether a checkpoint with only aggregator and dataset attributes can be loaded. Useful for case where we are using a pretrained encoder but no decoder.
        """
        # Set the following: aggregator, decoder(s), embedding dict, transformer_per_table. 
        # If transformer is fitted, existing fit will be OVERWRITTEN.

        REQUIRED_STATES = ['vae_state', 'transformers_per_table', 'encoder_params', 'decoder_params']

        assert all([ST in checkpoint for ST in REQUIRED_STATES]), f"All models states and data attributes must be provide: {REQUIRED_STATES}. Currently have: {list(checkpoint.keys())}"
        
        self.transformers_per_table = checkpoint['transformers_per_table']

        self.encoder_params = checkpoint['encoder_params']
        self.decoder_params = checkpoint['decoder_params']

        self.VAE = self.MODEL_CLASS(self.encoder_params, self.decoder_params)
        self.VAE.load_state_dict(checkpoint['vae_state'])

        self.fitted = True
        print("Checkpoint loaded!")