import torch
import datasets
import copy

from .arc_prompt_utils import apply_prompt_template_arc, icl_tokenize_arc


def apply_icl_prompt(batch_size, test_dataset_name, train_dataset, test_dataset, idx_mat, tokenizer):
    # need to make sure that the size of idx_mat is |test_dataset| * k, where k is the number of in-context examples
    # the indices in idx_mat should range in [0, |train_dataset_inputs|)
    icl_tokenized_list = []
    if test_dataset_name in ["arc-easy", "arc-challenge", "hellaswag", "commonsense_qa", "sociali_qa", "openbook_qa", "commonsense_qa-all"] or "winogrande" in test_dataset_name or "hellaswag" in test_dataset_name:
        icl_prompt_list = apply_prompt_template_arc(test_dataset, train_dataset, idx_mat)

        words_list = []
        loop_num = len(icl_prompt_list) // batch_size + 1
        for i in range(loop_num):
            max_words = -1
            for j in range(i*batch_size, min((i+1)*batch_size, len(icl_prompt_list))):
                for choice in icl_prompt_list[j]["options"]:
                    sample_len = len(tokenizer.encode(icl_prompt_list[j]["icl_prompt"] + " " + choice))
                    if sample_len > max_words:
                        max_words = sample_len
            max_words += 8
            words_list.append(max_words)
            for j in range(i*batch_size, min((i+1)*batch_size, len(icl_prompt_list))):   
                tokenized_sample = icl_tokenize_arc(icl_prompt_list[j], tokenizer,max_words=max_words)
                icl_tokenized_list.append(tokenized_sample)
        print(f"words_list: {sum(words_list)/len(words_list)}")

    else:
        raise ValueError("Dataset not implemented.")
    
    icl_tokenized_dataset = datasets.Dataset.from_list(icl_tokenized_list)
    return icl_tokenized_dataset