import json
import math
from pathlib import Path
import torch
from datasets import Dataset, concatenate_datasets
from accelerate import PartialState
from src.data.metadata_utils import add_example_metadata
from src.utils.logging_utils import get_logger
from transformers import Trainer, TrainingArguments, EvalPrediction, DataCollatorForSeq2Seq
from copy import deepcopy
from src.data.tulu_3_sft_utils import encode_sft_dataset, create_chat_tokenizer
from src.data.metadata_utils import flatten_metadata
from src.data.utils import SelectionMetric, DsColumn
from src.data.hf_hub_utils import push_ds_to_hub
from src.data.utils import create_ds_cfg

logger = get_logger(name=__name__)


class ExLossMetric:
    def __init__(self) -> None:
        super().__init__()
        self.losses = {}

    def compute(self, eval_preds: EvalPrediction, compute_result: bool = False, **kwargs) -> dict:
        id_list = eval_preds.inputs["_id"].tolist()
        losses = eval_preds.losses.tolist()
        for _id, loss in zip(id_list, losses):
            self.losses[str(len(self.losses))] = {
                "_id": _id,
                SelectionMetric.LOSS: loss,
                SelectionMetric.PPL: math.exp(loss),
            }

        if compute_result:
            return deepcopy(self.losses)

        return {}



def add_selection_metrics_meta(
    *,
    dataset: Dataset,
    model_name_or_path: str,
    model_id: str,
    dataset_id: str,
    is_name_template: bool = False,
    use_liger: bool = True,
    num_proc: int = 1,
    chat_template: str | None = None,
    eval_batch_size: int = 4,
    tokenizer_name_or_path: str | None = None,
    cache_dir: Path | None = None, 
    **kwargs,
) -> Dataset:
    if is_name_template:
        model_name_or_path = model_name_or_path.format(**kwargs)
        model_id = model_id.format(**kwargs)
        if tokenizer_name_or_path:
            tokenizer_name_or_path = tokenizer_name_or_path.format(**kwargs)
        logger.info(
            f"Formatted name template. model_name={model_name_or_path}, model_id={model_id}, tokenizer_name={tokenizer_name_or_path}"
        )

    tokenizer_name_or_path = tokenizer_name_or_path or model_name_or_path
    tokenizer = create_chat_tokenizer(tokenizer_name=tokenizer_name_or_path, chat_template=chat_template)
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id

    with PartialState().main_process_first():
        tokenized_dataset = encode_sft_dataset(
            dataset=dataset,
            tokenizer_name=model_name_or_path,
            max_len=None,
            tokenizer=tokenizer,
            num_proc=num_proc,
            chat_template=chat_template,
        )

        tokenized_dataset = tokenized_dataset.remove_columns(["_meta", "_ds_id"])

    model_load_kwargs = dict(
        torch_dtype=torch.bfloat16,
        attn_implementation="flash_attention_2",
    )
    if use_liger:
        from liger_kernel.transformers import AutoLigerKernelForCausalLM

        model = AutoLigerKernelForCausalLM.from_pretrained(model_name_or_path, **model_load_kwargs)
    else:
        from transformers import AutoModelForCausalLM

        model = AutoModelForCausalLM.from_pretrained(model_name_or_path, **model_load_kwargs)

    logger.info(f"Loaded model: {model_name_or_path}")

    training_args = TrainingArguments(
        output_dir="./tmp",
        per_device_eval_batch_size=eval_batch_size,
        num_train_epochs=0,
        batch_eval_metrics=True,
        report_to=None,
        include_for_metrics=["inputs", "loss"],
        remove_unused_columns=False,
        do_train=False,
    )
    metric = ExLossMetric()

    trainer = Trainer(
        model=model,
        args=training_args,
        compute_metrics=metric.compute,
        data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding="longest"),
    )

    cache_file = None
    num_shard = 20
    for idx in range(num_shard):
        shard = tokenized_dataset.shard(num_shards=num_shard, index=idx, contiguous=True)
        if cache_dir:
            cache_key = f"{model_id}-{dataset_id}-{idx}"
            cache_file = Path(cache_dir) / f"{cache_key}.json"
            if cache_file.exists():
                with open(cache_file, "r") as f:
                    cached = json.load(f)
                metric.losses = cached
                logger.info(f"Loading cached losses from {cache_file}")
                continue
        
        logger.info(f"Computing metrics for shard {idx}") 
        eval_results = trainer.evaluate(eval_dataset=shard)

        if cache_dir:
            logger.info(f"Caching losses to {cache_file}")
            with cache_file.open("w") as f:
                json.dump(metric.losses, f)

    del trainer
    del model

    logger.info(f"Computed metrics using model: {model_name_or_path}")

    def _add_selection_metrics_meta(example, idx):
        ex_result = eval_results[f"eval_{idx}"]
        new_meta = {
            model_id: {
                SelectionMetric.PPL: ex_result[SelectionMetric.PPL],
                SelectionMetric.LOSS: ex_result[SelectionMetric.LOSS],
            }
        }
        example = add_example_metadata(example, new_metadata=new_meta)
        return example

    _ds = dataset.map(
        _add_selection_metrics_meta,
        batched=False,
        with_indices=True,
        num_proc=1,
        desc="Adding Selection Metrics Metadata to data",
    )

    return _ds


def create_segment_ds(
    *,
    segment: str,
    **kwargs,
):
    return globals()[f"create_{segment}_segment_ds"](**kwargs)


def create_bottom_segment_ds(
    *,
    dataset: Dataset,
    subset_ratio: float,
    tagging_model_id: str,
    metric_name: str,
    num_proc: int = 1,
    per_ds: bool = False,
    **kwargs,
):
    assert 0 <= subset_ratio < 1
    logger.info(f"Creating bottom segment dataset with subset ratio {subset_ratio}")
    if per_ds:
        return _create_per_ds_segment_ds(
            dataset=dataset,
            left_percentile=0,
            right_percentile=int(subset_ratio * 100),
            tagging_model_id=tagging_model_id,
            metric_name=metric_name,
            num_proc=num_proc,
            subset_ratio=subset_ratio,
            **kwargs,
        )
    return _create_segment_ds(
        dataset=dataset,
        left_percentile=0,
        right_percentile=int(subset_ratio * 100),
        tagging_model_id=tagging_model_id,
        metric_name=metric_name,
        num_proc=num_proc,
        subset_ratio=subset_ratio,
        **kwargs,
    )


def create_top_segment_ds(
    *,
    dataset: Dataset,
    subset_ratio: float,
    tagging_model_id: str,
    metric_name: str,
    num_proc: int = 1,
    per_ds: bool = False,
    **kwargs,
):
    assert 0 <= subset_ratio < 1
    logger.info(f"Creating top segment dataset with subset ratio {subset_ratio}")
    if per_ds:
        return _create_per_ds_segment_ds(
            dataset=dataset,
            left_percentile=100 - int(subset_ratio * 100),
            right_percentile=100,
            tagging_model_id=tagging_model_id,
            metric_name=metric_name,
            num_proc=num_proc,
            subset_ratio=subset_ratio,
            **kwargs,
        )
    return _create_segment_ds(
        dataset=dataset,
        left_percentile=100 - int(subset_ratio * 100),
        right_percentile=100,
        tagging_model_id=tagging_model_id,
        metric_name=metric_name,
        num_proc=num_proc,
        subset_ratio=subset_ratio,
        **kwargs,
    )


def create_middle_segment_ds(
    *,
    dataset: Dataset,
    subset_ratio: float,
    tagging_model_id: str,
    metric_name: str,
    num_proc: int = 1,
    per_ds: bool = False,
    **kwargs,
):
    assert 0 <= subset_ratio < 1
    logger.info(f"Creating middle segment dataset with subset ratio {subset_ratio}")
    if per_ds:
        return _create_per_ds_segment_ds(
            dataset=dataset,
            left_percentile=int((50 - (subset_ratio * 100) / 2)),
            right_percentile=int((50 + (subset_ratio * 100) / 2)),
            tagging_model_id=tagging_model_id,
            metric_name=metric_name,
            num_proc=num_proc,
            subset_ratio=subset_ratio,
            **kwargs,
        )

    return _create_segment_ds(
        dataset=dataset,
        left_percentile=int((50 - (subset_ratio * 100) / 2)),
        right_percentile=int((50 + (subset_ratio * 100) / 2)),
        tagging_model_id=tagging_model_id,
        metric_name=metric_name,
        num_proc=num_proc,
        subset_ratio=subset_ratio,
        **kwargs,
    )


def _create_per_ds_segment_ds(
    *,
    dataset: Dataset,
    num_proc: int,
    subset_ratio: float,
    **kwargs,
):
    logger.info(f"Creating per dataset segment dataset with subset ratio {subset_ratio}")
    unique_ds_ids = dataset.unique(DsColumn.DS_ID)
    sub_segments = []
    for ds_id in unique_ds_ids:
        sub_ds = dataset.filter(lambda x: x[DsColumn.DS_ID] == ds_id, num_proc=num_proc)
        sub_segment = _create_segment_ds(
            dataset=sub_ds,
            num_proc=num_proc,
            subset_ratio=subset_ratio,
            check_size=False,
            **kwargs,
        )
        sub_segments.append(sub_segment)

    subset = concatenate_datasets(sub_segments)
    assert len(subset) - (len(dataset) * subset_ratio) < (2 * len(unique_ds_ids)), breakpoint()

    return subset


def _create_segment_ds(
    *,
    dataset: Dataset,
    left_percentile: int,
    right_percentile: int,
    tagging_model_id: str,
    metric_name: str,
    num_proc: int = 1,
    subset_ratio: float,
    check_size: bool = True,
    **kwargs,
):
    assert 0 <= left_percentile < right_percentile <= 100
    logger.info(f"Creating segment dataset with {left_percentile} <= p <= {right_percentile}")
    metric_name = SelectionMetric(metric_name)  # This will validate the metric name value

    meta_join_char = "."
    metric_col_name = f"{tagging_model_id}{meta_join_char}{metric_name}"
    _ds = flatten_metadata(
        dataset=dataset, num_proc=num_proc, meta_join_char=meta_join_char, meta_to_keep=[metric_col_name]
    )

    _ds = _ds.remove_columns([c for c in _ds.column_names if c not in [DsColumn.ID, DsColumn.DS_ID, metric_col_name]])
    ds_df = _ds.to_pandas()

    left_p_val = ds_df[metric_col_name].quantile(left_percentile / 100)
    right_p_val = ds_df[metric_col_name].quantile(right_percentile / 100)

    logger.info(f"Left percentile value: {left_p_val} Right percentile value: {right_p_val}")
    subset_df = ds_df[(ds_df[metric_col_name] >= left_p_val) & (ds_df[metric_col_name] <= right_p_val)]

    indices = subset_df.index.tolist()
    # Sanity check, make sure that indices are in the same order as the dataset
    assert dataset[indices[0]][DsColumn.ID] == subset_df.iloc[0][DsColumn.ID]
    assert dataset[indices[-1]][DsColumn.ID] == subset_df.iloc[-1][DsColumn.ID]

    subset = dataset.select(indices)
    if check_size:
        assert abs(len(subset) - len(dataset) * subset_ratio) < 20, breakpoint()

    return subset


def create_ds_id(
    *,
    original_dataset_id: str,
    segment: str,
    subset_ratio: float,
    tagging_model_id: str,
    metric_name: str,
    per_ds: bool = False,
    **kwargs,
):
    per_ds_str = "-per_ds" if per_ds else ""
    return f"{original_dataset_id}-{tagging_model_id}-{metric_name}{per_ds_str}-{segment}-{subset_ratio}"


def push_segment_ds_to_hub(
    *,
    dataset: Dataset,
    original_dataset_id: str,
    segment: str,
    subset_ratio: float,
    tagging_model_id: str,
    metric_name: str,
    per_ds: bool = False,
    **kwargs,
):
    dataset_id = create_ds_id(
        original_dataset_id=original_dataset_id,
        segment=segment,
        subset_ratio=subset_ratio,
        tagging_model_id=tagging_model_id,
        metric_name=metric_name,
        per_ds=per_ds,
    )
    return push_ds_to_hub(
        dataset=dataset,
        dataset_name=dataset_id,
        **kwargs,
    )


def create_segment_ds_cfg(
    *,
    dataset: Dataset,
    hf_space_name: str,
    original_dataset_id: str,
    segment: str,
    subset_ratio: float,
    tagging_model_id: str,
    metric_name: str,
    per_ds: bool = False,
    **kwargs,
):
    dataset_id = create_ds_id(
        dataset=dataset,
        original_dataset_id=original_dataset_id,
        segment=segment,
        subset_ratio=subset_ratio,
        tagging_model_id=tagging_model_id,
        metric_name=metric_name,
        per_ds=per_ds,
    )

    create_ds_cfg(
        dataset=dataset,
        dataset_id=dataset_id,
        dataset_name=f"{hf_space_name}/{dataset_id}",
    )

    return dataset
