import torch
import datasets
import copy
    
from .arc_prompt_utils import apply_prompt_template_arc, icl_tokenize_arc


def apply_icl_prompt(batch_size, test_dataset_name, train_dataset, test_dataset, idx_mat, tokenizer):

    icl_tokenized_list = []
    
    icl_prompt_list = apply_prompt_template_arc(test_dataset, train_dataset, idx_mat)

    words_list = []
    loop_num = len(icl_prompt_list) // batch_size + 1
    for i in range(loop_num):
        max_words = -1
        for j in range(i*batch_size, min((i+1)*batch_size, len(icl_prompt_list))):
            for choice in icl_prompt_list[j]["options"]:
                sample_len = len(tokenizer.encode(icl_prompt_list[j]["icl_prompt"] + " " + choice))
                if sample_len > max_words:
                    max_words = sample_len
        max_words += 8
        words_list.append(max_words)
        for j in range(i*batch_size, min((i+1)*batch_size, len(icl_prompt_list))):   
            tokenized_sample = icl_tokenize_arc(icl_prompt_list[j], tokenizer,max_words=max_words)
            icl_tokenized_list.append(tokenized_sample)
    print(f"words_list: {sum(words_list)/len(words_list)}")

    
    icl_tokenized_dataset = datasets.Dataset.from_list(icl_tokenized_list)
    return icl_tokenized_dataset
