import random

import torch

# TODO: technically this should be in the loader
# And here we just do padding


def convert_raw_data_to_model_format(
    tokenizer, max_length, question, answer, model_configs
):
    print(model_configs)
    question_start_token, question_end_token, answer_token = (
        model_configs["question_start_tag"],
        model_configs["question_end_tag"],
        model_configs["answer_tag"],
    )
    new_question = question_start_token + question + question_end_token
    new_answer = answer_token + answer
    full_text = new_question + new_answer
    # full_text = question + answer
    num_question_tokens = len(tokenizer.tokenize(new_question, add_special_tokens=True))

    encoded = tokenizer(
        full_text,
        add_special_tokens=True,
        max_length=max_length,
        truncation=True,
    )
    pad_length = max_length - len(encoded.input_ids)
    pad_input_ids = encoded["input_ids"] + [tokenizer.eos_token_id] * pad_length
    pad_attention_mask = encoded["attention_mask"] + [0] * pad_length
    if len(encoded.input_ids) == max_length:
        label = encoded.input_ids
    else:
        label = (
            encoded["input_ids"] + [tokenizer.eos_token_id] + [-100] * (pad_length - 1)
        )

    # change label to -100 for question tokens
    for i in range(num_question_tokens):
        label[i] = -100

    return (
        torch.tensor(pad_input_ids),
        torch.tensor(label),
        torch.tensor(pad_attention_mask),
    )


class DataCollatorQADPO:
    # this class just has the idk dataset that we prefer instead

    def __init__(self, tokenizer, max_length, model_configs):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.model_configs = model_configs
        self.idk = open("egu/dataset/idontknow.jsonl", "r").readlines()

    def __call__(self, batch):

        forget_input_ids = []
        forget_labels = []
        forget_attention_masks = []

        retain_input_ids = []
        retain_labels = []
        retain_attention_masks = []

        idk_input_ids = []
        idk_labels = []
        idk_attention_masks = []

        for pair in batch:
            # Process forget
            f_prompt = pair["forget"]["prompt_formatted"]
            f_answer = pair["forget"]["answer"]
            # f_answer = random.choice(self.idk)
            f_input_ids, f_labels, f_attention_mask = convert_raw_data_to_model_format(
                self.tokenizer, self.max_length, f_prompt, f_answer, self.model_configs
            )
            forget_input_ids.append(f_input_ids)
            forget_labels.append(f_labels)
            forget_attention_masks.append(f_attention_mask)

            # Process retain
            r_prompt = pair["retain"]["prompt_formatted"]
            r_answer = pair["retain"]["answer"]
            r_input_ids, r_labels, r_attention_mask = convert_raw_data_to_model_format(
                self.tokenizer, self.max_length, r_prompt, r_answer, self.model_configs
            )
            retain_input_ids.append(r_input_ids)
            retain_labels.append(r_labels)
            retain_attention_masks.append(r_attention_mask)

            i_prompt = pair["forget"]["prompt_formatted"]
            # f_answer = pair["forget"]["answer"]
            i_answer = random.choice(self.idk)
            i_input_ids, i_labels, i_attention_mask = convert_raw_data_to_model_format(
                self.tokenizer,
                self.max_length,
                i_prompt,
                i_answer,
                self.model_configs,
            )
            idk_input_ids.append(i_input_ids)
            idk_labels.append(i_labels)
            idk_attention_masks.append(i_attention_mask)

        # Stack into batch tensors
        batch = {
            "forget_input_ids": torch.stack(forget_input_ids),
            "forget_labels": torch.stack(forget_labels),
            "forget_attention_mask": torch.stack(forget_attention_masks),
            "retain_input_ids": torch.stack(retain_input_ids),
            "retain_labels": torch.stack(retain_labels),
            "retain_attention_mask": torch.stack(retain_attention_masks),
            "idk_input_ids": torch.stack(idk_input_ids),
            "idk_labels": torch.stack(idk_labels),
            "idk_attention_mask": torch.stack(idk_attention_masks),
        }

        return batch
