import os
import re
import json
import shutil
import datetime
from tqdm import tqdm
from openai import OpenAI
from collections import defaultdict

# ========== Configuration ==========
openai_api_key = "EMPTY"
judge_prompt_path = "./make_data/prompt/model_judge.txt"
judge_model = "gpt-4o"

# Initialize OpenAI client for evaluation
judge_client = OpenAI(
    api_key=openai_api_key,
)

# Read judge prompt
def load_judge_prompt():
    with open(judge_prompt_path, "r", encoding="utf-8") as f:
        return f.read().strip()

def model_judge(pred: str, gt: str, judge_prompt: str) -> bool:
    """Use large language model to judge whether predicted answer and ground truth answer are semantically consistent"""
    msg = judge_prompt.format(pred=pred, gt=gt)
    resp = judge_client.chat.completions.create(
        model=judge_model,
        messages=[{"role": "system", "content": msg}],
    )
    result = resp.choices[0].message.content.lower()
    # print(f"🔍 Calling LLM for judgement, model pred:{pred}, gt: {gt}, is_same: {result}")
    return "same" in result


def evaluate_answer(raw_pred_ans: str, gt_ans: str, judge_prompt: str, tolerance: float = 0.02) -> bool:
    """
    Judge whether the model's predicted answer is correct.
    
    Steps:
    1. Extract predicted answer text: Prioritize content from <answer>…</answer> tags, if not found use original output.
    2. Semantic consistency judgment: Call `model_judge` to determine if semantics are consistent.
    
    Optional steps (currently commented):
    - Direct string comparison: If extracted result is exactly the same as ground truth, judge as correct.
    - Numerical approximation comparison: If both ground truth and prediction can be converted to float and error is within ±tolerance range, judge as correct.
    """
    # 1. Extract content from <answer> tags
    if not raw_pred_ans:
        print("⚠️ Predicted answer is empty, cannot evaluate")
        return False
    m = re.search(r'<answer>\s*(.*?)\s*</answer>', raw_pred_ans, re.IGNORECASE | re.DOTALL)
    final_pred = m.group(1).strip() if m else raw_pred_ans.strip()

    # 2. Exact string comparison (currently commented)
    if str(final_pred) == str(gt_ans).strip():
        return True

    # 3. Numerical approximation comparison (currently commented)
    try:
        p_val = float(final_pred)
        g_val = float(gt_ans)
        if g_val != 0 and (1 - tolerance) <= (p_val / g_val) <= (1 + tolerance):
            # print(f"🔍 Comparing numbers: {final_pred} vs {gt_ans},t")
            return True
        else:
            # print(f"🔍 Comparing numbers: {final_pred} vs {gt_ans},f")
            return False
    except ValueError:
        msg = judge_prompt.format(answer=final_pred)
        resp = judge_client.chat.completions.create(
            model=judge_model,
            messages=[{"role": "system", "content": msg}],
            )
        result = resp.choices[0].message.content.lower()
        try:
            p_val = float(result)
            g_val = float(gt_ans)
            if g_val != 0 and (1 - tolerance) <= (p_val / g_val) <= (1 + tolerance):
                # print(f"🔍 Comparing numbers: {final_pred} vs {gt_ans},t")
                return True
            else:
                # print(f"🔍 Comparing numbers: {final_pred} vs {gt_ans},f")
                return False
        except ValueError:
            return False

def evaluate_accuracy(jsonl_path):
    """
    Evaluate the overall accuracy of model answers in JSONL file containing prediction results
    and statistics accuracy by chart_type, while writing accuracy to log.
    """
    print(f"📊 Evaluating accuracy for: {jsonl_path}")

    judge_prompt = load_judge_prompt()

    with open(jsonl_path, "r", encoding="utf-8") as f:
        raw_data = [json.loads(line) for line in f if line.strip()]
    data = [
        entry for entry in raw_data
        if entry.get("QA", {}).get("model_answer") is not None
    ]

    total_q = sum(len(entry["QA"]["question_list"]) for entry in data)
    print(f"📊 Loaded {len(data)} entries, total {total_q} QA pairs.")

    stats = defaultdict(lambda: {"questions": 0, "correct": 0})

    for entry in tqdm(data, desc="Evaluating entries"):
        ctype        = entry.get("chart_type", "Unknown")
        qs           = entry["QA"]["question_list"]
        preds        = entry["QA"]["model_answer"]
        gts          = entry["QA"]["answer_list"]
        if len(preds) != len(gts):
            print(f"⚠️ Prediction/ground truth count mismatch, skipping {entry.get('plot_id')}")
            continue
        
        if ctype=='Composite' or ctype=='Bar' or ctype=='Line':
            ctype=f'Real_{ctype}_Charts'
        is_correct_list = [
            bool(evaluate_answer(p, g, judge_prompt))
            for p, g in zip(preds, gts)
        ]
        entry["QA"]["is_correct"] = is_correct_list

        stats[ctype]["questions"] += len(is_correct_list)
        stats[ctype]["correct"]   += sum(is_correct_list)

    # Write back JSONL with is_correct field
    with open(jsonl_path, "w", encoding="utf-8") as f:
        for entry in data:
            f.write(json.dumps(entry, ensure_ascii=False) + "\n")

    # Calculate overall accuracy
    total_correct = sum(v["correct"] for v in stats.values())
    overall_acc   = total_correct / total_q if total_q else 0
    print(f"\n✅ Overall Accuracy: {overall_acc:.2%}  ({total_correct}/{total_q})\n")

    print("📊 Accuracy by chart type:")
    for ctype, v in stats.items():
        q_cnt = v["questions"]
        c_cnt = v["correct"]
        acc   = c_cnt / q_cnt if q_cnt else 0
        print(f"  - {ctype:<15}: {acc:.2%}  ({c_cnt}/{q_cnt})")

    # Also write accuracy to log
    log_file = "new_evaluation_log.txt"
    now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    # Construct accuracy string for each type
    type_acc_strs = [
        f"{ctype} {v['correct']}/{v['questions']} ({v['correct']/v['questions']:.2%})"
        for ctype, v in stats.items()
    ]
    log_line = (
        f'***tolerance:2%*** |'
        f"{now} | File: {os.path.basename(jsonl_path)} | "
        f"Overall {total_correct}/{total_q} ({overall_acc:.2%}) | "
        + " | ".join(type_acc_strs)
        + "\n"
    )
    with open(log_file, "a", encoding="utf-8") as logf:
        logf.write(log_line)
    print(f"📝 Logged: {log_line.strip()}")

    return overall_acc, stats




if __name__ == "__main__":

    jsonl_path='./test/gemini-2.0-flash_direct_2.jsonl'
    evaluate_accuracy(jsonl_path)

        # Use append mode to write overall accuracy information to log file,
        # recording current test time, filename, test count, correct answer count and accuracy rate
        
        #print(incorrect_records)