import os
import json
import re
from transformers import AutoTokenizer
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor, as_completed

RESULIT_PATH="./results"

skip_id = set(['671b1480bb02136c067d4f14', '6713066fbb02136c067d3214', '66fb6d71bb02136c067c7c34', '66ec3a46821e116aacb1c49f', '67040276bb02136c067cd8ae', '66ec41d3821e116aacb1c874', '66fcf2f2bb02136c067c9169', '671b1335bb02136c067d4e88', '66fa7f1dbb02136c067c6e6b', '66eae4de5a08c7b9b35dd12d', '66f39aa7821e116aacb2da76', '66f56c9c821e116aacb33a45', '66ed1556821e116aacb1ea14', '66f958b3bb02136c067c5219', '66ec356a821e116aacb1c22b', '66ed910a821e116aacb2033b', '66fe8eb3bb02136c067ca35f', '66ebd5125a08c7b9b35e0616', '66ec3d1d821e116aacb1c622', '66f55d66821e116aacb33734', '66f578fa821e116aacb33c58'])
skip_id = set()
print(f"skip num::{len(skip_id)}")
files=os.listdir(RESULIT_PATH)
# files = [f"{f_}_{i}.jsonl" for f_ in files for i in range(5)]

compensated = False


data = []
header = ["Model", "Fail","Overall", "Easy", "Hard", "Short", "Medium", "Long", "shorter_32k", "longer_32k"]
data.append(header) 

def extract_boxed_option(response):
    response = response.replace('*', '')
    matches = re.findall(r'\\boxed\{([^}]*)\}', response)
    if matches:
        if matches[-1]:
            return matches[-1][0]
        return matches[-1]
    return None

def extract_answer(response):
    response = response.replace('*', '')
    match = re.search(r'The correct answer is \(([A-D])\)', response)
    if match:
        return match.group(1)
    else:
        match = re.search(r'The correct answer is ([A-D])', response)
        if match:
            return match.group(1)
        else:
            return None

def process_pred(pred):
    _id = pred['_id']
    if _id in skip_id:
        return None
    context = pred['context']
    context_tokens = 0
    # if context_tokens >= 33000:
    #     return None
    response = pred['response']
    if "\\boxed{" in response:
        _pred = extract_boxed_option(response)
    else:
        _pred = extract_answer(response)
    #if not _pred:
    #    return None
    _answer = pred['answer']
    _judge = (_pred == _answer)
    acc = int(_judge)
    if compensated and _pred == None:
        acc = 0.25
    return {
        'difficulty': pred['difficulty'],
        'length': pred['length'],
        'acc': acc,
        'context_tokens': context_tokens,
        'pred': _pred
    }


for file in files:
    fail = 0
    filename = os.path.join(RESULIT_PATH, file)
    if not filename.endswith(".jsonl") and not filename.endswith(".json"):
        print(f"skip file::{filename}")
        continue
    try:
        pred_data = json.load(open(filename, encoding='utf-8'))
    except Exception as e:
        pred_data = [json.loads(line) for line in open(filename, encoding='utf-8')]
    
    easy, hard, short, medium, long, shorter_32k, longer_32k = 0, 0, 0, 0, 0, 0, 0
    easy_acc, hard_acc, short_acc, medium_acc, long_acc, shorter_32k_acc, longer_32k_acc = 0, 0, 0, 0, 0, 0, 0
    total_num = 0
    empty_num = 0

    with ProcessPoolExecutor() as executor:
        futures = [executor.submit(process_pred, pred) for pred in pred_data]
        for future in tqdm(as_completed(futures), total=len(futures)):
            result = future.result()
            if result is None:
                continue
            if result['difficulty'] == "easy":
                easy += 1
                easy_acc += result['acc']
            else:
                hard += 1
                hard_acc += result['acc']
            if result['context_tokens'] <= 33000:
                shorter_32k += 1
                shorter_32k_acc += result['acc']
            else:
                longer_32k += 1
                longer_32k_acc += result['acc']
            if result['length'] == "short":
                short += 1
                short_acc += result['acc']
            elif result['length'] == "medium":
                medium += 1
                medium_acc += result['acc']
            else:
                long += 1
                long_acc += result['acc']
            total_num += 1
            if not result['pred']:
                fail += 1
                empty_num += 1

    name = '.'.join(file.split('.')[:-1])
    overall_acc = round(100 * (easy_acc + hard_acc) / total_num, 1)
    easy_accuracy = round(100 * easy_acc / easy, 1) if easy != 0 else 0
    hard_accuracy = round(100 * hard_acc / hard, 1) if hard != 0 else 0
    short_accuracy = round(100 * short_acc / short, 1) if short != 0 else 0
    medium_accuracy = round(100 * medium_acc / medium, 1) if medium != 0 else 0
    long_accuracy = round(100 * long_acc / long, 1) if long != 0 else 0
    shorter_32k_accuracy = round(100 * shorter_32k_acc / shorter_32k, 1) if shorter_32k != 0 else 0
    longer_32k_accuracy = round(100 * longer_32k_acc / longer_32k, 1) if longer_32k != 0 else 0

    data.append([name, str(fail),str(overall_acc), str(easy_accuracy), str(hard_accuracy), str(short_accuracy), str(medium_accuracy), str(long_accuracy), str(shorter_32k_accuracy), str(longer_32k_accuracy)])
    
    print(f"{name} :: fail num::{fail}, valid_num::{total_num}, short_num::{short}")

column_widths = [max(len(str(row[i])) for row in data) for i in range(len(header))]

output = []
output_latex = []
for row in data:
    formatted_row = "  ".join(f"{str(row[i]):<{column_widths[i]}}" for i in range(len(row)))
    formatted_row_latex = "&" + "&".join(f"{str(row[i]):<{column_widths[i]}}" for i in range(len(row)))
    output.append(formatted_row)
    output_latex.append(formatted_row_latex)

with open(f'{RESULIT_PATH}/result.txt', 'w', encoding='utf-8') as f:
    f.write('\n'.join(output)+'\n')
