import numpy as np
import random
from pathlib import Path
from typing import Any
from datasets import load_dataset, Dataset, ClassLabel
from transformers import AutoTokenizer
from collections import Counter

import yaml
from datasets import Dataset, DatasetDict, concatenate_datasets

from src.data.utils import CustomColName, DatasetConfig
from src.utils.logging_utils import get_logger
from src.utils.pipeline_utils import PipelineStep
from src.data.tulu_3_sft_utils import encode_sft_dataset

logger = get_logger(name=__name__)


class CreateTuluV3Subset(PipelineStep):
    class Config(PipelineStep.Config):
        original_dataset_name: str
        subset_ratio: float
        hf_token: str
        num_proc: int = 32
        data_seed: int = 42
        max_seq_length: int = 2048
        split_name: str = "train"
        subset_name_filter_list: list[str] = []
        subset_name_col: str = "source"
        dataset_cfg_template: Path = Path("config/train/dataset/_dataset_template.yaml")

    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)
        self.config = self.Config(**kwargs)
        random.seed(self.config.data_seed)

    def _call(self, **kwargs) -> Any:
        ds = load_dataset(self.config.original_dataset_name, split=self.config.split_name)
        ds = self.filter_long_examples(ds)

        if self.config.subset_name_filter_list:
            logger.info(f"Filtering dataset by {self.config.subset_name_col} in {self.config.subset_name_filter_list}")
            ds = ds.filter(lambda x: x[self.config.subset_name_col] not in self.config.subset_name_filter_list)
            logger.info(f"Filtered dataset:\n{ds}")

        if not isinstance(ds.features[self.config.subset_name_col], ClassLabel):
            logger.info(f"Converting {self.config.subset_name_col} to ClassLabel")
            ds = ds.cast_column(
                self.config.subset_name_col, ClassLabel(names=list(set(ds[self.config.subset_name_col])))
            )

        subset = ds.train_test_split(
            test_size=self.config.subset_ratio,
            seed=self.config.data_seed,
            stratify_by_column=self.config.subset_name_col,
            shuffle=True,
        )["test"]

        logger.info(f"Generated Subset:\n{subset}")

        counts = Counter(subset[self.config.subset_name_col])
        metadata = []
        for src_id, cls_cnt in counts.items():
            msg = f"Src={subset.features[self.config.subset_name_col].int2str(src_id)}: {cls_cnt} samples"
            metadata.append(msg)
            logger.info(msg)

        subset_name = f"{self.config.original_dataset_name.split('/')[-1]}-.{(int(self.config.subset_ratio*100))}"
        DatasetDict(train=subset).push_to_hub(
            subset_name,
            private=True,
            token=self.config.hf_token,
        )
        logger.info(f"Dataset {subset_name} pushed to hub")

        self._create_ds_cfg(
            dataset_name=subset_name, ds_id=f"T3-{(int(self.config.subset_ratio*100))}", metadata=metadata
        )

    def _create_ds_cfg(self, *, dataset_name: str, ds_id: str, metadata: list[str]) -> None:
        with self.config.dataset_cfg_template.open("r") as f:
            ds_cfg = yaml.safe_load(f)

        ds_cfg["name_or_path"] += dataset_name
        ds_cfg["eval_split_name"] = None
        ds_cfg["meta"]["id_suffix"] = ds_id

        out_file = self.config.dataset_cfg_template.parent / f"{dataset_name}.yaml"
        with out_file.open("w") as f:
            yaml.dump(ds_cfg, f)

        with out_file.open("a") as f:
            f.write("\n#Ratios:\n")
            f.writelines([f"#{line}\n" for line in metadata])

        logger.info(f"Dataset config created: {out_file}")

    def filter_long_examples(self, dataset: Dataset) -> Dataset:
        tokenizer = AutoTokenizer.from_pretrained("philschmid/meta-llama-3-tokenizer")
        tokenized_dataset = encode_sft_dataset(
            dataset=dataset,
            tokenizer=tokenizer,
            max_len=self.config.max_seq_length,
            num_proc=self.config.num_proc,
            # overwrite_cache=True,
            cols_to_keep=["id"],
        )

        _usable_ds = tokenized_dataset.filter(
            lambda x: x["input_ids"].shape[0] <= self.config.max_seq_length,
            num_proc=self.config.num_proc,
            desc=f"Selecting examples with length <= {self.config.max_seq_length}",
        )

        ids_to_use = _usable_ds["id"]
        filtered_ds = dataset.filter(
            lambda x: x["id"] in ids_to_use,
            num_proc=self.config.num_proc,
            desc="Filtering dataset",
        )
        logger.info(f"Filtered dataset (max len): {filtered_ds}")
        return filtered_ds
