import os

import jsonlines
import argparse
import csv
import numpy as np
from collections import defaultdict
from math_verify.metric import math_metric
from math_verify.parser import LatexExtractionConfig, ExprExtractionConfig
from math_verify import parse, verify
from common.utils import get_tl_length
from math_evaluator import DatasetEvaluator


evaluator = DatasetEvaluator('debug')

def verify_math(gold, answer):
    try:
        if 'boxed' in answer:
            grade = evaluator.evaluate(answer, gold)
        else:
            gold = parse(gold)
            answer = parse(answer)
            grade = verify(gold, answer)
        if grade == 1:
            return True
        else:
            return False
    except Exception:
        return False

# def verify_math(gold, answer):
#     res = evaluator.evaluate(answer, gold)
#     if res == 1:
#         return True
#     else:
#         return False
# def verify_math(gold, answer):
#     try:
#         gold = parse(gold)
#         answer = parse(answer)
#         grade = verify(gold, answer)
#         # Create the verification function
#         # verify_func = math_metric(
#         #     gold_extraction_target=(ExprExtractionConfig(),),
#         #     pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()),
#         #     aggregation_function=max,
#         #     precision=6
#         # )
#         # # Use the verification function
#         # grade, extracted_answers = verify_func([gold], [answer])
#
#         if grade == 1:
#             return True
#         else:
#             # print(grade)
#             # print(extracted_answers)
#             # input()
#             return False
#
#     except Exception:
#         # print(1)
#         return False


def get_raw_tl(file):
    with jsonlines.open(file) as reader:
        data = list(reader)
    tls = []
    for item in data:
        if item['answer'] is None or item['judge_res']:
            continue
        tl = get_tl_length(item)[0]
        tls.append(tl)
    return np.mean(tls)


def report_res(file, base_tl):
    with jsonlines.open(file) as reader:
        data = list(reader)
    cls_answers, cls_thoughts = [], []
    non_cls_answers, non_cls_thoughts, non_cls_solutions = [], [], []
    cls_count, solved_count, total_count = 0, 0, 0
    level_total_question = defaultdict(int)
    level_solved_question = defaultdict(int)
    cls_total_question = defaultdict(int)
    cls_thought_lengths = defaultdict(list)
    non_cls_total_question = defaultdict(int)
    non_cls_thought_lengths = defaultdict(list)
    non_cls_solved_thought_lengths = defaultdict(list)
    costs = 0
    for item in data:
        if item['answer'] is None:
            continue
        total_count += 1
        level = item['metadata']['level']
        solution = item['metadata']['answer'] if item['metadata']['answer'] is not None else item['metadata']['solution']
        level_total_question[level] += 1
        if item['judge_res']:
            cls_count += 1
            cls_answers.append(item['answer'])
            cls_thoughts.append(item['thought'])
            cls_total_question[level] += 1
            tl, cost = get_tl_length(item)
            costs += cost
            cls_thought_lengths[level].append(tl)
            continue
        else:
            tl, cost = get_tl_length(item)
            costs += cost
            non_cls_thought_lengths[level].append(tl)
            non_cls_answers.append(item['answer'])
            non_cls_thoughts.append(item['thought'])
            non_cls_solutions.append(solution)
            non_cls_total_question[level] += 1

        answer, gold = item['answer'], solution
        is_correct = verify_math(gold, answer)
        if is_correct:
            level_solved_question[level] += 1
            solved_count += 1
            non_cls_solved_thought_lengths[level].append(tl)
        # else:
        #     print(is_correct)
        #     print(item['answer'])
        #     print('===============')
        #     print(solution)
        #     input()

    levels = level_total_question.keys()
    levels = sorted(levels, key=lambda x: float(x) if x is not None else -1)
    for level in levels:
        if level not in level_total_question:
            continue
        if non_cls_total_question[level] == 0:
            non_cls_total_question[level] = 1
        print(
            f"Level: {level}. "
            f"Cls: {cls_total_question[level]}/{level_total_question[level]} -> {cls_total_question[level]/level_total_question[level]*100:.2f}%\t"
            f"Solved: {level_solved_question[level]}/{non_cls_total_question[level]} -> {level_solved_question[level]/non_cls_total_question[level]*100:.2f}%\t"
            f"Cls Thought Length: {np.mean(cls_thought_lengths[level]) if len(cls_thought_lengths[level]) > 0 else 0.:.2f}\t"
            f"Non Cls Thought Length: {np.mean(non_cls_thought_lengths[level]) if len(non_cls_thought_lengths[level]) > 0 else 0.:.2f}"
        )
    print(f"Total Cls: {cls_count}/{total_count} -> {cls_count/total_count*100:.2f}%")
    print(f"Total Solved: {solved_count}/{total_count-cls_count} -> {solved_count/(total_count-cls_count)*100:.2f}%")
    all_cls_tl, all_non_cls_tl, all_non_cls_solved_tl = [], [], []
    for level in levels:
        all_cls_tl.extend(cls_thought_lengths[level])
        all_non_cls_tl.extend(non_cls_thought_lengths[level])
        all_non_cls_solved_tl.extend(non_cls_solved_thought_lengths[level])
    print(f"Avg Cls TL: {np.mean(all_cls_tl):.2f}/{base_tl:.2f} -> {np.mean(all_cls_tl)/base_tl:.2f}")
    print(f"Avg Non Cls TL: {np.mean(all_non_cls_tl):.2f}/{base_tl:.2f} -> {np.mean(all_non_cls_tl)/base_tl:.2f}")
    print(f"Avg Non Cls Solved TL: {np.mean(all_non_cls_solved_tl):.2f}/{base_tl:.2f} -> {np.mean(all_non_cls_solved_tl)/base_tl:.2f}")
    # thought_lens = [len(thought) for thought in non_cls_thoughts]
    # print(np.mean(thought_lens))
    print(costs)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_path_raw", type=str)
    parser.add_argument('--input_path_cls', type=str)
    args = parser.parse_args()

    base_tl = get_raw_tl(args.input_path_raw)
    report_res(args.input_path_cls, base_tl)


if __name__ == '__main__':
    main()
