import random

import torch


class DataCollatorQAPair:
    def __init__(self, tokenizer, max_length, model_configs):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.model_configs = model_configs

    def __call__(self, batch):
        forget_input_ids = []
        forget_labels = []
        forget_attention_masks = []

        retain_input_ids = []
        retain_labels = []
        retain_attention_masks = []

        for pair in batch:
            # Process forget
            f_prompt = pair["forget"]["prompt_formatted"]
            f_answer = pair["forget"]["answer"]
            print("the current f_prompt")
            print(f_prompt)
            f_input_ids, f_labels, f_attention_mask = convert_raw_data_to_model_format(
                self.tokenizer, self.max_length, f_prompt, f_answer, self.model_configs
            )
            forget_input_ids.append(f_input_ids)
            forget_labels.append(f_labels)
            forget_attention_masks.append(f_attention_mask)

            # Process retain
            r_prompt = pair["retain"]["prompt_formatted"]
            r_answer = pair["retain"]["answer"]
            r_input_ids, r_labels, r_attention_mask = convert_raw_data_to_model_format(
                self.tokenizer, self.max_length, r_prompt, r_answer, self.model_configs
            )
            retain_input_ids.append(r_input_ids)
            retain_labels.append(r_labels)
            retain_attention_masks.append(r_attention_mask)

        # Stack into batch tensors
        batch = {
            "forget_input_ids": torch.stack(forget_input_ids),
            "forget_labels": torch.stack(forget_labels),
            "forget_attention_mask": torch.stack(forget_attention_masks),
            "retain_input_ids": torch.stack(retain_input_ids),
            "retain_labels": torch.stack(retain_labels),
            "retain_attention_mask": torch.stack(retain_attention_masks),
        }

        return batch
