import json
import re
import os

def eval_gsm(dataset, gold_data):
    correct = 0

    for i in range(len(dataset)):
        gold = gold_data[i].split("####")[-1].strip()
        answer_nums = re.findall(r'\d+\,\d+|\d+\.\d+|\d+', dataset[i]) 

        if gold in answer_nums:
            correct += 1

    return correct


def load_gold(args):
    if args.dataset_name=="gsm":
        with open('./data/gsm/answer.json', 'r', encoding="utf-8") as f:
            gold_data = json.load(f)
            gold_data = [i.split("####")[-1].strip() for i in gold_data]
    elif args.dataset_name=="gsm_ko":
        with open('./data/gsm/ko_gsm.json', 'r', encoding="utf-8") as f:
            gold_data = json.load(f)
            gold_data = [str(i['answer']) for i in gold_data]
    else:
        from datasets import load_dataset
        ds = load_dataset("juletxara/mgsm", 'en', cache_dir=args.data_cache_dir)
        gold_data = []
        for i in range(250):
            gold_data.append(str(ds['test'][i]['answer_number']))

    return gold_data


def get_acc(args):
    savepath = f"./results/acc/{args.dataset_name}/{args.model_name}_{args.model_size}/"
    os.makedirs(savepath, exist_ok=True)
    savepath = os.path.join(savepath,f"acc_results.txt")
    if not os.path.exists(savepath):
        with open(savepath, 'w', encoding='utf-8') as f:
            f.write("")
        exists_lines = []
    else:
        with open(savepath, 'r', encoding='utf-8') as f:
            exists_lines = f.readlines()

    exists_keys = set()
    for line in exists_lines:
        key = line.split("::")[0].strip()
        exists_keys.add(key)
    
    base_path = f"./results/gen_text/{args.dataset_name}/{args.model_name}_{args.model_size}/"
    gold_data = load_gold(args)

    for file_name in os.listdir(base_path):
        if 'json' not in file_name:
            continue
        if 'prompt' in file_name:
            continue

        key = f"{args.dataset_name}_{args.model_name}_{args.model_size}_{file_name}"
        if key in exists_keys:
            continue

        with open(os.path.join(base_path,file_name), 'r', encoding="utf-8") as f:
            gen = json.load(f)

        cor = eval_gsm(gen, gold_data)
        acc = round((cor/1319)*100,2)

        with open(savepath, 'a', encoding='utf-8') as f:
            f.write(f"{key} :: {acc}\n")