import numpy as np
import torch
from transformers import DataCollatorForLanguageModeling

INSTRUCTION_TEMPLATE = "### Human:\n"
RESPONSE_TEMPLATE = "### Response:\n"
IGNORE_INDEX = -100


class MaskedDataCollatorForLM(DataCollatorForLanguageModeling):
    def __init__(self, tokenizer, is_weight_by_token, is_completion_only):
        super().__init__(tokenizer=tokenizer, mlm=False, pad_to_multiple_of=4)
        self.instruction_template = INSTRUCTION_TEMPLATE
        self.instruction_token_ids = self.tokenizer.encode(
            self.instruction_template, add_special_tokens=False
        )
        self.response_template = RESPONSE_TEMPLATE
        self.response_token_ids = self.tokenizer.encode(
            self.response_template, add_special_tokens=False
        )
        self.ignore_index = IGNORE_INDEX
        self.is_weight_by_token = is_weight_by_token
        self.mask_dict = None
        self.is_completion_only = is_completion_only

    def super_touch_call(self, examples):
        batch = super().torch_call(examples)
        return batch

    def get_indices(self, batch):
        start_indices = []
        for i in range(len(batch["labels"])):
            instruction_token_ids_start_idx = None
            for idx in np.where(batch["labels"][i] == self.instruction_token_ids[0])[0]:
                if (
                    self.instruction_token_ids
                    == batch["labels"][i][
                        idx : idx + len(self.instruction_token_ids)
                    ].tolist()
                ):
                    instruction_token_ids_start_idx = idx
            if instruction_token_ids_start_idx is None:
                # Skip first token cause that's not learnable anyway
                start_indices.append(1)
            else:
                start_indices.append(max(instruction_token_ids_start_idx, 1))
        end_indices = []
        for i in range(len(batch["labels"])):
            response_token_ids_start_idx = None
            for idx in np.where(batch["labels"][i] == self.response_token_ids[0])[0]:
                if (
                    self.response_token_ids
                    == batch["labels"][i][
                        idx : idx + len(self.response_token_ids)
                    ].tolist()
                ):
                    response_token_ids_start_idx = idx
            if response_token_ids_start_idx is None:
                end_indices.append(len(batch["labels"][i]))
            else:
                end_indices.append(
                    response_token_ids_start_idx + len(self.response_token_ids)
                )
        return start_indices, end_indices

    def set_token_mask(self, dataset, t):
        if not self.is_weight_by_token:
            raise RuntimeError("Cannot set token mask when not using token weights.")
        self.mask_dict = {
            hash(tuple(np.int32(ids))): [0, []] for ids in dataset["sft_ids"]
        }
        for ids, adaboost_count in zip(
            dataset["sft_ids"],
            dataset[f"adaboost_count_{t}"],
        ):
            self.mask_dict[hash(tuple(np.int32(ids)))][1].append(
                torch.BoolTensor(~np.array(adaboost_count, dtype=bool))
            )

    def torch_call(self, examples):
        batch = self.super_touch_call(examples)
        start_indices, end_indices = self.get_indices(batch)
        indices = end_indices if self.is_completion_only else start_indices

        if self.is_weight_by_token:
            if self.mask_dict is None:
                raise RuntimeError("Mask dict must be set when using token weights.")
            for i in range(len(examples)):
                ids = batch["labels"][i, start_indices[i] :]
                hash_code = hash(tuple(np.int32(ids)))
                if hash_code in self.mask_dict:
                    mask = self.mask_dict[hash_code][1][self.mask_dict[hash_code][0]]
                    self.mask_dict[hash_code][0] = (
                        self.mask_dict[hash_code][0] + 1
                    ) % len(self.mask_dict[hash_code][1])
                    batch["labels"][i, indices[i] :][mask] = self.ignore_index

        for i in range(len(examples)):
            batch["labels"][i, : indices[i]] = self.ignore_index

        return batch
