import pandas as pd
from datasets import Dataset, DatasetDict
from tqdm import tqdm

from src.data.raw_data_loader import RawDataLoader
from src.data.utils import DatasetConfig
from src.utils.logging_utils import get_logger

logger = get_logger(name=__name__)


class RawShareGPTDataLoader(RawDataLoader):
    def load(self, *, dataset_config: DatasetConfig) -> DatasetDict:
        # data has some extra keys in the conversations, we remove them

        logger.info("Loading using pandas")
        df = pd.read_json(dataset_config.name_or_path, lines=True)

        all_keys = set()
        removed_keys = set()
        all_roles = set()
        for example in tqdm(df.to_dict(orient="records"), desc="Checking examples"):
            for idx, turn in enumerate(example["conversations"]):
                all_roles.add(turn["from"])
                keys_to_remove = [k for k in turn.keys() if k not in ["from", "value"]]
                for key in keys_to_remove:
                    removed_keys.add(key)
                    example["conversations"][idx].pop(key)

                keys = list(example["conversations"][idx].keys())
                all_keys.update(keys)

        logger.info(f"Keys after cleanup in conversations: {all_keys}. Removed keys: {removed_keys}")
        logger.info(f"Roles in the dataset: {all_roles}")

        ds = Dataset.from_pandas(df)
        logger.info(f"Dataset loaded. {ds}")
        return DatasetDict(train=ds)
