import os
from glob import glob
import pandas as pd
import json
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from sklearn.preprocessing import QuantileTransformer

from .dataTransformer import DataTransformer, IdentityColumnTransformer, StringTransformer, DatetimeTransformer, NumericalQuantileTransformer, PiecewiseLinearEncoderColumn

def preprocess_one_table(df, config, transformer_mapping=None, quantile_transform=False):
    """
    Preprocesses a single dataframe and updates its configuration.
    
    Parameters:
    - df (pd.DataFrame): The dataframe to preprocess
    - config (dict): The configuration dictionary for the dataframe
    - transformer_mapping (dict, optional): Mapping of column types to transformer classes
    - quantile_transform (bool, optional): Whether to apply quantile transformation to numerical columns.
                                         If False, uses PiecewiseLinearEncoderColumn instead.
    
    Returns:
    - tuple: (column_type, config) where column_type is a dict mapping column names to types
              and config is the updated configuration
    """
    if transformer_mapping is None:
        # Store transformer classes, not instances
        transformer_mapping = {
            "numerical": NumericalQuantileTransformer if quantile_transform else PiecewiseLinearEncoderColumn,
            "categorical": StringTransformer,
            "datetime": DatetimeTransformer,
            "text": StringTransformer,
            "unknown": IdentityColumnTransformer
        }
        
    column_type = {}
    for idx, column in enumerate(df.columns):
        try:
            dtype = config['variables'][idx]['variable_type'].lower()
            assert dtype in transformer_mapping, f"Column {column} has type {dtype} which is not in the transformer mapping: {transformer_mapping.keys()}."
            column_type[column] = dtype
            
            # Add categories for categorical columns. Needed for vectorization
            if dtype == "categorical":
                config['variables'][idx]['categories'] = df[column].astype(str).unique().tolist()
            
            # Check if datetime column to determine granularity
            if dtype == "datetime":
                # Convert to pandas datetime
                datetime_series = pd.to_datetime(df[column], errors='coerce')
                
                # Check if time components are present and meaningful
                has_time = False
                has_date = False
                
                # Check if any non-zero time components exist
                if not (datetime_series.dt.hour == 0).all() or not (datetime_series.dt.minute == 0).all() or not (datetime_series.dt.second == 0).all():
                    has_time = True
                
                # Check if date components vary (not just default date)
                if len(datetime_series.dt.date.unique()) > 1:
                    has_date = True
                
                # Determine granularity
                if has_date and has_time:
                    granularity = "datetime"
                elif has_date:
                    granularity = "date"
                elif has_time:
                    granularity = "time"
                else:
                    granularity = "unknown"
                
                # Store granularity in config
                config['variables'][idx]['datetime_granularity'] = granularity
            
            # Check if dtype is numerical
            if dtype == "numerical":
                if quantile_transform:
                    # Use NumericalQuantileTransformer
                    transformer = NumericalQuantileTransformer(
                        output_distribution='uniform',
                        n_quantiles=1000
                    )
                    
                    # Fit the transformer to get parameters
                    transformer.fit(df[column].values.reshape(-1, 1))
                    
                    # Get config including parameters for later
                    transformer_config = transformer.get_config()
                    
                    # Store parameters in config
                    config['variables'][idx]['quantile_params'] = transformer_config['quantile_params']
                else:
                    # Use PiecewiseLinearEncoderColumn
                    transformer = PiecewiseLinearEncoderColumn(
                        n_bins=32, 
                        strategy="quantile"
                    )
                    
                    # Fit the transformer to get parameters
                    transformer.fit(df[column])
                    
                    # Get config for PLE
                    ple_config = transformer.get_config()
                    
                    # Store parameters in config
                    config['variables'][idx]['ple_params'] = ple_config
        except Exception as e:
            print(f"Error processing column '{column}':")
            print(f"  - Specified type in config: {config['variables'][idx]['variable_type']}")
            print(f"  - Pandas inferred dtype: {df[column].dtype}")
            print(f"  - Sample values: {df[column].head(5).tolist()}")
            print(f"  - Exception: {str(e)}")
            raise  # Re-raise the exception after printing diagnostic info
    
    return column_type, config

# Load all parquet/config file names
def process_datasets(
    df_folder, config_folder, new_folder, batch_size, 
    data_transformer_class=None,transformer_mapping=None,quantile_transform = False
):
    """
    Processes and splits Parquet files into batches with matching JSON configurations.
    
    Parameters:
    - df_folder (str): Path to the folder containing the Parquet files.
    - config_folder (str): Path to the folder containing the JSON config files.
    - new_folder (str): Path to the folder where batch files will be saved.
    - batch_size (int): Number of rows per batch.
    - data_transformer_class ( optional): Transformer class to apply data transformation 
      on each data frame before splitting. The function should take a data frame as input 
      and return the transformed data frame.
      
    Returns:
    - batch_file_to_config (dict): A dictionary mapping batch file paths to their JSON config.
    """
    print("Preprocessing local dataset...")
    # Load all parquet/config file names
    parquet_files = glob(os.path.join(df_folder, '*.parquet'))
    json_files = glob(os.path.join(config_folder, '*.json'))
    #print(f"Found {len(parquet_files)} parquet files : {parquet_files} and {len(json_files)} json files : {json_files}.")

    # Create a mapping from datasetName to parquet and json files
    parquet_dict = {os.path.splitext(os.path.basename(f))[0]: f for f in parquet_files}
    json_dict = {os.path.splitext(os.path.basename(f))[0]: f for f in json_files}

    # Find matching dataset names
    dataset_names = set(parquet_dict.keys()) & set(json_dict.keys())
    print(f"Found {len(dataset_names)} matching dataset names.")

    # Ensure the new folder exists
    os.makedirs(new_folder, exist_ok=True)

    # For mapping from batch file to config
    batch_file_to_config = {}

    # Process each dataset
    for dataset_name in tqdm(dataset_names, desc="Processing datasets"):
        #print(f"Processing dataset: {dataset_name}")
        parquet_path = parquet_dict[dataset_name]
        json_path = json_dict[dataset_name]

        # Load the JSON config
        with open(json_path, 'r') as f:
            config = json.load(f)

        # Load the Parquet data frame
        df = pd.read_parquet(parquet_path)
        #print("Df loaded:", df.head())
        
        # Preprocess the table
        column_type, config = preprocess_one_table(df, config, transformer_mapping, quantile_transform)
        # print("Config after preprocessing:", config)
        # print("Column type:", column_type)
        # Apply data transformation if provided
        # print("transformer_mapping:", transformer_mapping)
        if data_transformer_class:
            # Create transformer from config
            # Here we pass transformer_mapping containing transformer classes, not instances
            tf = data_transformer_class.from_config(config, transformer_mapping)
            
            # Transform the data
            df = tf.transform(df)
        
        #print("Df transformed:", df.head())

        # Split into batches
        num_rows = len(df)
        num_batches = (num_rows + batch_size - 1) // batch_size

        for batch_idx in range(num_batches):
            start_idx = batch_idx * batch_size
            end_idx = min(start_idx + batch_size, num_rows)
            df_batch = df.iloc[start_idx:end_idx]

            # Save the batch to new_folder
            batch_file_name = f"{dataset_name}_batch{batch_idx}.parquet"
            batch_file_path = os.path.join(new_folder, batch_file_name)
            df_batch.to_parquet(batch_file_path)

            # Map batch file to config
            batch_file_to_config[batch_file_path] = config.copy()

    return batch_file_to_config

class ParquetDataset(Dataset):
    def __init__(self, batch_file_to_config):
        self.batch_files = list(batch_file_to_config.keys())
        self.batch_file_to_config = batch_file_to_config

    def __len__(self):
        return len(self.batch_files)

    def __getitem__(self, idx):
        batch_file_path = self.batch_files[idx]
        config = self.batch_file_to_config[batch_file_path]
        df_batch = pd.read_parquet(batch_file_path)
        return {'config': config.copy(), 'df_batch': df_batch, "batch_file_path": batch_file_path}

def collate_fn(batch):
    return batch[0]