import os

import jsonlines
import argparse
import csv
import numpy as np
from collections import defaultdict
from math_verify.metric import math_metric
from math_verify.parser import LatexExtractionConfig, ExprExtractionConfig
from math_verify import parse, verify
from common.utils import get_tl_length
from math_evaluator import DatasetEvaluator
from scipy import stats


evaluator = DatasetEvaluator('debug')

def verify_math(gold, answer):
    try:
        if 'boxed' in answer:
            grade = evaluator.evaluate(answer, gold)
        else:
            gold = parse(gold)
            answer = parse(answer)
            grade = verify(gold, answer)
        if grade == 1:
            return True
        else:
            return False
    except Exception:
        return False


def level_map(level):
    if level in [1, 1.5]: return 1
    elif level in [2, 2.25]: return 2
    elif level in [2.5]: return 3
    elif level in [3, 3.5]: return 4
    elif level in [4]: return 5
    else: raise Exception(f'Unknown level: {level}')


def get_raw_tl(file):
    raw_tl_levels = defaultdict(list)
    with jsonlines.open(file) as reader:
        data = list(reader)
    tls = []
    for item in data:
        if item['answer'] is None or 'clarification' in item['answer'].lower():
            continue
        tl = get_tl_length(item)[0]
        tls.append(tl)
        raw_tl_levels[item['metadata']['level'] if 'omni' not in file else level_map(item['metadata']['level'])].append(tl)
    return np.mean(tls), raw_tl_levels


def report_res(file, base_tl):
    with jsonlines.open(file) as reader:
        data = list(reader)
    cls_answers, cls_thoughts = [], []
    non_cls_answers, non_cls_thoughts, non_cls_solutions = [], [], []
    cls_count, solved_count, total_count = 0, 0, 0
    level_total_question = defaultdict(int)
    level_solved_question = defaultdict(int)
    cls_total_question = defaultdict(int)
    cls_thought_lengths = defaultdict(list)
    non_cls_total_question = defaultdict(int)
    non_cls_thought_lengths = defaultdict(list)
    non_cls_solved_thought_lengths = defaultdict(list)
    task_set = set()
    costs = 0
    for item in data:
        if item['answer'] is None:
            print(2)
            continue
        total_count += 1
        level = item['metadata']['level'] if 'omni' not in file else level_map(item['metadata']['level'])
        solution = item['metadata']['answer'] if item['metadata']['answer'] is not None else item['metadata']['solution']
        level_total_question[level] += 1
        if 'clarification' in item['answer'].lower():
            task_set.add(item['task'])
            cls_count += 1
            cls_answers.append(item['answer'])
            cls_thoughts.append(item['thought'])
            cls_total_question[level] += 1
            tl, cost = get_tl_length(item)
            costs += cost
            cls_thought_lengths[level].append(tl)
            continue
        else:
            tl, cost = get_tl_length(item)
            costs += cost
            non_cls_thought_lengths[level].append(tl)
            non_cls_answers.append(item['answer'])
            non_cls_thoughts.append(item['thought'])
            non_cls_solutions.append(solution)
            non_cls_total_question[level] += 1

        answer, gold = item['answer'], solution
        is_correct = verify_math(gold, answer)
        if is_correct:
            level_solved_question[level] += 1
            solved_count += 1
            non_cls_solved_thought_lengths[level].append(tl)
        # else:
        #     print(is_correct)
        #     print(item['answer'])
        #     print('===============')
        #     print(solution)
        #     input()

    levels = level_total_question.keys()
    levels = sorted(levels, key=lambda x: float(x) if x is not None else -1)
    for level in levels:
        if level not in level_total_question:
            continue
        if non_cls_total_question[level] == 0:
            non_cls_total_question[level] = 1
        # print(
        #     f"Level: {level}. "
        #     f"Cls: {cls_total_question[level]}/{level_total_question[level]} -> {cls_total_question[level]/level_total_question[level]*100:.2f}%\t"
        #     f"Solved: {level_solved_question[level]}/{non_cls_total_question[level]} -> {level_solved_question[level]/non_cls_total_question[level]*100:.2f}%\t"
        #     f"Cls Thought Length: {np.mean(cls_thought_lengths[level]) if len(cls_thought_lengths[level]) > 0 else 0.:.2f}/{np.mean(raw_tl_levels[level]) if len(raw_tl_levels[level]) > 0 else 1 :.2f} "
        #     f"-> {np.mean(cls_thought_lengths[level])/np.mean(raw_tl_levels[level]) if len(cls_thought_lengths[level]) > 0  and len(raw_tl_levels[level]) > 0 else 0}\t"
        #     f"Non Cls Thought Length: {np.mean(non_cls_thought_lengths[level]) if len(non_cls_thought_lengths[level]) > 0 else 0.:.2f}/{np.mean(raw_tl_levels[level]) if len(raw_tl_levels[level]) > 0 else 1 :.2f} "
        #     f"-> {np.mean(non_cls_thought_lengths[level])/np.mean(raw_tl_levels[level]) if len(non_cls_thought_lengths[level]) > 0 and len(raw_tl_levels[level]) > 0 else 0 :.2f}"
        # )
    # print(f"Total Cls: {cls_count}/{total_count} -> {cls_count/total_count*100:.2f}%")
    # print(f"Total Solved: {solved_count}/{total_count-cls_count} -> {solved_count/(total_count-cls_count)*100:.2f}%")
    all_cls_tl, all_non_cls_tl, all_non_cls_solved_tl = [], [], []
    for level in levels:
        all_cls_tl.extend(cls_thought_lengths[level])
        all_non_cls_tl.extend(non_cls_thought_lengths[level])
        all_non_cls_solved_tl.extend(non_cls_solved_thought_lengths[level])
    # print(f"Avg Cls TL: {np.mean(all_cls_tl):.2f}/{base_tl:.2f} -> {np.mean(all_cls_tl)/base_tl:.2f}")
    # print(f"Avg Non Cls TL: {np.mean(all_non_cls_tl):.2f}/{base_tl:.2f} -> {np.mean(all_non_cls_tl)/base_tl:.2f}")
    # print(f"Avg Non Cls Solved TL: {np.mean(all_non_cls_solved_tl):.2f}/{base_tl:.2f} -> {np.mean(all_non_cls_solved_tl)/base_tl:.2f}")
    # thought_lens = [len(thought) for thought in non_cls_thoughts]
    # print(np.mean(thought_lens))
    # print(costs)
    return cls_count/total_count, np.mean(all_cls_tl)/base_tl, np.mean(all_non_cls_tl)/base_tl, task_set, len(data), costs

def mean_confidence_interval(data, confidence=0.95):
    """
    计算均值的置信区间
    参数:
        data: 一维numpy数组或可迭代对象
        confidence: 置信水平（默认0.95）
    返回:
        (均值, 下限, 上限)
    """
    data = np.array(data)
    n = len(data)
    mean = np.mean(data)
    std = np.std(data, ddof=1)
    se = std / np.sqrt(n)
    alpha = 1 - confidence
    t_crit = stats.t.ppf(1 - alpha/2, df=n-1)
    return mean, t_crit * se

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_path_raw", type=str)
    parser.add_argument('--input_path_cls', type=str)
    args = parser.parse_args()

    base_tls = []
    for i in range(5):
        base_tl, raw_tl_levels = get_raw_tl(args.input_path_raw.format(i))
        base_tls.append(base_tl)
    base_tl = np.mean(base_tls)
    crs, tlcs, tlncs = [], [], []
    crts = set()
    len_data = 0
    costs = 0
    for i in range(5):
        cr, tlc, tlnc, crt, len_data, cost = report_res(args.input_path_cls.format(i), base_tl)
        crs.append(cr)
        tlcs.append(tlc)
        tlncs.append(tlnc)
        crts = crts.union(crt)
        costs += cost
    print(mean_confidence_interval(crs))
    print(mean_confidence_interval(tlcs))
    print(mean_confidence_interval(tlncs))
    print(len(crts), len_data)
    print(costs)


if __name__ == '__main__':
    main()
