import torch
import datasets
import copy

IGNORE_INDEX = -100  # The default setting in CrossEntropyLoss

# no space after Answer: , seems no need here
PROMPT_TEMPLATE = "Question: {question}\nAnswer:"

def apply_prompt_template_arc(test_dataset, train_dataset, idx_mat):
    icl_list = []

    # convert sample to 
    train_dataset_inputs = []
    train_dataset_outputs = []
    test_dataset_inputs = []
    for d in train_dataset:
        train_dataset_inputs.append(PROMPT_TEMPLATE.format(question=d["question"]))
        train_dataset_outputs.append(d["choices"]["text"][d["answerKey"]])
    for d in test_dataset:
        test_dataset_inputs.append(PROMPT_TEMPLATE.format(question=d["question"]))

    for i, prompt in enumerate(test_dataset_inputs):
        idx_list = idx_mat[i]
        icl_examples = [train_dataset_inputs[i] + " " + train_dataset_outputs[i] for i in idx_list]
        icl_prompt = "\n\n".join(icl_examples)
        icl_prompt_whole = icl_prompt + "\n\n" + prompt
        data_point = {}
        data_point["icl_prompt"] = icl_prompt_whole
        data_point["answerKey"] = test_dataset[i]["answerKey"]
        data_point["choices"] = test_dataset[i]["choices"]
        data_point["options"] = test_dataset[i]["choices"]["text"]
        icl_list.append(data_point)
    
    return icl_list
        
# sample is a dictionary with keys "icl_prompt", "answerKey", "choices"    
def icl_tokenize_arc(sample, tokenizer, max_words=1024):
    # prompt = prompt_templete.format(
    #         question=sample["question"],
    #         sentence=sample["sentence"]
    #     )
    if_eos = 0
    example_list = []
    labels_list = []
    example_mask_list = []
    res_list = []
    prompt_text = sample["icl_prompt"]
    for choice in sample["options"]:

        example = prompt_text + " " + choice
        prompt = torch.tensor(
            tokenizer.encode(prompt_text), dtype=torch.int64
        )

        prompt_len = prompt.shape[0]

        example = tokenizer.encode(example)
        if(tokenizer.eos_token_id in example):
            if_eos = 1
        else:
            if_eos = 0

        example.append(tokenizer.eos_token_id)
        example = torch.tensor(
            example, dtype=torch.int64
        )

        padding = max_words - example.shape[0]

        if padding > 0:
            example_len = example.shape[0]

            example = torch.cat((example, torch.zeros(padding, dtype=torch.int64) - 1))
        elif padding < 0:
            print("enter truncate")
            example_len = max_words
            raise ValueError("padding < 0, please make the max_words larger")
            
        labels = copy.deepcopy(example)

        labels[: prompt_len-if_eos] = -1

        labels[example_len-1:] = -1

        label_mask = labels.ge(0)
        labels[~label_mask] = IGNORE_INDEX
        
        example_mask = example.ge(0)

        example[~example_mask] = 0
        example_mask = example_mask.float()
        label_mask = label_mask.float()

        example_list.append(example)
        labels_list.append(labels)
        example_mask_list.append(example_mask)

    return {
            "input_ids": torch.stack(example_list),
            "labels": torch.stack(labels_list),
            "attention_mask": torch.stack(example_mask_list),
            "gold": sample["answerKey"]
        }