import os
os.environ["CUDA_VISIBLE_DEVICES"] = "6"

import torch
import argparse
import evaluate

import pandas as pd

from dataset.utils import get_dataloader
from model import get_model, get_tokenizer
from algorithms import get_retriever
from inference import get_inferencer
from common import get_prompt_label, setup_seed, get_input




def main(args):
    setup_seed(100)
    device = torch.device("cuda")
    model, tokenizer = get_model(args.pretrained_model_name), get_tokenizer(args.pretrained_model_name)
    model.eval()


    acc = []
    for seed in [100]:
        setup_seed(seed)
        #####get data#####
        ice_dataset, test_dataset = get_dataloader(args.task, args.imbalance_type, args.imbalance_ratio)
        
        ####Selection#####
        inferencer = get_inferencer('ppl', model_name=model, tokenizer_name = tokenizer, device = device, batch_size=args.batch_size)
        retriever = get_retriever(args.test_retrieving, args.task, ice_dataset, test_dataset, inferencer, device, model, tokenizer)
        ice_idx_list = retriever.retrieve(args.im_retrieving, args.ice_num)
        template, template_dict, label = get_prompt_label(args.task)
        ice = get_input(args.task, ice_idx_list, template, template_dict, ice_dataset)

        #####Inference#####
        test_predictions = inferencer.inference(task=args.task, ice=ice,  candidate=test_dataset['text'], labels=list(range(len(label))), ice_template=template_dict)
        acc_evaluate = evaluate.load('accuracy')
        acc.append(acc_evaluate.compute(predictions=test_predictions, references=test_dataset['label']))
        
    print(acc)
 


if __name__ == '__main__':
    
    parser = argparse.ArgumentParser()
    #task and prompt
    parser.add_argument('--task', type=str, choices=['agnews'], default='agnews', help='task.')
    
    #retriever
    parser.add_argument('--test_retrieving', type=str, choices=['topk', 'rm'], default='topk', help='Choose demonstration selection method.')

    parser.add_argument('--im_retrieving', type=str, choices=['naive', 'rif'], default='rif', help='Choose imbalance retriever.')
    parser.add_argument('--ice_num',  type=int, default=16)

    #noise label
    parser.add_argument('--imbalance_ratio', type=int, default=0.01, help='imbalance ratio.')
    parser.add_argument('--imbalance_type', type=str, choices=['exp', "real"], default="exp", help='imbalance type.')

    #model
    parser.add_argument('--pretrained_model_name', '-m', choices=['opt'], type=str, default='opt', help='Choose pretrained model.')
    
    #others
    parser.add_argument('--batch_size', type=int, default=1, help='Test batch size.')
    args = parser.parse_args()
    main(args)

