import threading
import argparse
import os
from ARCA import *
from models import *

def main():
    parser = argparse.ArgumentParser(description='AQUA dataset')

    parser.add_argument('--func', type=str, default='ARCA', help='Function to execute')
    args = parser.parse_args()

    run(args, 1)

def run(args,id):

    print('*****************************')
    print(args)
    print('*****************************')

    log_file = f"logs/{args.func}/{args.begin_task}-{args.end_task}_generate{args.n_generate_sample}_select{args.n_select_sample}_evaluate{args.n_evaluate_time}_shuffle{args.shuffle}({id}).json"
    result_file = f"results/{args.func}.jsonl"
    os.makedirs(os.path.dirname(log_file), exist_ok=True)
    os.makedirs(os.path.dirname(result_file), exist_ok=True)

    with open(args.dataset_path, "r", encoding="utf-8") as input_file:
        lines = input_file.readlines()
    total_num = args.end_task - args.begin_task
    chunks = [
        lines[args.begin_task+i : args.begin_task+i + total_num // args.thread_n]
        for i in range(0, total_num, total_num // args.thread_n)
    ]
    dics = [{} for _ in range(args.thread_n)]

    threads = []
    for i in range(args.thread_n):
        thread = threading.Thread(target=process_data, args=(args, chunks[i], dics, i))
        threads.append(thread)
        thread.start()
    for thread in threads:
        thread.join()

    result_dict = {}
    result_dict['id']=id
    result_dict["begin"]=args.begin_task
    result_dict["end"]=args.end_task
    result_dict["generate"]=args.n_generate_sample
    result_dict["select"]=args.n_select_sample
    result_dict["evaluate"]=args.n_evaluate_time
    result_dict['correct_list']=[]
    result_dict['correct_num']=0
    result_dict.update(gpt_usage())
    infos=[]
    for d in dics:
        infos.extend(d['infos'])
        result_dict['correct_list'].extend(d['correct_list'])
    result_dict['correct_num']=len(result_dict['correct_list'])
    result_dict['accuracy'] = result_dict['correct_num'] * 1.0 / total_num
    print("Accuracy:",result_dict['accuracy'])
    
    with open(log_file, 'w') as f:
        json.dump(infos, f, indent=4)
    with open(result_file, 'a') as f:
        f.write(json.dumps(result_dict) + "\n")

def process_data(args, lines, dics, index):
    i=0
    thread_n=len(lines)
    infos=[]
    correct_list=[]
    processed_count = 0
    correct_count = 0
    for line in lines:
        data = json.loads(line)
        choice = "(" + "(".join(data["options"])
        choice = choice.replace("(", " (").replace(")", ") ")
        choice = "Answer Choices:" + choice
        question = data["question"].strip() + " " + choice + "\nSteps:\n"
        standard_answer = data["correct"]

        print('*************************')
        print(f"No. {i+thread_n*index} data")
        if args.func == 'ARCA':
            answer, info = solve(question, args, to_print=False)

        print("Correct answer:", standard_answer)
        info.update({'idx': i+thread_n*index, 'answer': answer, 'correct answer': standard_answer, 'usage_so_far': gpt_usage()})
        infos.append(info)
        processed_count += 1
        if standard_answer in answer:
            correct_list.append(i+thread_n*index)
            correct_count += 1
            
        current_accuracy = correct_count / processed_count
        print(f"{args.func} Current accuracy: {current_accuracy:.4f} ({correct_count}/{processed_count})")
        i += 1
        info.update({'current accuracy': current_accuracy})

    dics[index] = {'infos': infos, 'correct_list': correct_list}

if __name__ == "__main__":
    main()



