from pathlib import Path
from typing import Any

import yaml
from accelerate.state import PartialState
from datasets import DatasetDict, load_dataset
from pydantic import BaseModel, ConfigDict

from src.data.utils import CustomColName, DatasetConfig
from src.utils.logging_utils import get_logger

logger = get_logger(name=__name__)


class RawDataLoaderConfig(BaseModel):
    model_config = ConfigDict(extra="forbid")

    num_proc: int | None = 8
    add_id: bool = True
    max_train_examples: int | None = None
    max_validation_examples: int | None = None
    seed: int = 42
    clean_cache: bool = False
    eval_split_column: str | None = None
    meta: dict[str, Any] | None = None


class RawDataLoader:
    def __init__(self, **kwargs):
        self.config = RawDataLoaderConfig(**kwargs)

    def load(self, *, dataset_config: DatasetConfig, clean_data: bool = False, revision: str | None = None) -> DatasetDict:
        with PartialState().local_main_process_first():  # Let rank 0 download the dataset
            dataset_kwargs = dataset_config.datasets_kwargs or {}
            splits = [dataset_config.train_split_name] + (
                [dataset_config.eval_split_name] if dataset_config.eval_split_name else []
            )

            raw_dataset = DatasetDict(
                {
                    split: load_dataset(
                        dataset_config.name_or_path,
                        split=split,  # type: ignore
                        revision=revision,
                        **dataset_kwargs,
                    )  # type: ignore
                    for split in splits
                }
            )

            if self.config.clean_cache:
                raw_dataset.cleanup_cache_files()  # type: ignore
                raw_dataset = None
                raw_dataset = DatasetDict(
                    {
                        split: load_dataset(
                            dataset_config.name_or_path,
                            split=split,  # type: ignore
                            **dataset_kwargs,
                        )  # type: ignore
                        for split in splits
                    }
                )

        for split in splits:
            assert len(raw_dataset[split]) > 0, f"Dataset {split} is empty"
        logger.info(f"Raw dataset loaded. Name={dataset_config.name_or_path}:\n{raw_dataset}")

        if self.config.max_train_examples and dataset_config.train_split_name in raw_dataset:
            raw_dataset[dataset_config.train_split_name] = raw_dataset[dataset_config.train_split_name].select(
                range(self.config.max_train_examples)
            )
            logger.info(f"Limiting train dataset to {self.config.max_train_examples} examples: {raw_dataset}")

        if self.config.eval_split_column:
            assert dataset_config.eval_split_name in raw_dataset
            raw_dataset[dataset_config.eval_split_name] = self.split_dataset_on_column(
                dataset=raw_dataset[dataset_config.eval_split_name],
                column=self.config.eval_split_column,
            )
            logger.info("Eval dataset: %s", raw_dataset)
            
        if clean_data:
            raw_dataset = self.clean_messages_data(ds=raw_dataset, ds_name=dataset_config.name_or_path)

        return raw_dataset

    def split_dataset_on_column(self, *, dataset, column):
        splits = {}
        unique_values = dataset.unique(column)
        assert len(unique_values) > 1
        for value in unique_values:
            split = dataset.filter(lambda example: example[column] == value)
            splits[str(value)] = split
        logger.info(f"Split dataset on column {column}:\n{splits}")
        return splits
    
    def clean_messages_data(self, *, ds: DatasetDict, ds_name: str) -> DatasetDict:
        def _parse_ex_chatml_attr(example, **kwargs):
            for i in range(len(example["messages"])):
                content = example["messages"][i]["content"].strip()
                example[CustomColName.FLAGGED.value] = True if content == "" else False
                example["messages"][i]["content"] = example["messages"][i]["content"].lstrip()
            return example
        
        ds = ds.map(
            _parse_ex_chatml_attr,
            num_proc=self.config.num_proc,
            desc="Parsing examples",
        )
        
        ds = self._remove_flagged_examples(ds)

        return ds

    def _remove_flagged_examples(self, ds):
        _tmp_len = len(ds)
        ds = ds.filter(
            lambda x: not x[CustomColName.FLAGGED.value],
            num_proc=self.config.num_proc,
            desc="Filtering flagged examples",
        )
        ds = ds.remove_columns([CustomColName.FLAGGED.value])
        logger.info(f"Removed {(_tmp_len - len(ds))} flagged examples")

        return ds
