import csv
import os
from typing import List, Dict
import inspect

BASE_DIR = os.path.dirname(__file__)

# 定义所有输出目录
OUTPUT_DIRS = {
    'output_claude': os.path.join(BASE_DIR, 'output_claude'),
    'output_ds': os.path.join(BASE_DIR, 'output_ds'),
    'output_qwen': os.path.join(BASE_DIR, 'output_qwen'),
    'output_gpt': os.path.join(BASE_DIR, 'output_gpt'),
}

# 定义每个目录对应的文件
OUTPUT_FILES = {
    'output_claude': [
        'claude_output_baseline.csv',
        'claude_output_generate_module.csv',
    ],
    'output_ds': [
        'deepseek_output_baseline.csv',
        'deepseek_output_generate_module.csv',
    ],
    'output_qwen': [
        'qwen_output_baseline.csv',
        'qwen_output_generate_module.csv',
    ],
    'output_gpt': [
        'gpt_output_baseline.csv',
        'gpt_output_generate_module.csv',
    ],
}
BERT_MODEL_DIR = os.path.join(BASE_DIR, 'bert-base-chinese')


def read_csv_rows(csv_path: str) -> List[Dict[str, str]]:
    rows: List[Dict[str, str]] = []
    with open(csv_path, 'r', encoding='utf-8-sig', newline='') as f:
        reader = csv.DictReader(f)
        for row in reader:
            rows.append(row)
    required_cols = {
        'Character',
        'Ground_Truth_Type',
        'Ground_Truth_Reasoning',
        'Predicted_Type',
        'Predicted_Reasoning',
    }
    missing = required_cols - set(rows[0].keys()) if rows else required_cols
    if missing:
        raise ValueError(f"CSV缺少必要列: {sorted(missing)}")
    return rows


def compute_type_acc(rows: List[Dict[str, str]]) -> float:
    correct = 0
    total = 0
    for r in rows:
        gt = (r.get('Ground_Truth_Type') or '').strip()
        pred = (r.get('Predicted_Type') or '').strip()
        if gt == '' and pred == '':
            continue
        total += 1
        if gt == pred:
            correct += 1
    return (correct / total) if total > 0 else 0.0


def compute_bertscore_f1(refs: List[str], hyps: List[str]) -> float:
    from bert_score import score as bert_score

    _, _, f1 = bert_score(
        cands=hyps,
        refs=refs,
        lang='zh',
        model_type=BERT_MODEL_DIR,
        num_layers=12,
        rescale_with_baseline=False,
        device='cpu',
    )
    return float(f1.mean().item()) if hasattr(f1, 'mean') else float(sum(f1) / len(f1))


def compute_moverscore_offline(refs: List[str], hyps: List[str]) -> float:
    """离线计算 MoverScore。未安装则抛出 ImportError 由调用方处理。"""
    os.environ.setdefault('TRANSFORMERS_OFFLINE', '1')
    os.environ.setdefault('HF_HUB_OFFLINE', '1')
    os.environ.setdefault('TRANSFORMERS_NO_TF', '1')
    os.environ.setdefault('TRANSFORMERS_NO_FLAX', '1')
    os.environ.setdefault('TRANSFORMERS_NO_JAX', '1')

    try:
        from moverscore import get_idf_dict, word_mover_score  # type: ignore
        try:
            from moverscore import get_stop_words  # type: ignore
        except Exception:
            get_stop_words = None  # type: ignore
    except ImportError:
        from moverscore_v2 import get_idf_dict, word_mover_score, get_stop_words  # type: ignore

    if not refs or not hyps:
        return 0.0

    stop_words = []
    if 'get_stop_words' in globals() and callable(globals()['get_stop_words']):
        try:
            stop_words = globals()['get_stop_words']('zh') or []  # type: ignore
        except Exception:
            stop_words = []

    idf_ref = get_idf_dict(refs)
    idf_hyp = get_idf_dict(hyps)

    sig = inspect.signature(word_mover_score)
    params = sig.parameters

    kwargs = {}
    if 'stop_words' in params:
        kwargs['stop_words'] = stop_words
    if 'n_gram' in params:
        kwargs['n_gram'] = 1
    if 'remove_subwords' in params:
        kwargs['remove_subwords'] = False
    if 'model_name' in params:
        kwargs['model_name'] = BERT_MODEL_DIR

    if 'idf_dict' in params:
        kwargs['idf_dict'] = {**idf_ref, **idf_hyp}
        scores = word_mover_score(refs, hyps, **kwargs)
    else:
        try:
            scores = word_mover_score(refs, hyps, idf_ref, idf_hyp, **kwargs)
        except TypeError:
            scores = word_mover_score(refs=refs, hyps=hyps, idf_ref=idf_ref, idf_hyp=idf_hyp, **kwargs)  # type: ignore

    return float(sum(scores) / len(scores)) if scores else 0.0


# ==== 中文加权 ROUGE-1（基于词性权重）====
POS_WEIGHT_ZH = {
    'n': 1.0,   # 名词
    'v': 0.9,   # 动词
    'a': 0.7,   # 形容词
    'd': 0.7,   # 副词
    'm': 0.6,   # 数词/量词
    't': 0.8,   # 时间词
    'r': 0.3,   # 代词
    'p': 0.3,   # 介词
    'c': 0.3,   # 连词
    'u': 0.2,   # 助词
}


def weighted_rouge1_zh(gen_text: str, ref_text: str) -> float:
    try:
        import jieba
        import jieba.posseg as pseg
    except Exception as e:
        raise ImportError('未安装 jieba，请先 pip install jieba') from e

    # 参考文本使用词性切分
    ref_words = []
    ref_weights: Dict[str, float] = {}
    for w, tag in pseg.cut(ref_text):
        tag_key = tag[:1] if tag else ''
        weight = POS_WEIGHT_ZH.get(tag_key, 0.5)
        ref_words.append(w)
        # 以词为键的权重（lower不适用中文，这里直接用原词）
        if w not in ref_weights:
            ref_weights[w] = weight

    # 生成文本用普通分词
    gen_words = jieba.lcut(gen_text)

    # 词频
    from collections import Counter
    ref_counts = Counter(ref_words)
    gen_counts = Counter(gen_words)

    match_weight = 0.0
    ref_total_weight = 0.0
    for word, ref_cnt in ref_counts.items():
        weight = ref_weights.get(word, 0.5)
        match_count = min(gen_counts.get(word, 0), ref_cnt)
        match_weight += match_count * weight
        ref_total_weight += ref_cnt * weight

    return (match_weight / ref_total_weight) if ref_total_weight > 0 else 0.0


def compute_weighted_rouge1_zh_avg(refs: List[str], hyps: List[str]) -> float:
    if not refs:
        return 0.0
    scores = [weighted_rouge1_zh(h, r) for h, r in zip(hyps, refs)]
    return float(sum(scores) / len(scores))


def evaluate_csv_files(csv_files, data_dir, prefix=""):
    """评估CSV文件"""
    results = {}
    for fname in csv_files:
        csv_path = os.path.join(data_dir, fname)
        if not os.path.exists(csv_path):
            print(f"⚠️  文件不存在: {csv_path}")
            continue
            
        print(f"📊 正在评估: {prefix}{fname}")
        rows = read_csv_rows(csv_path)
        refs_raw = [(r.get('Ground_Truth_Reasoning') or '').strip() for r in rows]
        hyps_raw = [(r.get('Predicted_Reasoning') or '').strip() for r in rows]
        
        # 跳过空内容样本，只保留两边都有内容的
        valid_pairs = [(h, r) for h, r in zip(hyps_raw, refs_raw) if h != '' and r != '']
        
        if not valid_pairs:
            results[fname] = {
                'num_samples': len(rows),
                'valid_samples': 0,
                'type_acc': None,
                'bertscore_f1': None,
                'moverscore': None,
                'zh_weighted_rouge1': None,
            }
            continue
            
        hyps = [h for h, _ in valid_pairs]
        refs = [r for _, r in valid_pairs]

        type_acc = compute_type_acc(rows)
        print(f"   计算BERTScore...")
        bert_f1 = compute_bertscore_f1(refs, hyps)
        try:
            print(f"   计算MoverScore...")
            mover = compute_moverscore_offline(refs, hyps)
        except Exception as e:
            print(f"   MoverScore计算失败: {e}")
            mover = None
        try:
            print(f"   计算中文加权ROUGE-1...")
            zh_rouge = compute_weighted_rouge1_zh_avg(refs, hyps)
        except Exception as e:
            print(f"   中文加权ROUGE-1计算失败: {e}")
            zh_rouge = None

        results[fname] = {
            'num_samples': len(rows),
            'valid_samples': len(valid_pairs),
            'type_acc': type_acc,
            'bertscore_f1': bert_f1,
            'moverscore': mover,
            'zh_weighted_rouge1': zh_rouge,
        }
        print(f"   ✅ 评估完成")

    return results


if __name__ == '__main__':
    print("🚀 开始评估所有结果...\n")
    
    all_results = {}
    
    # 评估所有目录
    for dir_name, dir_path in OUTPUT_DIRS.items():
        print("="*60)
        print(f"📁 评估 {dir_name}/ 目录:")
        print("="*60)
        
        csv_files = OUTPUT_FILES[dir_name]
        results = evaluate_csv_files(csv_files, dir_path, f"{dir_name}/")
        all_results.update(results)
    
    # 输出结果汇总
    print("\n" + "="*60)
    print("📊 评估结果汇总:")
    print("="*60)
    
    def fmt(x):
        return 'NA' if x is None else f"{x:.4f}"
    
    # 按目录分组显示结果
    model_mapping = {
        'output_claude': 'claude',
        'output_ds': 'deepseek', 
        'output_qwen': 'qwen',
        'output_gpt': 'gpt'
    }
    
    for dir_name in OUTPUT_DIRS.keys():
        print(f"\n🔍 {dir_name.upper()} 结果:")
        model_name = model_mapping[dir_name]
        dir_results = {k: v for k, v in all_results.items() if k.startswith(model_name)}
        
        for fname, res in dir_results.items():
            print(
                f"  [{fname}] samples={res['num_samples']}, valid={res['valid_samples']}, "
                f"type_acc={fmt(res['type_acc'])}, bertscore_f1={fmt(res['bertscore_f1'])}, "
                f"moverscore={fmt(res['moverscore'])}, zh_weighted_rouge1={fmt(res['zh_weighted_rouge1'])}")
    
    # 对比分析 - Baseline方法
    print("\n" + "="*60)
    print("🔍 Baseline方法对比分析:")
    print("="*60)
    
    baseline_results = {}
    for dir_name in OUTPUT_DIRS.keys():
        model_name = model_mapping[dir_name]  # 使用映射获取正确的模型名
        baseline_file = f"{model_name}_output_baseline.csv"
        if baseline_file in all_results:
            baseline_results[model_name] = all_results[baseline_file]
    
    if baseline_results:
        print("类型准确度对比:")
        for model, res in baseline_results.items():
            print(f"  {model.upper()}: {fmt(res['type_acc'])}")
        
        print("\nBERTScore F1对比:")
        for model, res in baseline_results.items():
            print(f"  {model.upper()}: {fmt(res['bertscore_f1'])}")
        
        print("\nMoverScore对比:")
        for model, res in baseline_results.items():
            print(f"  {model.upper()}: {fmt(res['moverscore'])}")
        
        print("\n中文加权ROUGE-1对比:")
        for model, res in baseline_results.items():
            print(f"  {model.upper()}: {fmt(res['zh_weighted_rouge1'])}")
    
    # 对比分析 - Generation Module方法
    print("\n" + "="*60)
    print("🔍 Generation Module方法对比分析:")
    print("="*60)
    
    generation_results = {}
    for dir_name in OUTPUT_DIRS.keys():
        model_name = model_mapping[dir_name]  # 使用映射获取正确的模型名
        generation_file = f"{model_name}_output_generate_module.csv"
        if generation_file in all_results:
            generation_results[model_name] = all_results[generation_file]
    
    if generation_results:
        print("类型准确度对比:")
        for model, res in generation_results.items():
            print(f"  {model.upper()}: {fmt(res['type_acc'])}")
        
        print("\nBERTScore F1对比:")
        for model, res in generation_results.items():
            print(f"  {model.upper()}: {fmt(res['bertscore_f1'])}")
        
        print("\nMoverScore对比:")
        for model, res in generation_results.items():
            print(f"  {model.upper()}: {fmt(res['moverscore'])}")
        
        print("\n中文加权ROUGE-1对比:")
        for model, res in generation_results.items():
            print(f"  {model.upper()}: {fmt(res['zh_weighted_rouge1'])}")
    
    # 模型间对比
    print("\n" + "="*60)
    print("🏆 最佳模型分析:")
    print("="*60)
    
    if baseline_results and generation_results:
        # 找出最佳Baseline模型
        best_baseline = max(baseline_results.items(), key=lambda x: x[1]['type_acc'] if x[1]['type_acc'] is not None else 0)
        print(f"最佳Baseline模型: {best_baseline[0].upper()} (类型准确度: {fmt(best_baseline[1]['type_acc'])})")
        
        # 找出最佳Generation Module模型
        best_generation = max(generation_results.items(), key=lambda x: x[1]['type_acc'] if x[1]['type_acc'] is not None else 0)
        print(f"最佳Generation Module模型: {best_generation[0].upper()} (类型准确度: {fmt(best_generation[1]['type_acc'])})")
        
        # 方法间对比
        if best_baseline[1]['type_acc'] is not None and best_generation[1]['type_acc'] is not None:
            improvement = best_generation[1]['type_acc'] - best_baseline[1]['type_acc']
            print(f"Generation Module相对Baseline的提升: {improvement:+.4f} ({improvement*100:+.2f}%)")
    
    print("\n🎉 评估完成！")



