from sklearn.preprocessing import QuantileTransformer, LabelEncoder, OneHotEncoder, OrdinalEncoder
import pandas as pd
import os
import json

from transformers import RobertaTokenizer, RobertaModel, AutoModel, AutoTokenizer
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import random
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

def create_df_dict_from_dir(csv_dir, test_size=0, random_state=42):
    """
    Given a folder with CSV files, where all categorical columns are converted to string representations.
    Create a dictionary where keys are dataset names and values are data frames.
    
    If test_size > 0, it performs train/test split on each dataframe using the same random_state.
    
    Args:
    - csv_dir (str): Directory path where CSV files are stored.
    - test_size (float): Fraction of data to use as test set, between 0 and 1. Default is 0 (no split).
    - random_state (int): Random seed for splitting the data. Default is 42.
    
    Returns:
    - df_dict (dict): Dictionary of data frames from the CSV files.
    - train_dict (dict), test_dict (dict): If test_size > 0, returns two dictionaries, 
      one for the training set and one for the test set, otherwise returns just df_dict.
    """
    # Load all CSV files in a dictionary
    csv_files = os.listdir(csv_dir)
    train_dict = {}
    test_dict = {}

    for idx,csv in enumerate(csv_files):  
        full_path = os.path.join(csv_dir, csv)
        file_name = csv.replace(".csv", "")
        df = pd.read_csv(full_path,low_memory=False)

        if test_size > 0:
            train_df, test_df = train_test_split(df, test_size=test_size, random_state=random_state)
            train_dict[file_name] = train_df
            test_dict[file_name] = test_df
        else:
            train_dict[file_name] = df
        #if idx > 10:
        #    break

    if test_size > 0:
        return train_dict, test_dict
    else:
        return train_dict
    
def create_config_dict_from_dir(config_dir):
    config_files = os.listdir(config_dir)
    config_dict = {}

    for file in config_files:
        full_path = os.path.join(config_dir, file)
        file_name = file.replace(".json", "")
        config = json.load(open(full_path, "r"))

        config_dict[file_name] = config

    return config_dict

def batch_data_frames(df_dict, batch_size, shuffle=True):
    """
    Split data frames into batches. Optionally shuffle each data frame before batching.
    Ensure same batch rows come from the same data frame.
    
    Parameters:
    df_dict: Dictionary of pandas data frames to be batched.
    batch_size: Size of each batch.
    shuffle: Whether to shuffle the rows of each data frame before batching. Default is True.
    """
    print("Shuffling and batching data frames!" if shuffle else "Batching data frame without shuffling!")
    
    num_batches = sum(len(df) // batch_size + int(len(df) % batch_size > 0) for df in df_dict.values())
    batches = [None] * num_batches
    idx = 0
    
    for key, df in df_dict.items():
        if shuffle:
            df = df.sample(frac=1).reset_index(drop=True)  # Shuffle and reset index
        for i in range(len(df) // batch_size + int(len(df) % batch_size > 0)):
            batches[idx] = (df.iloc[i * batch_size : (i + 1) * batch_size], key)
            idx += 1
            
    return batches


def fit_column_transformers(df_dict, dist='normal', cat_encoder='ordinal'):
    """
    For each column in each data frame, fit a column transformer for numerical and categorical columns.

    df_dict: dictionary, keys are dataset names and values are pandas data frames.
    dist: string, distribution to be used by QuantileTransformer for numerical columns ('normal' or 'uniform').
    cat_encoder: string, encoder to be used for categorical columns ('label' or 'onehot').

    Returns:
    transformers_dict: dictionary, where keys are dataset names and values are dictionaries
                       with column names as keys and fitted transformers as values.
    """
    transformers_dict = {}
    
    for key, df in df_dict.items():
        column_transformers = {}
        for col in df.columns:
            if pd.api.types.is_numeric_dtype(df[col]):
                transformer = QuantileTransformer(output_distribution=dist)
                transformer.fit(df[[col]])  # Keep as 2D array for numerical transformer
            else:
                if cat_encoder == 'label':
                    transformer = LabelEncoder()
                    transformer.fit(df[col])  # Use 1D array for LabelEncoder
                elif cat_encoder == 'onehot':
                    transformer = OneHotEncoder(sparse_output=False, handle_unknown='ignore')
                    transformer.fit(df[[col]])  # Keep as 2D array for OneHotEncoder
                elif cat_encoder == "ordinal":
                    transformer = OrdinalEncoder(unknown_value=-1000, handle_unknown='use_encoded_value')
                    transformer.fit(df[[col]])  # Keep as 2D array for OrdinalEncoder
                else:
                    raise ValueError("Invalid cat_encoder value. Choose 'label' or 'onehot'.")
            column_transformers[col] = transformer
        
        transformers_dict[key] = column_transformers
    
    return transformers_dict

def apply_transformers(df_dict, transformers_dict):
    """
    Apply the fitted transformers to the corresponding columns of each data frame in df_dict.

    df_dict: dictionary, keys are dataset names and values are pandas data frames.
    transformers_dict: dictionary, where keys are dataset names and values are dictionaries
                       with column names as keys and fitted transformers as values.
    
    Returns:
    transformed_df_dict: dictionary, same as df_dict but with transformed columns.
    """
    transformed_df_dict = {}

    for dataset_name, df in df_dict.items():
        transformed_df = df.copy()

        for col in df.columns:
            transformer = transformers_dict[dataset_name][col]
            
            # Check if the column is numerical or categorical based on the transformer type
            if isinstance(transformer, QuantileTransformer):
                # Apply the transformer to numerical columns
                transformed_df[col] = transformer.transform(df[[col]]).flatten()
            
            elif isinstance(transformer, (LabelEncoder, OrdinalEncoder, OneHotEncoder)):
                # Apply the transformer to categorical columns
                transformed_df[col] = transformer.transform(df[[col]]).flatten()

            else:
                raise ValueError(f"Unsupported transformer for column {col}")
        
        # Save the transformed data frame
        transformed_df_dict[dataset_name] = transformed_df

    return transformed_df_dict



def get_unique_embeddings(df_dict, config_dict, transformers_dict, batch_size=512, language_model='GTE'):
    """
    Find all unique strings in all the data frames, encode them with a selected language model. 
    By saving the embeddings, the training is accelerated.

    After the embeddings_dict is created, a new dictionary is created, mapping each dataset name to:
    - metadata: embedding of the metadata text provided in config_dict.
    - column_name: a dictionary mapping each column name to an lm embedding.
    - categories: a dictionary in the format {categorical_column_name: {integer_labels: embedding_of_corresponding_category}}.
    
    The mapping between integer_labels and category levels is provided by transformers_dict.

    Only the resulting new dictionary will be returned.
    """

    # Load pre-trained model and tokenizer depending on the language_model argument
    if language_model == 'GTE':
        model_name = "thenlper/gte-base"
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModel.from_pretrained(model_name)
    else:
        # Fallback to BERT-base as a default
        model_name = 'roberta-base'
        tokenizer = RobertaModel.from_pretrained(model_name)
        model = RobertaTokenizer.from_pretrained(model_name)

    # Check if CUDA is available and move model to GPU if so
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    model.eval()  # Set model to evaluation mode

    def average_pool(last_hidden_states, attention_mask):
        last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
        return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

    # Collect all unique column names, categorical values, and metadata strings
    unique_names = set()
    print("Adding unique column name/value strings!")
    for dataset_name, df in df_dict.items():
        transformer = transformers_dict[dataset_name]
        unique_names.update(df.columns)  # Add column names
        unique_names.update([dataset_name])  # Add dataset name
        for col in df.select_dtypes(include=['object', 'category']).columns:
            categories_list = transformer[col].categories_[0].tolist()
            unique_names.update(categories_list)  # Add unique categories from transformer

    # Include embedding for metadata, if provided
    if config_dict is not None:
        print("Adding metadata strings!")
        for dataset_name, config in config_dict.items():
            unique_names.add(config['metadata'])

    # Convert unique names to a list for batching
    unique_names = list(unique_names)

    # Create a dictionary to store embeddings
    embeddings_dict = {}

    # Process unique names in batches
    print("Encoding unique strings:")
    for i in range(0, len(unique_names), batch_size):
        batch = unique_names[i:i + batch_size]

        # Tokenize and encode the batch of strings
        batch_dict = tokenizer(batch, max_length=512, padding=True, truncation=True, return_tensors='pt')
        batch_dict = {k: v.to(device) for k, v in batch_dict.items()}

        with torch.no_grad():
            outputs = model(**batch_dict)

        # Compute embeddings using average pooling of last hidden states
        embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask']).cpu().numpy()

        # Store the embeddings in the dictionary
        for j, name in enumerate(batch):
            embeddings_dict[name] = embeddings[j]

    # Clear the GPU memory after the embeddings are processed
    del model, tokenizer, batch_dict, outputs
    torch.cuda.empty_cache()

    # Now create the new dictionary mapping each dataset to metadata, column names, and category embeddings
    result_dict = {}

    for dataset_name, df in df_dict.items():
        dataset_dict = {}

        # Add metadata embeddings
        if config_dict and dataset_name in config_dict:
            metadata = config_dict[dataset_name]['metadata']
            dataset_dict['metadata'] = embeddings_dict[metadata]

        # Add column name embeddings
        dataset_dict['column_name'] = {col: embeddings_dict[col] for col in df.columns}

        # Add category embeddings for each categorical column
        category_dict = {}
        for col in df.select_dtypes(include=['object', 'category']).columns:
            transformer = transformers_dict[dataset_name][col]
            if isinstance(transformer, OrdinalEncoder):
                category_dict[col] = {int_label: embeddings_dict[orig_cat]
                                      for int_label, orig_cat in enumerate(transformer.categories_[0])}
            elif isinstance(transformer, LabelEncoder):
                category_dict[col] = {int_label: embeddings_dict[orig_cat]
                                      for int_label, orig_cat in enumerate(transformer.classes_)}
            elif isinstance(transformer, OneHotEncoder):
                # For OneHotEncoder, we assume the categories are in `categories_` attribute
                category_dict[col] = {i: embeddings_dict[cat]
                                      for i, cat in enumerate(transformer.categories_[0])}

        dataset_dict['categories'] = category_dict

        # Add to result dictionary
        result_dict[dataset_name] = dataset_dict

    return result_dict

class LatentConditionDataset(Dataset):
    def __init__(self, tensor1, tensor2):
        self.tensor1 = tensor1
        self.tensor2 = tensor2
        assert len(tensor1) == len(tensor2), "Tensors must have the same number of examples"
    
    def __len__(self):
        return len(self.tensor1)
    
    def __getitem__(self, idx):
        return self.tensor1[idx], self.tensor2[idx]

class TableLatentDataset(Dataset):
    def __init__(self, df_dict, result_dict, config_dict, lm_embedding_dim, transformers_dict, return_label=False, batch_size=512,fixed_batch=False):
        """
        Takes the following as input: batches of encoded data frames, embedding information for metadata, column names, and categories.
        Returns sequences of rows with embeddings for each batch.

        df_dict: a dictionary mappint dataset_names to original pandas df
        
        batches: list of tuples (pandas data frame, original_key).
                 All categorical columns are encoded as integer labels and numerical columns are quantile transformed.
        result_dict: dictionary created from get_unique_embeddings. It stores metadata, column names, and category embeddings for each dataset.
        lm_embedding_dim: integer, dimension of LM embeddings in result_dict.
        transformers_dict: dictionary of dictionaries. Maps each dataset's columns to the transformer used.
        return_label: Boolean flag indicating whether to return labels or not.
        """
        self.df_dict = df_dict
        self.batches = batch_data_frames(df_dict, batch_size,shuffle=(not fixed_batch))
        self.result_dict = result_dict
        self.config_dict = config_dict
        self.lm_embedding_dim = lm_embedding_dim
        self.return_label = return_label
        self.transformers_dict = transformers_dict
        self.batch_size = batch_size

    def shuffle_batches(self):
        # Reshuffle batches for more diverse contrastive comparison
        self.batches = batch_data_frames(self.df_dict, self.batch_size, shuffle=True)

    def __len__(self):
        return len(self.batches)

    def __getitem__(self, idx):
        if isinstance(idx, list):  # Handle batch of indices
            return self._get_batch(idx)
        else:  # Handle single index
            return self._get_single(idx)

    def _get_single(self, idx):
        """
        Process a single index and return a single batch.
        """
        batch_df, original_key = self.batches[idx]
        num_columns = len(batch_df.columns)
        batch_size = len(batch_df)

        # Get the preloaded embeddings from the result_dict
        dataset_embeddings = self.result_dict[original_key]

        # Create an empty tensor to hold the entire batch sequence
        batch_sequence = torch.zeros((batch_size, num_columns * 2, self.lm_embedding_dim))

        # Create categories_tensor to represent the number of unique categories per column
        categories_tensor = torch.zeros(num_columns)

        # Create a list to accumulate unique embeddings for all categorical columns
        unique_embeddings_list = []
        columns_with_unq_embedding_recorded = set()

        for i, (_, row) in enumerate(batch_df.iterrows()):
            row_sequence = []
            for j, col in enumerate(batch_df.columns):
                # Get column name embedding
                col_embedding = dataset_embeddings['column_name'].get(col, np.full(self.lm_embedding_dim, -1000))

                value = row[col]

                # Determine column type using transformers_dict
                transformer = self.transformers_dict[original_key].get(col)
                if isinstance(transformer, (LabelEncoder, OrdinalEncoder, OneHotEncoder)):
                    # Categorical column
                    column_categories = transformer.categories_[0]
                    categories_tensor[j] = len(column_categories)  # Store number of unique categories
                    value_embedding = dataset_embeddings['categories'][col][value]

                    # Ensure unique embeddings are accumulated for this categorical column
                    if col not in columns_with_unq_embedding_recorded:
                        columns_with_unq_embedding_recorded.add(col)
                        for category_idx, category in enumerate(column_categories):
                            unique_category_embedding = dataset_embeddings['categories'][col][category_idx]
                            unique_embeddings_list.append(unique_category_embedding)

                else:
                    # Numerical column
                    categories_tensor[j] = 0  # Indicate numerical column
                    value_embedding = np.full(self.lm_embedding_dim, value)

                # Append the embeddings to the row sequence
                row_sequence.append(col_embedding)
                row_sequence.append(value_embedding)

            # Convert row_sequence to a tensor and assign it to the batch sequence
            batch_sequence[i] = torch.tensor(np.array(row_sequence))

        if not self.return_label:
            return batch_sequence.squeeze(), categories_tensor.squeeze()

        # Create the label tensor
        label_matrix = np.zeros((batch_size, num_columns))
        for i, (_, row) in enumerate(batch_df.iterrows()):
            for j, col in enumerate(batch_df.columns):
                label_matrix[i, j] = row[col]  # Assuming batch_df is already transformed

        label_tensor = torch.tensor(label_matrix, dtype=torch.float32)

        # metadata tensor
        meta_tensor = torch.from_numpy(dataset_embeddings['metadata']).float()

        # Stack unique embeddings into a tensor
        if len(unique_embeddings_list) > 0:
            unique_embeddings = torch.tensor(np.stack(unique_embeddings_list))
        else:
            unique_embeddings = torch.tensor([])

        return batch_sequence.squeeze(), label_tensor.squeeze(), categories_tensor.squeeze(), meta_tensor.squeeze(), unique_embeddings.squeeze()


    def _get_batch(self, indices):
        """
        Process a list of indices (batch) and return a batch of batches.
        """
        batch_sequences = []
        label_tensors = []
        dtype_tensors = []
        meta_tensors = []

        original_keys = set()

        for idx in indices:
            _, original_key = self.batches[idx]
            original_keys.add(original_key)
            if self.return_label:
                batch_sequence, label_tensor, dtype_tensor, meta_tensor = self._get_single(idx)
                batch_sequences.append(batch_sequence)
                label_tensors.append(label_tensor)
                dtype_tensors.append(dtype_tensor)
                meta_tensors.append(meta_tensor)
            else:
                batch_sequence, dtype_tensor = self._get_single(idx)
                batch_sequences.append(batch_sequence)
                dtype_tensors.append(dtype_tensor)

        assert len(original_keys) == 1, f"All samples in one batch must come from the same dataset! Got {original_keys}"

        if self.return_label:
            return (
                torch.stack(batch_sequences),
                torch.stack(label_tensors),
                torch.stack(dtype_tensors),
                torch.stack(meta_tensors),
            )
        else:
            return (
                torch.stack(batch_sequences),
                torch.stack(dtype_tensors),
            )


def get_table_latent_dataloader(
    csv_dir=None,
    df_dict=None,
    config_dir=None,
    config_dict=None,
    transformers_dict=None,
    batch_size=512,
    embedding_dim=768,
    return_label=True,
    dist='uniform',
    cat_encoder='ordinal',
    shuffle_loader=True,
    fixed_batch=True,
    result_dict=None,
    lm_emb_batch_size=256,
):
    """
    Returns a DataLoader or a tuple of (train_dataloader, test_dataloader) if test_size is provided.

    Args:
        csv_dir (str): Directory containing the CSV files.
        df_dict (dict): Dictionary containing dataset name -> pd.DataFrame
        transformers_dict (dict): Dictionary containing dataset name -> mapping from col names to transformers.
        batch_size (int): Size of each batch.
        embedding_dim (int): Dimension of the embeddings.
        return_label (bool): Whether to return labels.
        dist (str): Distribution for numerical column transformations.
        cat_encoder (str): Type of encoder for categorical columns.
        shuffle_loader (bool): Whether to shuffle the data in the DataLoader.
        fixed_batch (bool): Whether to create fixed batches.

    Returns:
        DataLoader
    """
    if df_dict is None and csv_dir is not None:
        df_dict = create_df_dict_from_dir(csv_dir)
        print(f"Df_dict created from {csv_dir}")
    elif df_dict is None and csv_dir is None:
        raise ValueError("One of df_dict or csv_dir must be provided!")
    
    if config_dict is None and config_dir is not None:
        config_dict = create_config_dict_from_dir(config_dir)
        print(f"Config_dict created from {config_dir}")
    elif config_dict is None and config_dir is None:
        raise ValueError("One of config_dict or config_dir must be provided!")

    # Loader batch size always 1 as batch are precreated
    dataset_batch_size, loader_batch_size = batch_size, 1
    
    if transformers_dict is None:
        transformers_dict = fit_column_transformers(df_dict, dist, cat_encoder)

    if result_dict is None:
        result_dict = get_unique_embeddings(df_dict, config_dict, transformers_dict, batch_size=lm_emb_batch_size)

    df_dict = apply_transformers(df_dict, transformers_dict)

    dataset = TableLatentDataset(df_dict, result_dict, config_dict, embedding_dim, transformers_dict, return_label, dataset_batch_size, fixed_batch=fixed_batch)

    dataloader = DataLoader(dataset, batch_size=loader_batch_size, shuffle=shuffle_loader)  

    return dataloader


if __name__ == "__main__":
    import time
    from tqdm import tqdm

    #csv_dir = "/home/ubuntu/ml/CTSyn/csv/OpenTabPretrain"
    #csv_dir = "/home/ubuntu/ml/CTSyn/csv/health"
    csv_dir = "/home/ubuntu/ml/CTSyn/csv/test/diabetes"
    train_df_dict, test_df_dict = create_df_dict_from_dir(csv_dir, test_size=0.2)

    #config_dir = "/home/ubuntu/ml/CTSyn/configs/OpenTabPretrain"
    config_dir = "/home/ubuntu/ml/CTSyn/configs/health"
    config_dict = create_config_dict_from_dir(config_dir)

    train_dataloader_for_decode = get_table_latent_dataloader(df_dict=train_df_dict,config_dict=config_dict,return_label=True,batch_size=1024)

    # Start the timer
    start_time = time.time()
    train_dataloader_for_decode.dataset.shuffle_batches()

    for batch,label,categories_tensor,meta_tensor,unique_embeddings in tqdm(train_dataloader_for_decode):
        print("Train batch:",batch.shape,label.shape, categories_tensor,meta_tensor.shape,unique_embeddings.shape)  
        pass
    
    # End the timer and calculate the total time
    end_time = time.time()
    total_time = end_time - start_time
    print(f"Total run time for traversing: {total_time:.2f} seconds")




