import numpy as np
import random
from pathlib import Path
from typing import Any

import yaml
from datasets import Dataset, DatasetDict, concatenate_datasets

from src.data.raw_data_loader import RawDataLoader
from src.data.utils import CustomColName, DatasetConfig
from src.utils.logging_utils import get_logger
from src.utils.pipeline_utils import PipelineStep

logger = get_logger(name=__name__)


class CreateRandomDs(PipelineStep):
    class Config(PipelineStep.Config):
        class Config:
            arbitrary_types_allowed = True

        hf_token: str
        num_proc: int
        dataset_cfg_list: list[DatasetConfig]
        push_to_hf_hub: bool = True
        num_examples: int = 6000
        num_subsets: int = 3
        data_seed: int = 42
        dataset_cfg_template: Path = Path("config/train/dataset/_dataset_template.yaml")

    def __init__(
        self,
        *,
        raw_data_loader: RawDataLoader,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)
        self.config = self.Config(**kwargs)
        self.raw_data_loader = raw_data_loader
        random.seed(self.config.data_seed)

    def _call(self, **kwargs) -> Any:
        for num_ds in range(1, len(self.config.dataset_cfg_list) + 1):
            ds_names = [self._clean_ds_name(ds_cfg.id) for ds_cfg in self.config.dataset_cfg_list[:num_ds]]

            sub_samples, ratios = self._create_subsets(num_ds=num_ds)
            for idx, ds in enumerate(sub_samples):
                subset_name = f'{"-".join(ds_names)}-Rand-{idx}-{self.config.num_examples // 1000}k'
                self._push_to_hub(dataset=ds, dataset_name=subset_name)
                self._create_ds_cfg(dataset_name=subset_name, ds_names=ds_names, ratios=ratios[idx] if ratios else None)

    def _clean_ds_name(self, ds_name: str) -> str:
        ds_name = ds_name.replace("ds", "")
        ds_name = ds_name.replace("-20k", "")
        return ds_name

    def _create_ds_cfg(self, *, dataset_name: str, ds_names: list[str], ratios: list[float] | None) -> None:
        with self.config.dataset_cfg_template.open("r") as f:
            ds_cfg = yaml.safe_load(f)

        ds_cfg["name_or_path"] += dataset_name
        ds_cfg["eval_split_name"] = None
        ds_cfg["meta"]["id_suffix"] = dataset_name

        out_file = self.config.dataset_cfg_template.parent / f"{dataset_name}.yaml"
        with out_file.open("w") as f:
            yaml.dump(ds_cfg, f)

        if ratios:
            with out_file.open("a") as f:
                f.write("\nRatios:\n")
                for ds_name, ratio in zip(ds_names, ratios):
                    f.write(f"# {ds_name}: {ratio}\n")

        logger.info(f"Dataset config created: {out_file}")

    def _push_to_hub(self, *, dataset: Dataset, dataset_name: str) -> None:
        _ds_dict = DatasetDict(train=dataset)
        _ds_dict.push_to_hub(dataset_name, private=True, token=self.config.hf_token)
        logger.info(f"Dataset {dataset_name} pushed to hub")

    def _verify_ds_ids(self, *, ds_list: list[Dataset]) -> None:
        ds_ids = {ds[0][CustomColName.DS_ID.value] for ds in ds_list}
        assert len(ds_ids) == len(ds_list), "Dataset IDs are not unique"

    def _create_subsets(self, *, num_ds: int):
        ds_list = [
            self.raw_data_loader.load(dataset_config=ds_cfg)[ds_cfg.train_split_name].shuffle(self.config.data_seed)
            for ds_cfg in self.config.dataset_cfg_list[:num_ds]
        ]
        self._verify_ds_ids(ds_list=ds_list)
        logger.info(f"Creating random dataset with {len(ds_list)} datasets")

        subsets = []
        if len(ds_list) == 1:
            for _ in range(self.config.num_subsets):
                _ds = ds_list[0].shuffle(seed=self.config.data_seed)
                subsets.append(_ds.select(range(self.config.num_examples)))
                ratios = None
        else:
            ratios = [[1 / num_ds] * num_ds]  # uniform distribution
            for _ in range(self.config.num_subsets - 1):
                # random ratios
                rand_values = [random.randint(1, 100) for _ in range(num_ds)]
                ratios.append([v / sum(rand_values) for v in rand_values])

            for ratio in ratios:
                assert len(ratio) == num_ds
                assert np.isclose(sum(ratio), 1.0, rtol=1e-10)
                logger.info(f"Creating random dataset with ratios: {ratio}")
                subset_parts = []
                for r, ds in zip(ratio, ds_list):
                    _ds = ds.shuffle(seed=self.config.data_seed)
                    subset_parts.append(_ds.select(range(int(r * self.config.num_examples))))

                subset = concatenate_datasets(subset_parts)
                assert (
                    abs(len(subset) - self.config.num_examples) <= 5
                ), f"Expected {self.config.num_examples} examples, got {len(subset)}"

                subsets.append(subset)

        assert len(subsets) == self.config.num_subsets

        return subsets, ratios
