from utils import *
import time
import argparse
import torch.nn.functional as F
import sys
import json
import os
from sklearn.metrics import pairwise
from tqdm import tqdm
import random
from sentence_transformers import SentenceTransformer
from transformers import AutoModel, AutoTokenizer,AutoModelForCausalLM
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from reverse_engineer import *
from context_prompt import *
import os



def main():
    # load arguments from terminal
    args = arg_parser()
    print('*****************************')
    print(args)
    print('*****************************')

    print(f"API_KEY: {API_KEY}")

    set_random_seed(args.random_seed)


    # load dataset
    dataloader = create_dataloader(args)
    dataloader=random.sample(dataloader, k=args.test_size)
    print(f"dataloader size: {len(dataloader)}")


    if args.method == "few_shot":
        input_prompt = create_input_prompt(args, cot_flag=False)
    elif args.method == "few_shot_cot" or args.method == "auto_cot" or args.method == "active_cot":
        input_prompt = create_input_prompt(args, cot_flag=True)
        input_prompt_list=input_prompt.split('\n\n')[:-1][:args.shot]
        input_prompt='\n\n'.join(input_prompt_list) +'\n\n'
        if args.dataset == "coin_flip":
            reverse_prompt = coin_prompt()
        elif args.dataset =="object_tracking":
            reverse_prompt =object_prompt()
        else:
            reverse_prompt = generate_prompt()
        segment_context = segment_text()

    elif args.method == "zero_shot_cot":
        input_prompt=''
        if args.dataset == "coin_flip":
            reverse_prompt = coin_prompt()
        elif args.dataset =="object_tracking":
            reverse_prompt =object_prompt()
        else:
            reverse_prompt = generate_prompt()
        segment_context = segment_text()

    else:
        raise NotImplementedError

    start = time.time()
    print("Inference Start")
    if args.multipath != 1:
        print("Self-consistency Enabled, output each inference result is not available")
    # no limit on how many batches to inference, assume inference all batches
    
    final_result,double_check_final_result = inference_cot(args, dataloader, args.qes_limit, input_prompt,reverse_prompt,segment_context)
    end = time.time()
    print(f"Execution time: {end - start} seconds")




def inference_cot(args, question_pool, qes_limit, given_prompt,reverse_prompt,segment_context):
    all_list = []
    first_error_without_second_list = []
    error_to_error_list = []
    error_to_right_list = []
    right_to_error_list = []
    right_to_right_list = []

    final_result={}
    final_result['first_correct'] = 0
    final_result['first_error_without_second'] = 0  
    final_result['final_correct'] = 0
    final_result['right_to_right_count'] = 0  
    final_result['error_to_error_count'] = 0  
    final_result['error_to_right_count'] = 0  
    final_result['right_to_error_count'] = 0  

    double_check_final_result={}
    double_check_final_result['right_to_right_count'] = 0  
    double_check_final_result['error_to_error_count'] = 0  
    double_check_final_result['error_to_right_count'] = 0  
    double_check_final_result['right_to_error_count'] = 0  
    double_check_final_result['final_correct'] = 0

    qes_count = 0
    QA_record = []
    for qes_num, qes in enumerate(question_pool):
        if qes_limit is not None and qes_count == qes_limit:
            break
        second_ans = ""
        second_pred_ans = ""
        revise_prompt = ""
        exact_pred_answer_responses = ""
        exact_second_pred_ans = ""
        if args.dataset == "last_letters" and args.use_code_style_prompt is True:
            # code style prompt
            prompt = given_prompt + "Q: " + qes['question'] + "\nA: Let's think step by step in Python."
        else:
            prompt = given_prompt + "Q: " + qes['question'] + "\nA: Let's think step by step."
        # enable self-consistency if multipath > 1
        for path in range(0, args.multipath):
            if args.dataset =="object_tracking":
                qes['question']=qes['question'].replace("\n\n","")
            reverse_temp=reverse_prompt
            #get question answer
            answer_responses,exact_pred_answer_responses=get_answer(args,qes['question'],given_prompt)
            pred_ans = answer_extraction(args, exact_pred_answer_responses)

            
            #reverse generate question
            generate_question_text=reverse_get_question(args,answer_responses,reverse_temp)

            #segment question text to condition and question
            ori_condition_list,ori_ask=segment_question(args,qes['question'],segment_context)
            generate_condition_list,gen_ask=segment_question(args,generate_question_text,segment_context)
            
            #one by one judge
            ignored_condition=[]
            inconsistent_condition=[]
            ori_to_gen=[]
            gen_to_ori=[]
            #ori_conditions to generate_conditions
            for condition in ori_condition_list:
                if args.prompt=="condition list":
                    result = compare_condition_condition_list(args,condition,generate_condition_list,ori_to_gen=True,flag=args.flag)
                else:
                    result = compare_condition_condition_list(args,condition,generate_question_text,ori_to_gen=True,flag=args.flag)
                ori_to_gen.append(result)
                if result['answer'] == 'no':
                    ignored_condition.append(result)

            #generate_conditions to ori_conditions
            for condition in generate_condition_list:
                if args.prompt=="condition list":
                    result = compare_condition_condition_list(args,condition,ori_condition_list,ori_to_gen=False,flag=args.flag)
                else:
                    result = compare_condition_condition_list(args,condition,qes['question'],ori_to_gen=False,flag=args.flag)
                gen_to_ori.append(result)
                if result['answer'] == 'no':
                    inconsistent_condition.append(result)

            #ori_question to generate_question
            
            compare_question_result = compare_question(args,qes['question'],generate_question_text,flag=args.flag)
            
            previous_info=[prompt,answer_responses]
            if len(ignored_condition) == 0 and len(inconsistent_condition) == 0 and compare_question_result['answer'] == 'yes':
                first_answer_judge = 'correct'
            else:
                first_answer_judge = 'error'
                #reflex
                second_ans,revise_prompt,exact_second_pred_ans = reflex(args,ignored_condition,inconsistent_condition,compare_question_result,ori_condition_list,generate_condition_list,qes['question'],generate_question_text,previous_info)
                second_pred_ans = answer_extraction(args, exact_second_pred_ans)

            #double_check
            double_check_ans, exact_double_check_ans = double_check(args, previous_info)
            pred_double_check_ans = answer_extraction(args,  exact_double_check_ans)

            # create a dict to record each Q&A for later review purposes
            QA = {}
            QA['qes_idx'] = qes['question_idx']
            QA['Q'] = qes['question']
            QA['first_A'] = answer_responses
            QA['second_A'] = second_ans
            QA['double_check_A'] = double_check_ans
            QA['G'] = generate_question_text
            QA['first_exact_answer'] = exact_pred_answer_responses
            QA['second_exact_answer'] = exact_second_pred_ans
            QA['double_check_exact_answer'] = exact_double_check_ans
            QA['first_pre_ans'] = pred_ans
            QA['second_pre_ans'] = second_pred_ans
            QA['double_check_ans'] = pred_double_check_ans
            QA['GT'] = qes['answer']
            QA['ori_condition_list'] = ori_condition_list
            QA['generate_condition_list'] = generate_condition_list
            QA['ori_to_gen'] = ori_to_gen
            QA['gen_to_ori'] = gen_to_ori
            QA['ignored_condition'] = ignored_condition
            QA['inconsistent_condition'] = inconsistent_condition
            QA['compare_question_result'] = compare_question_result
            QA['revise_prompt'] = revise_prompt
            QA_record.append(QA)

            all_list.append(list(QA.values()))

            restore_path = f"{args.output_dir}/{args.dataset}-{args.random_seed}-{args.test_size}/all.txt"
            os.makedirs(f"{args.output_dir}/{args.dataset}-{args.random_seed}-{args.test_size}", exist_ok=True)
            with open(restore_path, 'a+') as f:
                f.write(json.dumps(QA, indent=4))
                f.write('\n')

            print(f"Question number: {qes_num}")
        

        #reverse_method
        if second_pred_ans == "":
            if pred_ans == qes['answer']:
                final_result['first_correct'] += 1
                final_result['final_correct'] += 1
            else:
                final_result['first_error_without_second'] += 1
                restore_path = f"{args.output_dir}/{args.dataset}-{args.random_seed}-{args.test_size}/first_error_without_second.txt"
                with open(restore_path, 'a+') as f:
                    f.write(json.dumps(QA, indent=4))
                    f.write('\n')
                first_error_without_second_list.append(list(QA.values()))
  
        else:  
            #right to right
            if pred_ans == qes['answer'] and second_pred_ans == qes['answer']:
                final_result['right_to_right_count'] += 1
                final_result['final_correct'] += 1
                restore_path = f"{args.output_dir}/{args.dataset}-{args.random_seed}-{args.test_size}/right_to_right.txt"
                with open(restore_path, 'a+') as f:
                    f.write(json.dumps(QA, indent=4))
                    f.write('\n')
                right_to_right_list.append(list(QA.values()))

            #right to error
            elif pred_ans == qes['answer'] and second_pred_ans != qes['answer']:
                final_result['right_to_error_count'] += 1
                restore_path = f"{args.output_dir}/{args.dataset}-{args.random_seed}-{args.test_size}/right_to_error.txt"
                with open(restore_path, 'a+') as f:
                    f.write(json.dumps(QA, indent=4))
                    f.write('\n')
                right_to_error_list.append(list(QA.values()))

            #error to error
            elif pred_ans != qes['answer'] and second_pred_ans != qes['answer']:
                final_result['error_to_error_count'] += 1
                restore_path = f"{args.output_dir}/{args.dataset}-{args.random_seed}-{args.test_size}/error_to_error.txt"
                with open(restore_path, 'a+') as f:
                    f.write(json.dumps(QA, indent=4))
                    f.write('\n')
                error_to_error_list.append(list(QA.values()))

            #error to right
            elif pred_ans != qes['answer'] and second_pred_ans == qes['answer']:
                final_result['error_to_right_count'] += 1
                final_result['final_correct'] += 1
                restore_path = f"{args.output_dir}/{args.dataset}-{args.random_seed}-{args.test_size}/error_to_right.txt"
                with open(restore_path, 'a+') as f:
                    f.write(json.dumps(QA, indent=4))
                    f.write('\n')
                error_to_right_list.append(list(QA.values()))

            else:
                pass

        #double_check method
        if pred_ans == qes['answer'] and pred_double_check_ans == qes['answer']:
            double_check_final_result['right_to_right_count'] += 1
            double_check_final_result['final_correct'] += 1

        elif pred_ans == qes['answer'] and pred_double_check_ans != qes['answer']:
            double_check_final_result['right_to_error_count'] += 1

        elif pred_ans != qes['answer'] and pred_double_check_ans != qes['answer']:
            double_check_final_result['error_to_error_count'] += 1

        elif pred_ans != qes['answer'] and pred_double_check_ans == qes['answer']:
            double_check_final_result['error_to_right_count'] += 1
            double_check_final_result['final_correct'] += 1
            
        else:
            pass
            

        qes_count += 1

    restore_path = f"{args.output_dir}/{args.dataset}-{args.random_seed}-{args.test_size}/all.txt"
    with open(restore_path, 'a+') as f:
        f.write(json.dumps(final_result, indent=4))
        f.write('\n')
        f.write(json.dumps(double_check_final_result, indent=4))
        f.write('\n')
    return final_result,double_check_final_result


def arg_parser():
    parser = argparse.ArgumentParser(description="CoT")
    parser.add_argument("--random_seed", type=int, default=1, help="random seed")
    parser.add_argument(
        "--dataset", type=str, default="gsm8k", choices=["gsm8k","svamp", "aqua", "csqa", "asdiv", "last_letters", "addsub", "singleeq", "strategyqa", "multiarith","coin_flip","bigbench_date","object_tracking"], help="dataset to inference"
    )
    parser.add_argument(
        "--prompt_path", type=str, default="./inference_prompts/gsm8k_k=10", help="prompts to use"
    )
    parser.add_argument(
        "--model", type=str, default="code-davinci-002", choices=["text-davinci-002", "code-davinci-002",'gpt-3.5-turbo','vicuna-13b','alpaca-30b'], help="model used for decoding."
    )
    parser.add_argument(
        "--method", type=str, default="active_cot", choices=["zero_shot", "zero_shot_cot", "few_shot", "few_shot_cot", "auto_cot", "active_cot"], help="method"
    )
    parser.add_argument(
        "--output_dir", type=str, default="./results", help="output directory"
    )
    parser.add_argument(
        "--max_length_cot", type=int, default=512, help="maximum length of output tokens by model for reasoning extraction"
    )
    parser.add_argument(
        "--qes_limit", type=int, default=0, help="whether to limit test dataset size. if 0, the dataset size is unlimited and we use all the samples in the dataset for testing."
    )
    parser.add_argument(
        "--api_time_interval", type=float, default=1.0, help="how many seconds to sleep between each request"
    )
    parser.add_argument(
        "--temperature", type=float, default=0, help=""
    )
    parser.add_argument(
        "--multipath", type=int, default=1, help="self-consistency path num"
    )
    parser.add_argument(
        "--concat_length", type=int, default=4, help='Used for task last_letters, indicates length of last letter to concat, i.e. Elon Musk -> nk, use concat length of 2'
    )
    parser.add_argument(
        "--use_code_style_prompt", type=bool, default=False, help='Use code-style prompt as mentioned in paper for last_letters dataset'
    )
    parser.add_argument(
        "--test_size", type=int, default=1000, help='test size'
    )
    parser.add_argument(
        "--shot", type=int, default=4, help='n-shot for in context learning'
    )
    parser.add_argument(
        "--load_model", type=str, default="", help=''
    )
    parser.add_argument(
        "--tokenizer", type=str, default="", help=''
    )
    parser.add_argument(
        "--flag", type=str, default="gpt", help=''
    )
    parser.add_argument(
        "--file_idx", type=str, default="", help=''
    )
    parser.add_argument(
        "--api_num", type=int, default=0, help=''
    )
    parser.add_argument(
        "--prompt", type=str, choices=["condition list","context"], default="context", help=''
    )
    args = parser.parse_args()

    if args.multipath > 1:
        args.temperature = 0.7
    else:
        args.temperature = 0
    print(f"Temperature: {args.temperature}")
    
    if args.dataset == "gsm8k":
        args.dataset_path = "./dataset/GSM8K/test.jsonl"
        args.direct_answer_trigger = "\nTherefore, the answer (arabic numerals) is"
    elif args.dataset == "svamp":
        args.dataset_path = "./dataset/SVAMP/SVAMP.json"
        args.direct_answer_trigger = "\nTherefore, the answer (arabic numerals) is"
    elif args.dataset == "asdiv":
        args.dataset_path = "./dataset/ASDiv/ASDiv.json"
        args.direct_answer_trigger = "\nTherefore, the answer (arabic numerals) is"
    elif args.dataset == "aqua":
        args.dataset_path = "./dataset/AQuA/test.json"
        args.direct_answer_trigger = "The answer is"
    elif args.dataset == "csqa":
        args.dataset_path = "./dataset/CSQA/dev_rand_split.jsonl"
        args.direct_answer_trigger = "So the answer is"
    elif args.dataset == "strategyqa":
        args.dataset_path = "./dataset/strategyQA/task.json"
        args.direct_answer_trigger = "\nTherefore, the answer (Yes or No) is"
    elif args.dataset == "last_letters":
        args.dataset_path = "./dataset/last_letters/last_letters_test.json"
        args.direct_answer_trigger = "\nTherefore, the answer is"
    elif args.dataset == "addsub":
        args.dataset_path = "./dataset/MAWPS/AddSub.json"
        args.direct_answer_trigger = "\nTherefore, the answer (arabic numerals) is"
    elif args.dataset == "singleeq":
        args.dataset_path = "./dataset/MAWPS/SingleEq.json"
        args.direct_answer_trigger = "\nTherefore, the answer (arabic numerals) is"
    elif args.dataset == "multiarith":
        args.dataset_path = "./dataset/MAWPS/MultiArith.json"
        args.direct_answer_trigger = "\nTherefore, the answer (arabic numerals) is"
    elif args.dataset =="coin_flip":
        args.dataset_path = "./dataset/coin_flip/coin_flip.json"
        args.direct_answer_trigger = "\nTherefore, the answer (Yes or No) is"
    elif args.dataset == "bigbench_date":
        args.dataset_path = "./dataset/Bigbench_Date/task.json"
        args.direct_answer_trigger = "\nTherefore, among A through F, the answer is"
    elif args.dataset == "object_tracking":
        args.dataset_path = "./dataset/Bigbench_object_tracking/task.json"
        args.direct_answer_trigger = "\nTherefore, among A through C, the answer is"
    else:
        raise ValueError("dataset is not properly defined ...")
        
    trigger = args.direct_answer_trigger.replace("\nTherefore, ", "")
    args.direct_answer_trigger_for_zeroshot = trigger[0].upper() + trigger[1:]
    args.direct_answer_trigger_for_zeroshot_cot = args.direct_answer_trigger
    args.direct_answer_trigger_for_fewshot = "Therefore, the final answer is"
    args.cot_trigger = "Let's think step by step."
    
    return args


if __name__ == "__main__":
    main()