import os

import jsonlines
import argparse
import re
import numpy as np
from collections import defaultdict
from transformers import AutoTokenizer
from math_verify import parse, verify
from math_evaluator import DatasetEvaluator


evaluator = DatasetEvaluator('debug')
tokenizer = AutoTokenizer.from_pretrained('/data/model_path/models/Qwen3-8B-Base')

def verify_math(gold, answer):
    try:
        if 'boxed' in answer:
            grade = evaluator.evaluate(answer, gold)
        else:
            gold = parse(gold)
            answer = parse(answer)
            grade = verify(gold, answer)
        if grade == 1:
            return True
        else:
            return False
    except Exception:
        return False


def get_thought_answer(output):
    if '</think>' not in output:
        thought = ''
        answer = output
    else:
        thought = output.split('<think>')[1].split('</think>')[0].strip()
        answer = output.split('</think>')[1].strip()
    return thought, answer


def get_tl(item):
    thought, answer = get_thought_answer(item['response'])
    if len(thought) == 0:
        return len(tokenizer.tokenize(answer))
    else:
        return len(tokenizer.tokenize(thought))

def get_raw_tl(file):
    with jsonlines.open(file) as reader:
        data = list(reader)
    tls = []
    for item in data:
        if item['response'] is None or 'clarification' in item['response'].lower():
            continue
        tls.append(get_tl(item))
    return np.mean(tls)


def report_res(file, base_tl):
    with jsonlines.open(file) as reader:
        data = list(reader)
    cls_answers, cls_thoughts = [], []
    non_cls_answers, non_cls_thoughts, non_cls_solutions = [], [], []
    cls_count, solved_count, total_count = 0, 0, 0
    level_total_question = defaultdict(int)
    level_solved_question = defaultdict(int)
    cls_total_question = defaultdict(int)
    cls_thought_lengths = defaultdict(list)
    non_cls_total_question = defaultdict(int)
    non_cls_thought_lengths = defaultdict(list)
    non_cls_solved_thought_lengths = defaultdict(list)
    for item in data:
        total_count += 1
        level = item['metadata']['level']
        solution = item['metadata']['answer'] if item['metadata']['answer'] is not None else item['metadata']['solution']
        level_total_question[level] += 1
        thought, answer = get_thought_answer(item['response'])
        if 'clarification' in answer.lower():
            cls_count += 1
            cls_answers.append(answer)
            cls_thoughts.append(thought)
            cls_total_question[level] += 1
            tl = get_tl(item)
            cls_thought_lengths[level].append(tl)
            continue
        else:
            tl = get_tl(item)
            non_cls_thought_lengths[level].append(tl)
            non_cls_answers.append(answer)
            non_cls_thoughts.append(thought)
            non_cls_solutions.append(solution)
            non_cls_total_question[level] += 1

        answer, gold = answer, solution
        is_correct = verify_math(gold, answer)
        if is_correct:
            level_solved_question[level] += 1
            solved_count += 1
            non_cls_solved_thought_lengths[level].append(tl)

    levels = level_total_question.keys()
    levels = sorted(levels, key=lambda x: float(x) if x is not None else -1)
    for level in levels:
        if level not in level_total_question:
            continue
        if non_cls_total_question[level] == 0:
            non_cls_total_question[level] = 1
        print(
            f"Level: {level}. "
            f"Cls: {cls_total_question[level]}/{level_total_question[level]} -> {cls_total_question[level]/level_total_question[level]*100:.2f}%\t"
            f"Solved: {level_solved_question[level]}/{non_cls_total_question[level]} -> {level_solved_question[level]/non_cls_total_question[level]*100:.2f}%\t"
            f"Cls Thought Length: {np.mean(cls_thought_lengths[level]) if len(cls_thought_lengths[level]) > 0 else 0.:.2f}\t"
            f"Non Cls Thought Length: {np.mean(non_cls_thought_lengths[level]) if len(non_cls_thought_lengths[level]) > 0 else 0.:.2f}"
        )
    print(f"Total Cls: {cls_count}/{total_count} -> {cls_count/total_count*100:.2f}%")
    print(f"Total Solved: {solved_count}/{total_count} -> {solved_count/(total_count)*100:.2f}%")
    all_cls_tl, all_non_cls_tl, all_non_cls_solved_tl = [], [], []
    for level in levels:
        all_cls_tl.extend(cls_thought_lengths[level])
        all_non_cls_tl.extend(non_cls_thought_lengths[level])
        all_non_cls_solved_tl.extend(non_cls_solved_thought_lengths[level])
    print(f"Avg Cls TL: {np.mean(all_cls_tl):.2f}/{base_tl:.2f} -> {np.mean(all_cls_tl)/base_tl:.2f}")
    print(f"Avg Non Cls TL: {np.mean(all_non_cls_tl):.2f}/{base_tl:.2f} -> {np.mean(all_non_cls_tl)/base_tl:.2f}")
    print(f"Avg Non Cls Solved TL: {np.mean(all_non_cls_solved_tl):.2f}/{base_tl:.2f} -> {np.mean(all_non_cls_solved_tl)/base_tl:.2f}")
    print(np.max(all_cls_tl))
    print(np.max(all_non_cls_tl))
    # thought_lens = [len(thought) for thought in non_cls_thoughts]
    # print(np.mean(thought_lens))

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_path_raw", type=str)
    parser.add_argument('--input_path_cls', type=str)
    args = parser.parse_args()

    base_tl = get_raw_tl(args.input_path_raw)
    report_res(args.input_path_cls, base_tl)


if __name__ == '__main__':
    main()
