from typing import Optional

import pandas as pd

from datasets import load_dataset, load_from_disk, Dataset
from sklearn.model_selection import train_test_split


def get_balanced_salad_dataset(
        split='str', sanity_check=False, cache_dir: Optional[str] = None) -> Dataset:

    def split_group(group):
        train, test = train_test_split(group, test_size=0.2, random_state=0)
        return train, test

    try:
        dataset = load_from_disk(f"{cache_dir}/{split}")
        print(f"Loaded {split} dataset from cache.")

    except FileNotFoundError:
        print(f"No cached {split} dataset found, loading and processing...")

        salad_dataset = load_dataset("OpenSafetyLab/Salad-Data", name='base_set', split='train')
        df = pd.DataFrame(salad_dataset)

        # Count the minimal number of records for a category
        min_num_question = df.groupby('3-category').count()['question'].min()

        # Sample min_num_question records for each category
        sampled_df = df.groupby('3-category').apply(lambda x: x.sample(
            n=min_num_question,
            replace=False  # We make sure len(x) < min_num_question
        )).reset_index(drop=True)

        # Split each group into train and test splits
        train_list = []
        test_list = []
        for _, group in sampled_df.groupby('3-category'):
            train_group, test_group = split_group(group)
            train_list.append(train_group)
            test_list.append(test_group)
        train_df = pd.concat(train_list).reset_index(drop=True)
        test_df = pd.concat(test_list).reset_index(drop=True)

        # Convert back to hf dataset
        dataset_train = Dataset.from_pandas(train_df)
        dataset_test = Dataset.from_pandas(test_df)

        # Save to cache_dir
        dataset_train.save_to_disk(f"{cache_dir}/train")
        dataset_test.save_to_disk(f"{cache_dir}/test")

        # Get dataset
        dataset = load_from_disk(f"{cache_dir}/{split}")

    if sanity_check:
        dataset = dataset.select(range(min(len(dataset), 100)))

    return dataset
