import threading
import argparse
import os
import json
from models import *
from ARCA import *


def is_answer_correct(given_answer, correct_answer):
    try:
        given_num = float(given_answer)
        correct_num = float(correct_answer)
        return abs(given_num - correct_num) < 1e-3
    except ValueError:
        return str(correct_answer) in str(given_answer)


def main():
    parser = argparse.ArgumentParser(description='GSM8K dataset')
    parser.add_argument('--func', type=str, default='ARCA', help='Function to execute')
    args = parser.parse_args()

    run(args, 1)


def run(args, id):
    print('*****************************')
    print(args)
    print('*****************************')

    log_file = f"logs/{args.func}/{args.begin_task}-{args.end_task}_generate{args.n_generate_sample}_select{args.n_select_sample}_evaluate{args.n_evaluate_time}_shuffle{args.shuffle}({id}).json"
    result_file = f"results/{args.func}.jsonl"
    os.makedirs(os.path.dirname(log_file), exist_ok=True)
    os.makedirs(os.path.dirname(result_file), exist_ok=True)

    with open(args.dataset_path, "r", encoding="utf-8") as input_file:
        data = json.load(input_file)

    total_num = args.end_task - args.begin_task
    chunks = [
        data[args.begin_task + i: args.begin_task + i + total_num // args.thread_n]
        for i in range(0, total_num, total_num // args.thread_n)
    ]
    dics = [{} for _ in range(args.thread_n)]

    threads = []
    for i in range(args.thread_n):
        thread = threading.Thread(target=process_data, args=(args, chunks[i], dics, i))
        threads.append(thread)
        thread.start()
    for thread in threads:
        thread.join()

    result_dict = {}
    result_dict['id'] = id
    result_dict["begin"] = args.begin_task
    result_dict["end"] = args.end_task
    result_dict["generate"] = args.n_generate_sample
    result_dict["select"] = args.n_select_sample
    result_dict["evaluate"] = args.n_evaluate_time
    result_dict['correct_list'] = []
    result_dict['correct_num'] = 0
    result_dict.update(gpt_usage())
    infos = []
    for d in dics:
        infos.extend(d['infos'])
        result_dict['correct_list'].extend(d['correct_list'])
    result_dict['correct_num'] = len(result_dict['correct_list'])
    result_dict['accuracy'] = result_dict['correct_num'] * 1.0 / total_num
    print("Accuracy:", result_dict['accuracy'])

    with open(log_file, 'w') as f:
        json.dump(infos, f, indent=4)
    with open(result_file, 'a') as f:
        f.write(json.dumps(result_dict) + "\n")


def process_data(args, data_chunk, dics, index):
    i = 0
    thread_n = len(data_chunk)
    infos = []
    correct_list = []
    for item in data_chunk:
        question = item["question"]
        standard_answer = item["answer"]

        print('*************************')
        print(f"No. {i + thread_n * index} data")
        if args.func == 'ARCA':
            answer, info = solve(question, args, to_print=False)

        print("Correct answer:", standard_answer)
        info.update({'idx': i + thread_n * index, 'answer': answer, 'correct answer': standard_answer,
                     'usage_so_far': gpt_usage()})
        infos.append(info)
        if is_answer_correct(answer, standard_answer):
            correct_list.append(i + thread_n * index)
        i += 1
        print(f'{args.func} current accuracy', len(correct_list) / i)
        info.update({'current accuracy': len(correct_list) / i})

    dics[index] = {'infos': infos, 'correct_list': correct_list}


if __name__ == "__main__":
    main()
