import os
import pandas as pd
from transformers import AutoTokenizer
from datasets import Dataset, DatasetDict



def custom_loading_dataset(dataset_name, train_name='train.parquet', test_name='test.parquet', max_length=512, tokenizer=None):
    """
    Load and preprocess a dataset from Parquet files, and filter out samples exceeding a specified length.

    Args:
        dataset_name (str): The base directory of the dataset.
        train_name (str, optional): The name of the training file. Defaults to 'train.parquet'.
        test_name (str, optional): The name of the test file. Defaults to 'test.parquet'.
        max_length (int, optional): Maximum length of the samples to keep. Defaults to 512.
        tokenizer (str, optional): tokenizer to use. Defaults to 'bert-base-uncased'.

    Returns:
        DatasetDict: A dictionary-like object containing the training and test datasets.
    """

    train_path = os.path.join(dataset_name, train_name)
    test_path = os.path.join(dataset_name, test_name)



    def get_length(text):
        inputs = tokenizer(text, return_tensors="pt", padding=False, truncation=False)
        return inputs["input_ids"].shape[1]


    try:
        train_data = pd.read_parquet(train_path)
        train_data['split'] = 'train'
    except FileNotFoundError:
        raise FileNotFoundError(f"Training file not found at {train_path}")


    try:
        test_data = pd.read_parquet(test_path)
        test_data['split'] = 'test'
    except FileNotFoundError:
        print(f"Test file not found at {test_path}. Skipping test data.")
        test_data = None

    column_mapping = {
        'ground_truth_answer': 'ground_truth',
        'subject': 'topic',
        'target': 'solution',
        # 'data_source': 'source',
        'input': 'instruction',
        # 'ability': 'skill',
        # 'reward_model': 'reward',
        # 'extra_info': 'metadata',
        'question': 'problem'
    }

    train_data.rename(columns=column_mapping, inplace=True)

    if test_data is not None:
        test_data.rename(columns=column_mapping, inplace=True)


    train_data['length'] = train_data['instruction'].apply(get_length)
    if test_data is not None:
        test_data['length'] = test_data['instruction'].apply(get_length)

    train_data = train_data[train_data['length'] <= max_length]
    if test_data is not None:
        test_data = test_data[test_data['length'] <= max_length]

    train_dataset = Dataset.from_pandas(train_data)
    if test_data is not None:
        test_dataset = Dataset.from_pandas(test_data)
    else:
        test_dataset = None

    dataset_dict = DatasetDict({
        'train': train_dataset,
        'test': test_dataset
    })

    return dataset_dict