import argparse
from utils import *

def main():
    args = parse_arguments()    
    fix_seed(args.random_seed)
    
    ## disable bert model for NLI
    if args.disable_nli:
        pass
    else:
        debert_classify = ClassifyWrapper(dir=args.transformer_dir)
        
    
    
    # prepare the dataloader
    if args.question_json_dir == None:
        dataloader = [[args.question, args.answer]]
    else:
        dataloader = data_reader(args)



    ## if generate one answer, set temp=0.0, if generate mulitple answers, set temp=1.0
    if args.num_sample > 1:
        temp_ref = args.temperature
    else:
        temp_ref = 0.0



    ## engine selection: gpt3, chatgpt, gpt4
    if args.engine in ('gpt4'):
        api = async_gpt4_generate  
    elif args.engine in ('chatgpt'):
        api = async_chatgpt_generate 
    elif args.engine in ('gpt3'):
        api = async_gpt3_generate
    elif args.engine in ('llama2'):
        llama_wrapper = llama_initial(args.llama2_dir)
        api = llama_wrapper.batch_pred
    else:
        raise ValueError("Define your API wrapper...")



    total = 0
    save_list = []

    for data in dataloader:
        
        try: 
            print('******************************************************')
            total += 1 
            print("{}st data".format(total))
            print('******************************************************')
            question, answer = data[0], data[1]
            

    
            ## prepare the prompt
            task_prompt = 'Please strictly use the following template to provide answer: explanation: [insert step-by-step analysis], \nanswer: [provide your answer].'
            option_prompt = ['Is the proposed answer: (A) Correct (B) Incorrect. The output should strictly use the following template: explanation: [insert analysis], \nanswer: [choose one letter between choices A and B]',
                            'Are you really sure the proposed answer is correct? Choose again: (A) Correct (B) Incorrect. The output should strictly use the following template: explanation: [insert analysis], \nanswer: [choose one letter between choices A and B]']


            # Prepare questieon template and generate the reference answer...
            x_org = "Question: " + question + "\n" + args.direct_answer_trigger
            if answer == None:
                reply = api([x_org for _ in range(args.num_sample)], temp=temp_ref)
                ans_ref = [answer_cleansing(args, _) for _ in reply]
            else:
                ans_ref = [answer]
    

            ## observed consistency
            prompt = [task_prompt + "\n" + x_org for _ in range(args.num_self_consistency)]
            reply = api(prompt, temp=args.temperature)
            ans = [ans_uq_cleansing(args, _) for _ in reply]

            
            ## for each candidate answer, get the confidence score
            save_one_data = []
            for ans_ref_one in ans_ref: 

                ## if not use debert, will use jaccard_similarity
                if args.disable_nli:
                    consistency_score_average = np.average([jaccard_similarity(_, ans_ref_one) for _ in ans])
                else: 
                    debert_score = debert_classify.batch_pred(args.direct_answer_trigger, ans_ref_one, ans).squeeze()
                    consistency_score_average = np.average(debert_score)

                    
                ## indicator function  
                indicator_score = [float(i==ans_ref_one) for i in ans]
                indicator_score_average = np.average(indicator_score)

                ## self-reflection
                prompt = [x_org + "\n" + ans_ref_one + "\n" + _ for _ in option_prompt]
                reply = api(prompt, temp=0.0)
                option_score = [option_uq_cleansing(_) for _ in reply]
                verbalized_score_average = np.average(option_score)

                ## ensemble all scores
                confidence_score = args.weight1 * consistency_score_average + (1-args.weight1) * indicator_score_average
                confidence_score = args.weight2 * confidence_score + (1-args.weight2) * verbalized_score_average

                ## save
                save_one_data.append([confidence_score, ans_ref_one])
                

            ## save question and answer to dict
            data_dict = {
                'id': total,
                'Question': question.replace('"', "'").replace("\n", " "),
            }
            for k in range(args.num_sample):
                data_dict[f'Answer{k}'] = save_one_data[k][1].replace('"', "'").replace("\n", " ")
                data_dict[f'Confidence{k}'] = save_one_data[k][0]
            save_list.append(data_dict) 
            
            ## print and save final results
            print_result(args, save_one_data, answer)
            save_resuls(args, save_list, list(data_dict.keys()))


            
        except Exception as e:
            print(e)
            print("Please Wait! Skip this sample.")
            pass
   



def parse_arguments():
    parser = argparse.ArgumentParser(description="BSdetecor")
  
    parser.add_argument(
        "--save_log", type=str, default='./save_log', help="log directory"
    )  
    parser.add_argument(
        "--transformer_dir", type=str, default='./', help="log directory"
    )
    parser.add_argument(
        "--llama2_dir", type=str, default='./', help="llama directory"
    )
    parser.add_argument(
        "--random_seed", type=int, default=1, help="random seed"
    )
    parser.add_argument(
        "--question", type=str, default=None
    )
    parser.add_argument(
        "--question_json_dir", type=str, default=None
    )
    parser.add_argument(
        "--question_type", type=str, default="open_form", choices=["mathqa", "multiple_choice", "open_form"], help="the type of question"
    )
    parser.add_argument(
        "--answer", type=str, default=None
    )
    parser.add_argument(
        "--direct_answer_trigger", type=str, default="Therefore, the answer is"
    )
    parser.add_argument(
        "--engine", type=str, default="chatgpt", choices=["chatgpt", "gpt3", "gpt4", "llama2"], help="api"
    )
    parser.add_argument(
        "--weight1", type=float, default=0.9, help=""
    )
    parser.add_argument(
        "--weight2", type=float, default=0.6, help=""
    )
    parser.add_argument(
        "--num_self_consistency", type=int, default=5, help="number of self-consistency"
    )
    parser.add_argument(
        "--num_sample", type=int, default=1, help="number of samples for generating candidate answer"
    )
    parser.add_argument(
        "--temperature", type=float, default=1.0, help=""
    )
    parser.add_argument(
        "--disable_nli", action="store_true"
    )
    
    args = parser.parse_args()    
    return args

if __name__ == "__main__":
    main()
