from torch.utils.data import Dataset
import lmdb
import msgpack
import pandas as pd
import os
import json
from pathlib import Path
import random
import torch.distributed as dist

DEBUGGING = True
MAX_TEXT_COLUMNS = 10  # Assuming a default value, you might want to define this constant
MAX_COLUMNS = 512
MAX_CELL_PER_BATCH = 128 * 48  # Assuming a default value, you might want to define this constant

class Dataset_lmdb(Dataset):
    def __init__(self, lmdb_path, csv_log_path, shuffle_seed=42): # add seed parameter
        """
        Initializes the dataset with LMDB and a CSV log of keys, applying transformations
        and ensuring consistent batch sizes.

        Parameters:
        - lmdb_path (str): Path to the LMDB database.
        - csv_log_path (str): Path to the CSV file containing LMDB keys.
        - batch_size (int): Desired number of rows per batch.
        - data_transformer_class (class): Transformer class to apply data transformations.
        """
        self.lmdb_path = lmdb_path
        self.csv_log_path = csv_log_path

        try:
            self.env = lmdb.open(lmdb_path, readonly=True, lock=False)
            
        except Exception as e:
            print(f"Error initializing LMDB on rank {dist.get_rank()}: {str(e)}")
            raise

        # Load the CSV containing keys and configuration mappings
        self.csv_log = pd.read_csv(csv_log_path)
        self.keys = self.csv_log['Key'].tolist()

        # Shuffle keys with a seed for reproducibility
        # This way we can get the same order of keys for each run
        random.seed(shuffle_seed)
        random.shuffle(self.keys)

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, idx):
        batch_key = self.keys[idx]
        # Current implementation:
        with self.env.begin() as txn:  # New transaction each time
            packed_data = txn.get(batch_key.encode())
            if packed_data is None:
                raise KeyError(f"Key {batch_key} not found in LMDB.")

            # Unpack data with Msgpack
            combined_data = msgpack.unpackb(packed_data, raw=False)
            df_batch = pd.DataFrame(combined_data["data"])
            config = combined_data["metadata"]

            if DEBUGGING:
                text_columns = []
                invalid_dtype_columns = []  # New list to track columns with dtype violations
                # Create a new list for variables instead of modifying in place
                variables = config['variables'].copy()
                text_count = 0
                
                # First pass: handle text columns and check dtypes
                i = 0
                while i < len(variables):
                    var = variables[i]
                    if var['variable_type'] == 'text':
                        if text_count >= MAX_TEXT_COLUMNS:
                            text_columns.append(var['variable_name'])
                            variables.pop(i)
                        else:
                            text_count += 1
                            i += 1
                    elif var['variable_type'] == 'numerical':
                        # Check if the column's dtype is valid for numerical type
                        col_dtype = str(df_batch[var['variable_name']].dtype)
                        valid_numerical_dtypes = [
                            "int8", "int16", "int32", "int64",
                            "uint8", "uint16", "uint32", "uint64",
                            "float16", "float32", "float64"
                        ]
                        if col_dtype not in valid_numerical_dtypes:
                            invalid_dtype_columns.append(var['variable_name'])
                            variables.pop(i)
                        else:
                            i += 1
                    else:
                        i += 1
                
                # Second pass: limit total columns to MAX_COLUMNS
                if len(variables) > MAX_COLUMNS:
                    excess_columns = [var['variable_name'] for var in variables[MAX_COLUMNS:]]
                    variables = variables[:MAX_COLUMNS]
                    text_columns.extend(excess_columns)
                
                # Drop all excess columns and invalid dtype columns from DataFrame
                columns_to_drop = text_columns + invalid_dtype_columns
                if columns_to_drop:
                    df_batch = df_batch.drop(columns=columns_to_drop)
                
                # Update config with modified variables
                config = config.copy()  # Create a new copy of config
                config['variables'] = variables  # Update with our modified variables list
                
                # Validate column synchronization
                config_cols = [var['variable_name'] for var in variables]
                df_cols = list(df_batch.columns)
                
                if len(config_cols) != len(df_cols):
                    raise ValueError(f"Column count mismatch for key {batch_key}. "
                                  f"Config has {len(config_cols)} columns, "
                                  f"DataFrame has {len(df_cols)} columns.")
                
                if config_cols != df_cols:
                    # Ensure DataFrame columns are in the same order as config variables
                    try:
                        df_batch = df_batch[config_cols]
                    except KeyError as e:
                        missing_cols = set(config_cols) - set(df_cols)
                        extra_cols = set(df_cols) - set(config_cols)
                        raise ValueError(f"Column mismatch for key {batch_key}. "
                                      f"Missing in DataFrame: {missing_cols}, "
                                      f"Extra in DataFrame: {extra_cols}")

            # Add this logic after the existing debugging block
            if MAX_CELL_PER_BATCH is not None:
                n_rows = len(df_batch)
                n_cols = len(df_batch.columns)
                total_cells = n_rows * n_cols
                
                if total_cells > MAX_CELL_PER_BATCH:
                    # Calculate maximum allowed rows
                    max_rows = int(MAX_CELL_PER_BATCH // n_cols)
                    
                    # Randomly sample from the dataframe to reduce rows
                    df_batch = df_batch.sample(n=max_rows, random_state=42)
                    
                    if DEBUGGING and False:
                        print(f"Reduced batch from {n_rows} rows to {max_rows} rows to meet cell limit. "
                              f"Original: {total_cells} cells, New: {max_rows * n_cols} cells.")

            return {"config": config, "df_batch": df_batch, "key": batch_key}

    def close(self):
        self.env.close()


import torch
from torch.utils.data import DataLoader

import pandas as pd


def collate_fn(batch):
    return batch[0]



