import csv
import os
from typing import List, Dict
import inspect


BASE_DIR = os.path.dirname(__file__)
DATA_DIR = os.path.join(BASE_DIR, 'output_multi_agent')
BERT_MODEL_DIR = os.path.join(BASE_DIR, 'bert-base-chinese')

# 多智能体系统的CSV文件
MULTI_AGENT_CSV_FILE = 'test_set_multi_agent_qwen3_235b_a22b_instruct_2507_gpt_5.csv'


def read_csv_rows(csv_path: str) -> List[Dict[str, str]]:
    """读取CSV文件并返回行数据"""
    rows: List[Dict[str, str]] = []
    with open(csv_path, 'r', encoding='utf-8-sig', newline='') as f:
        reader = csv.DictReader(f)
        for row in reader:
            rows.append(row)
    
    required_cols = {
        'Ground_Truth',
        'LLM_Output',
    }
    missing = required_cols - set(rows[0].keys()) if rows else required_cols
    if missing:
        raise ValueError(f"CSV缺少必要列: {sorted(missing)}")
    return rows


def compute_bertscore_f1(refs: List[str], hyps: List[str]) -> float:
    """计算BERTScore F1分数"""
    from bert_score import score as bert_score

    _, _, f1 = bert_score(
        cands=hyps,
        refs=refs,
        lang='zh',
        model_type=BERT_MODEL_DIR,
        num_layers=12,
        rescale_with_baseline=False,
        device='cpu',
    )
    return float(f1.mean().item()) if hasattr(f1, 'mean') else float(sum(f1) / len(f1))


def compute_moverscore_offline(refs: List[str], hyps: List[str]) -> float:
    """计算MoverScore分数"""
    os.environ.setdefault('TRANSFORMERS_OFFLINE', '1')
    os.environ.setdefault('HF_HUB_OFFLINE', '1')
    os.environ.setdefault('TRANSFORMERS_NO_TF', '1')
    os.environ.setdefault('TRANSFORMERS_NO_FLAX', '1')
    os.environ.setdefault('TRANSFORMERS_NO_JAX', '1')

    try:
        from moverscore import get_idf_dict, word_mover_score  # type: ignore
        try:
            from moverscore import get_stop_words  # type: ignore
        except Exception:
            get_stop_words = None  # type: ignore
    except ImportError:
        from moverscore_v2 import get_idf_dict, word_mover_score, get_stop_words  # type: ignore

    if not refs or not hyps:
        return 0.0

    stop_words = []
    if 'get_stop_words' in globals() and callable(globals()['get_stop_words']):
        try:
            stop_words = globals()['get_stop_words']('zh') or []  # type: ignore
        except Exception:
            stop_words = []

    idf_ref = get_idf_dict(refs)
    idf_hyp = get_idf_dict(hyps)

    sig = inspect.signature(word_mover_score)
    params = sig.parameters

    kwargs = {}
    if 'stop_words' in params:
        kwargs['stop_words'] = stop_words
    if 'n_gram' in params:
        kwargs['n_gram'] = 1
    if 'remove_subwords' in params:
        kwargs['remove_subwords'] = False
    if 'model_name' in params:
        kwargs['model_name'] = BERT_MODEL_DIR

    if 'idf_dict' in params:
        kwargs['idf_dict'] = {**idf_ref, **idf_hyp}
        scores = word_mover_score(refs, hyps, **kwargs)
    else:
        try:
            scores = word_mover_score(refs, hyps, idf_ref, idf_hyp, **kwargs)
        except TypeError:
            scores = word_mover_score(refs=refs, hyps=hyps, idf_ref=idf_ref, idf_hyp=idf_hyp, **kwargs)  # type: ignore

    return float(sum(scores) / len(scores)) if scores else 0.0


# ==== 中文加权 ROUGE-1（基于词性权重）====
POS_WEIGHT_ZH = {
    'n': 1.0,
    'v': 0.9,
    'a': 0.7,
    'd': 0.7,
    'm': 0.6,
    't': 0.8,
    'r': 0.3,
    'p': 0.3,
    'c': 0.3,
    'u': 0.2,
}


def weighted_rouge1_zh(gen_text: str, ref_text: str) -> float:
    """计算中文加权ROUGE-1分数"""
    import jieba
    import jieba.posseg as pseg

    ref_words = []
    ref_weights: Dict[str, float] = {}
    for w, tag in pseg.cut(ref_text):
        tag_key = tag[:1] if tag else ''
        weight = POS_WEIGHT_ZH.get(tag_key, 0.5)
        ref_words.append(w)
        if w not in ref_weights:
            ref_weights[w] = weight

    gen_words = jieba.lcut(gen_text)

    from collections import Counter
    ref_counts = Counter(ref_words)
    gen_counts = Counter(gen_words)

    match_weight = 0.0
    ref_total_weight = 0.0
    for word, ref_cnt in ref_counts.items():
        weight = ref_weights.get(word, 0.5)
        match_count = min(gen_counts.get(word, 0), ref_cnt)
        match_weight += match_count * weight
        ref_total_weight += ref_cnt * weight

    return (match_weight / ref_total_weight) if ref_total_weight > 0 else 0.0


def compute_weighted_rouge1_zh_avg(refs: List[str], hyps: List[str]) -> float:
    """计算平均中文加权ROUGE-1分数"""
    if not refs:
        return 0.0
    scores = [weighted_rouge1_zh(h, r) for h, r in zip(hyps, refs)]
    return float(sum(scores) / len(scores))


def analyze_multi_agent_results():
    """分析多智能体系统的结果"""
    csv_path = os.path.join(DATA_DIR, MULTI_AGENT_CSV_FILE)
    
    if not os.path.exists(csv_path):
        print(f"❌ 多智能体结果文件不存在: {csv_path}")
        print(f"   请确保文件路径正确: {MULTI_AGENT_CSV_FILE}")
        return
    
    print(f"📊 分析多智能体系统结果: {MULTI_AGENT_CSV_FILE}")
    print("=" * 80)
    
    try:
        rows = read_csv_rows(csv_path)
        print(f"✅ 成功读取 {len(rows)} 条记录")
        
        # 提取数据
        refs_raw = [(r.get('Ground_Truth') or '').strip() for r in rows]
        hyps_raw = [(r.get('LLM_Output') or '').strip() for r in rows]
        
        # 统计空内容和异常内容
        empty_hyp = sum(1 for h in hyps_raw if h == '')
        empty_ref = sum(1 for r in refs_raw if r == '')
        llm_error = sum(1 for h in hyps_raw if h.startswith('LLM'))
        multi_agent_error = sum(1 for h in hyps_raw if h.startswith('Multi-Agent'))
        
        # 过滤有效数据
        valid_pairs = [(h, r) for h, r in zip(hyps_raw, refs_raw) 
                      if h != '' and r != '' and not h.startswith('LLM') and not h.startswith('Multi-Agent')]
        
        print(f"\n📈 数据统计:")
        print(f"   总样本数: {len(rows)}")
        print(f"   有效样本: {len(valid_pairs)}")
        print(f"   跳过样本: {len(rows) - len(valid_pairs)}")
        print(f"     - 空输出: {empty_hyp}")
        print(f"     - 空真值: {empty_ref}")
        print(f"     - LLM异常: {llm_error}")
        print(f"     - Multi-Agent异常: {multi_agent_error}")
        
        if not valid_pairs:
            print("❌ 没有有效样本可以评估")
            return
        
        # 分离有效数据
        hyps = [h for h, _ in valid_pairs]
        refs = [r for _, r in valid_pairs]
        
        print(f"\n🔍 开始计算评估指标...")
        
        # 计算BERTScore F1
        print("   计算BERTScore F1...")
        bert_f1 = compute_bertscore_f1(refs, hyps)
        
        # 计算MoverScore
        print("   计算MoverScore...")
        try:
            mover = compute_moverscore_offline(refs, hyps)
        except Exception as e:
            print(f"   ⚠️ MoverScore计算失败: {e}")
            mover = None
        
        # 计算中文加权ROUGE-1
        print("   计算中文加权ROUGE-1...")
        try:
            zh_rouge = compute_weighted_rouge1_zh_avg(refs, hyps)
        except Exception as e:
            print(f"   ⚠️ 中文加权ROUGE-1计算失败: {e}")
            zh_rouge = None
        
        # 显示结果
        print("\n" + "=" * 80)
        print("📊 多智能体系统评估结果")
        print("=" * 80)
        
        def fmt(x):
            return 'N/A' if x is None else f"{x:.4f}"
        
        print(f"{'指标':<20} {'分数':<12} {'说明'}")
        print("-" * 50)
        print(f"{'BERTScore F1':<20} {fmt(bert_f1):<12} 基于BERT的语义相似度")
        print(f"{'MoverScore':<20} {fmt(mover):<12} 基于词移动距离的语义相似度")
        print(f"{'中文加权ROUGE-1':<20} {fmt(zh_rouge):<12} 基于词性权重的中文ROUGE分数")
        
        # 详细分析
        print(f"\n📋 详细分析:")
        print(f"   有效样本数: {len(valid_pairs)}")
        print(f"   数据质量: {len(valid_pairs)/len(rows)*100:.1f}%")
        
        if bert_f1 is not None:
            print(f"   BERTScore F1: {bert_f1:.4f}")
            if bert_f1 > 0.8:
                print("     ✅ 优秀 (BERTScore F1 > 0.8)")
            elif bert_f1 > 0.6:
                print("     ✅ 良好 (BERTScore F1 > 0.6)")
            elif bert_f1 > 0.4:
                print("     ⚠️ 一般 (BERTScore F1 > 0.4)")
            else:
                print("     ❌ 需要改进 (BERTScore F1 ≤ 0.4)")
        
        if mover is not None:
            print(f"   MoverScore: {mover:.4f}")
            if mover > 0.3:
                print("     ✅ 优秀 (MoverScore > 0.3)")
            elif mover > 0.2:
                print("     ✅ 良好 (MoverScore > 0.2)")
            elif mover > 0.1:
                print("     ⚠️ 一般 (MoverScore > 0.1)")
            else:
                print("     ❌ 需要改进 (MoverScore ≤ 0.1)")
        
        if zh_rouge is not None:
            print(f"   中文加权ROUGE-1: {zh_rouge:.4f}")
            if zh_rouge > 0.6:
                print("     ✅ 优秀 (ROUGE-1 > 0.6)")
            elif zh_rouge > 0.4:
                print("     ✅ 良好 (ROUGE-1 > 0.4)")
            elif zh_rouge > 0.2:
                print("     ⚠️ 一般 (ROUGE-1 > 0.2)")
            else:
                print("     ❌ 需要改进 (ROUGE-1 ≤ 0.2)")
        
        # 样本分析
        print(f"\n📝 样本分析 (前5个有效样本):")
        print("-" * 80)
        for i, (hyp, ref) in enumerate(valid_pairs[:5]):
            print(f"样本 {i+1}:")
            print(f"  真值: {ref}")
            print(f"  预测: {hyp}")
            print(f"  长度: 真值={len(ref)}字符, 预测={len(hyp)}字符")
            print()
        
        # 保存结果到文件
        result_file = os.path.join(DATA_DIR, 'multi_agent_evaluation_results.txt')
        with open(result_file, 'w', encoding='utf-8') as f:
            f.write("多智能体系统评估结果\n")
            f.write("=" * 50 + "\n")
            f.write(f"文件: {MULTI_AGENT_CSV_FILE}\n")
            f.write(f"总样本数: {len(rows)}\n")
            f.write(f"有效样本数: {len(valid_pairs)}\n")
            f.write(f"数据质量: {len(valid_pairs)/len(rows)*100:.1f}%\n\n")
            f.write("评估指标:\n")
            f.write(f"BERTScore F1: {fmt(bert_f1)}\n")
            f.write(f"MoverScore: {fmt(mover)}\n")
            f.write(f"中文加权ROUGE-1: {fmt(zh_rouge)}\n")
        
        print(f"✅ 评估结果已保存到: {result_file}")
        
    except Exception as e:
        print(f"❌ 分析过程中发生错误: {e}")
        import traceback
        traceback.print_exc()


if __name__ == '__main__':
    print("🤖 多智能体系统评估工具")
    print("=" * 50)
    
    # 检查数据目录
    if not os.path.exists(DATA_DIR):
        print(f"❌ 数据目录不存在: {DATA_DIR}")
        print("   请确保 output_multi_agent 目录存在")
        exit(1)
    
    # 检查BERT模型目录
    if not os.path.exists(BERT_MODEL_DIR):
        print(f"⚠️ BERT模型目录不存在: {BERT_MODEL_DIR}")
        print("   将使用在线模型，可能影响性能")
    
    # 开始分析
    analyze_multi_agent_results()
    
    print("\n🎉 评估完成!")
