#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
sample_file_avg_judge.py
每个 results_xxx 模型文件夹里抽取 10 个 JSON 文件，
对里面所有条目整体评估，计算平均值。
记录：
1. 模型平均分 -> model_avg_scores.json
2. 每个文件平均分 -> file_avg_scores.json
3. 每个文件的整体评分结果 -> score_results/<模型>/<文件>
"""

import os, sys, json, pathlib, random, re
from tqdm import tqdm
from openai import OpenAI

# ---------- 配置 ----------
API_KEY  = os.getenv("DEEPSEEK_API_KEY", "")
BASE_URL = "https://api.deepseek.com"
MODEL    = "deepseek-chat"
ROOT     = pathlib.Path("./").resolve()
FILES_PER_MODEL = 10
SCORE_ROOT = pathlib.Path("./score_results").resolve()
# -------------------------

PROMPT_HEADER = """You are a senior counseling evaluator. You need to score an entire counseling dialogue (all turns together) on a 5-point scale and give a 1–2 sentence explanation. 

Scoring dimensions = Cognitive Empathy & Emotional Empathy & Deep Understanding (abbreviated as CC / EC / DL)

Scoring rules:
1 Very Poor (1 point): Lacks empathy, contains denial/blame/criticism/self-centeredness/empty comfort/ineffective advice.
2 Poor (2 points): Has cognitive empathy (CC) but lacks emotional empathy (EC); only restates facts or shows surface-level understanding.
3 Average (3 points): CC + EC present but inaccurate or partial; misjudged/mixed/omitted emotions.
4 Good (4 points): CC + EC accurate and nuanced, helps the client notice their difficulties and emotions, but does not touch on deeper motivations.
5 Excellent (5 points): On top of level 4, further connects to the client’s deeper needs/beliefs/traumas (DL).

Output format: Provide only one line of valid JSON, without adding any other characters or Markdown code block.
JSON format:
{
\"score\": <integer 1-5>,
\"reason\": \"<one-sentence reason>\"
}"""

client = OpenAI(api_key=API_KEY, base_url=BASE_URL)


def _strip_md_fence(text: str) -> str:
    """去掉 Markdown ``` 包裹"""
    m = re.search(r"```json\s*(\{.*\})\s*```", text, re.S) \
        or re.search(r"```\s*(\{.*\})\s*```", text, re.S)
    return m.group(1).strip() if m else text


def extract_qa(rec: dict):
    """从 record 里抽取 QA"""
    q = rec.get("input") or rec.get("question") or ""
    a = rec.get("model_output") or rec.get("answer") or ""
    return q, a


def judge_file(turns: list):
    """调用 DeepSeek 对整个文件所有 turns 一块评估"""
    conv_text = []
    for i, rec in enumerate(turns, 1):
        q, a = extract_qa(rec)
        if q and a:
            conv_text.append(f"[Turn {i}]\nSeeker: {q}\nSupporter: {a}")
    if not conv_text:
        return {"score": None, "reason": "empty file"}

    text_block = "\n\n".join(conv_text)
    user_msg = {
        "role": "user",
        "content": f"""{PROMPT_HEADER}

Now evaluate the ENTIRE conversation as a whole (all turns together). 
Provide ONE overall score and reason.

{text_block}

Please score now."""
    }

    sys_msg = {"role": "system", "content": "You are a helpful assistant."}
    resp = client.chat.completions.create(
        model=MODEL,
        messages=[sys_msg, user_msg],
        stream=False,
    )
    txt = _strip_md_fence(resp.choices[0].message.content.strip())
    try:
        data = json.loads(txt)
        return {"score": int(data["score"]), "reason": str(data.get("reason", ""))}
    except Exception:
        return {"score": None, "reason": txt}


def process_file(path: pathlib.Path, model_name: str):
    """处理单个文件（整体评估 + 保存结果）"""
    data = json.load(path.open(encoding="utf-8"))

    result = judge_file(data)

    # 保存结果
    out_path = SCORE_ROOT / model_name / path.name
    out_path.parent.mkdir(parents=True, exist_ok=True)
    json.dump(result, out_path.open("w", encoding="utf-8"),
              ensure_ascii=False, indent=2)

    return [result["score"]] if result.get("score") is not None else []


def process_model_dir(model_dir: pathlib.Path):
    """处理单个模型文件夹"""
    json_files = list(model_dir.rglob("*.json"))
    if not json_files:
        return None, {}

    chosen = random.sample(json_files, min(FILES_PER_MODEL, len(json_files)))
    all_scores = []
    file_results = {}

    for f in tqdm(chosen, desc=f"{model_dir.name}"):
        try:
            scores = process_file(f, model_dir.name)
            if scores:
                avg_file = sum(scores) / len(scores)
                file_results[f.name] = round(avg_file, 3)
                all_scores.extend(scores)
        except Exception as e:
            print(f"[WARN] {f} 失败: {e}", file=sys.stderr)

    if all_scores:
        avg_model = sum(all_scores) / len(all_scores)
        return round(avg_model, 3), file_results
    else:
        return None, file_results


def main():
    if not API_KEY:
        print("❌ 请先 export DEEPSEEK_API_KEY=sk-xxxx", file=sys.stderr)
        sys.exit(1)

    model_dirs = [p for p in ROOT.iterdir()
                  if p.is_dir() and p.name.startswith("results_")]

    results = {}
    file_level_results = {}

    for d in model_dirs:
        avg, file_results = process_model_dir(d)
        if avg is not None:
            results[d.name] = avg
            file_level_results[d.name] = file_results

    print("\n=== 模型平均分 ===")
    for k, v in results.items():
        print(f"{k:25s}: {v}")

    json.dump(results,
              open("model_avg_scores.json", "w", encoding="utf-8"),
              ensure_ascii=False, indent=2)
    json.dump(file_level_results,
              open("file_avg_scores.json", "w", encoding="utf-8"),
              ensure_ascii=False, indent=2)


if __name__ == "__main__":
    main()
