from typing import Dict, Callable, Optional, Any

import numpy as np

from mindspeed_rl.datasets.utils import _infer_seqlen, get_prompt_index

from mindspeed_rl.datasets.indexed_dataset import get_packed_indexed_dataset
from mindspeed_rl.datasets.base_dataset import BaseDataset
from mindspeed_rl.datasets.templates import get_model_template
from mindspeed_rl.datasets.utils import _build_index_mappings


class PromptDataset(BaseDataset):
    def __init__(
            self,
            data_prefix: str = "",
            is_packed_data: bool = False,
            tokenizer: Callable = None,
            seq_length: int = 128,
            num_samples: int = None,
            name: str = "",
            documents: Any = None,
            seed: int = 42,
            full_shuffle_instruction_dataset: bool = False,
            token_param: Optional[Dict] = None,
            preprocess_template: Optional[str] = None,
            pad_token: int = 0,
            eos_token: int = 1,
            extra_param: Any = None,
            **kwargs,
    ):
        self.data_prefix = data_prefix
        self.is_packed_data = is_packed_data
        self.tokenizer = tokenizer
        self.token_param = token_param
        self.seq_length = seq_length
        self.preprocess_template = preprocess_template
        self.pad_token = pad_token
        self.eos_token = eos_token
        self.num_samples = num_samples
        self.args = extra_param

        if self.is_packed_data:
            self.res_dataset = get_packed_indexed_dataset(data_prefix=self.data_prefix,
                                                          filter_length=getattr(extra_param, 'max_prompt_length', None),
                                                          is_pairwise_dataset=self.args.is_pairwise_dataset)
            self.shuffle_index = _build_index_mappings(name=name,
                                                       data_prefix=self.data_prefix,
                                                       start_index=documents[0],
                                                       nb_documents=len(documents),
                                                       num_samples=self.num_samples,
                                                       seed=seed,
                                                       full_shuffle_instruction_dataset=full_shuffle_instruction_dataset,
                                                       parallel_state=kwargs.get('parallel_state'),
                                                       no_shuffle=True)
            dataset_type = "Prompt_DS_Packed"
        else:
            raise NotImplementedError('non packed data are not supported yet.')

        super().__init__(self.res_dataset, dataset_type)

    def __len__(self):
        return len(self.shuffle_index)

    def __getitem__(self, index):
        doc_idx = self.shuffle_index[index]

        item = self.res_dataset[doc_idx]
        if self.args.is_pairwise_dataset:
            return self._cut_pairwise_token(item, np.int64)
        return self._cut_instruction_token(item, np.int64)

    def _cut_instruction_token(self, item, dtype):
        IGNORE_INDEX = -100
        if "labels" in item.keys() and not self.args.dataset_additional_keys:
            token_length = len(item["input_ids"])
            if token_length <= self.seq_length:
                return {
                    "input_ids": item["input_ids"].astype(dtype),
                    "attention_mask": np.ones_like(item["input_ids"]).astype(dtype),
                    "labels": item["labels"].astype(dtype)
                }

            template = None
            # get model chat template
            if hasattr(self.args, "prompt_type") and self.args.prompt_type is not None:
                template = get_model_template(self.args.prompt_type, self.args.prompt_type_path, self.args.enable_thinking)

            prompt_begin_list, prompt_end_list = get_prompt_index(item["labels"], IGNORE_INDEX)

            multi_turns = len(prompt_begin_list)
            total_length = 0

            if template is not None and template.efficient_eos:
                total_length = 1
                prompt_end_list = [x - 1 for x in prompt_end_list]
                eos_token_id = item["input_ids"][token_length - 1]
                item["input_ids"] = item["input_ids"][:token_length]
                item["labels"] = item["labels"][:token_length]

            cutoff_len = self.seq_length
            input_ids = np.array([], dtype=dtype)
            labels = np.array([], dtype=dtype)

            for turn_idx in range(multi_turns):
                if total_length >= cutoff_len:
                    break
                source_ids = item["input_ids"][prompt_begin_list[turn_idx]:prompt_end_list[turn_idx]]
                mask_ids = item["labels"][prompt_begin_list[turn_idx]:prompt_end_list[turn_idx]]

                label_begin_idx = prompt_end_list[turn_idx]

                if turn_idx != multi_turns - 1:
                    target_ids = item["labels"][label_begin_idx:prompt_begin_list[turn_idx + 1]]
                else:
                    target_ids = item["labels"][label_begin_idx:]

                source_len, target_len = _infer_seqlen(len(source_ids), len(target_ids), cutoff_len - total_length)

                source_ids = source_ids[:source_len]
                target_ids = target_ids[:target_len]
                mask_ids = mask_ids[:source_len]

                total_length += source_len + target_len
                input_ids = np.concatenate((input_ids, source_ids, target_ids), axis=0)
                labels = np.concatenate((labels, mask_ids, target_ids), axis=0)

            if template is not None and template.efficient_eos:
                input_ids = np.concatenate((input_ids, np.array([eos_token_id], dtype=dtype)), axis=0)
                labels = np.concatenate((labels, np.array([eos_token_id], dtype=dtype)), axis=0)

            res = {
                "input_ids": input_ids.astype(dtype),
                "attention_mask": np.ones_like(input_ids).astype(dtype),
                "labels": labels.astype(dtype)
            }

        else:
            prompt_ids = item["input_ids"]
            input_ids = prompt_ids[:self.seq_length]

            add_vals = {}
            for add_keys in self.args.dataset_additional_keys:
                if add_keys in item.keys():
                    add_vals[add_keys] = item[add_keys]

            res = dict(
                {
                    "input_ids": input_ids.astype(dtype),
                    "attention_mask": np.ones_like(input_ids).astype(dtype)
                }, **add_vals
            )

        return res

    def _cut_pairwise_token(self, item, dtype):
        """Cut prompt and response proportionally for pairwise datasets."""
        IGNORE_INDEX = -100
        prompt_length = (item["chosen_labels"] != IGNORE_INDEX).nonzero()[0][0]
        prompt_ids = item["chosen_input_ids"][:prompt_length]
        chosen_ids = item["chosen_input_ids"][prompt_length:]
        rejected_ids = item["rejected_input_ids"][prompt_length:]
        source_len, target_len = _infer_seqlen(
            len(prompt_ids), max(len(chosen_ids), len(rejected_ids)), self.seq_length
        )
        prompt_ids = prompt_ids[:source_len]
        chosen_ids = chosen_ids[:target_len]
        rejected_ids = rejected_ids[:target_len]

        chosen_input_ids = np.append(prompt_ids, chosen_ids)
        chosen_labels = np.append(IGNORE_INDEX * np.ones(source_len), chosen_ids)
        rejected_input_ids = np.append(prompt_ids, rejected_ids)
        rejected_labels = np.append(IGNORE_INDEX * np.ones(source_len), rejected_ids)

        res = {
            "chosen_input_ids": chosen_input_ids.astype(dtype),
            "chosen_attention_mask": np.ones_like(chosen_input_ids).astype(dtype),
            "chosen_labels": chosen_labels.astype(dtype),
            "rejected_input_ids": rejected_input_ids.astype(dtype),
            "rejected_attention_mask": np.ones_like(rejected_input_ids).astype(dtype),
            "rejected_labels": rejected_labels.astype(dtype)
        }

        return res