import csv
import os
from typing import List, Dict
import inspect


BASE_DIR = os.path.dirname(__file__)
DATA_DIR = os.path.join(BASE_DIR, 'output_gpt')
# 自动检测可用的CSV文件
def get_available_csv_files():
    """自动检测exp3_output目录中可用的三个pipeline的CSV文件"""
    if not os.path.exists(DATA_DIR):
        print(f"❌ 目录不存在: {DATA_DIR}")
        return []
    
    available_files = []
    pipeline_types = ['baseline', 'prototype', 'kg']
    
    for pipeline in pipeline_types:
        # 查找匹配的CSV文件
        for file in os.listdir(DATA_DIR):
            if file.startswith(f'test_set_{pipeline}_') and file.endswith('.csv'):
                available_files.append(file)
                break
    
    # 按pipeline类型排序
    available_files.sort(key=lambda x: pipeline_types.index(x.split('_')[2]))
    
    if not available_files:
        print(f"⚠️  在 {DATA_DIR} 中未找到任何pipeline结果文件")
        print(f"   期望的文件名格式: test_set_{{baseline|prototype|kg}}_{{model_name}}.csv")
    
    return available_files

# 默认CSV文件列表（如果自动检测失败）
DEFAULT_CSV_FILES = [
    'test_set_baseline_deepseek_v3.csv',
    'test_set_prototype_deepseek_v3.csv',
    'test_set_kg_deepseek_v3.csv',
]
BERT_MODEL_DIR = os.path.join(BASE_DIR, 'bert-base-chinese')


def read_csv_rows(csv_path: str) -> List[Dict[str, str]]:
    rows: List[Dict[str, str]] = []
    with open(csv_path, 'r', encoding='utf-8-sig', newline='') as f:
        reader = csv.DictReader(f)
        for row in reader:
            rows.append(row)
    required_cols = {
        'Ground_Truth',
        'LLM_Output',
    }
    missing = required_cols - set(rows[0].keys()) if rows else required_cols
    if missing:
        raise ValueError(f"CSV缺少必要列: {sorted(missing)}")
    return rows


def compute_bertscore_f1(refs: List[str], hyps: List[str]) -> float:
    from bert_score import score as bert_score

    _, _, f1 = bert_score(
        cands=hyps,
        refs=refs,
        lang='zh',
        model_type=BERT_MODEL_DIR,
        num_layers=12,
        rescale_with_baseline=False,
        device='cpu',
    )
    return float(f1.mean().item()) if hasattr(f1, 'mean') else float(sum(f1) / len(f1))


def compute_moverscore_offline(refs: List[str], hyps: List[str]) -> float:
    os.environ.setdefault('TRANSFORMERS_OFFLINE', '1')
    os.environ.setdefault('HF_HUB_OFFLINE', '1')
    os.environ.setdefault('TRANSFORMERS_NO_TF', '1')
    os.environ.setdefault('TRANSFORMERS_NO_FLAX', '1')
    os.environ.setdefault('TRANSFORMERS_NO_JAX', '1')

    try:
        from moverscore import get_idf_dict, word_mover_score  # type: ignore
        try:
            from moverscore import get_stop_words  # type: ignore
        except Exception:
            get_stop_words = None  # type: ignore
    except ImportError:
        from moverscore_v2 import get_idf_dict, word_mover_score, get_stop_words  # type: ignore

    if not refs or not hyps:
        return 0.0

    stop_words = []
    if 'get_stop_words' in globals() and callable(globals()['get_stop_words']):
        try:
            stop_words = globals()['get_stop_words']('zh') or []  # type: ignore
        except Exception:
            stop_words = []

    idf_ref = get_idf_dict(refs)
    idf_hyp = get_idf_dict(hyps)

    sig = inspect.signature(word_mover_score)
    params = sig.parameters

    kwargs = {}
    if 'stop_words' in params:
        kwargs['stop_words'] = stop_words
    if 'n_gram' in params:
        kwargs['n_gram'] = 1
    if 'remove_subwords' in params:
        kwargs['remove_subwords'] = False
    if 'model_name' in params:
        kwargs['model_name'] = BERT_MODEL_DIR

    if 'idf_dict' in params:
        kwargs['idf_dict'] = {**idf_ref, **idf_hyp}
        scores = word_mover_score(refs, hyps, **kwargs)
    else:
        try:
            scores = word_mover_score(refs, hyps, idf_ref, idf_hyp, **kwargs)
        except TypeError:
            scores = word_mover_score(refs=refs, hyps=hyps, idf_ref=idf_ref, idf_hyp=idf_hyp, **kwargs)  # type: ignore

    return float(sum(scores) / len(scores)) if scores else 0.0



# ==== 中文加权 ROUGE-1（基于词性权重）====
POS_WEIGHT_ZH = {
    'n': 1.0,
    'v': 0.9,
    'a': 0.7,
    'd': 0.7,
    'm': 0.6,
    't': 0.8,
    'r': 0.3,
    'p': 0.3,
    'c': 0.3,
    'u': 0.2,
}


def weighted_rouge1_zh(gen_text: str, ref_text: str) -> float:
    import jieba
    import jieba.posseg as pseg

    ref_words = []
    ref_weights: Dict[str, float] = {}
    for w, tag in pseg.cut(ref_text):
        tag_key = tag[:1] if tag else ''
        weight = POS_WEIGHT_ZH.get(tag_key, 0.5)
        ref_words.append(w)
        if w not in ref_weights:
            ref_weights[w] = weight

    gen_words = jieba.lcut(gen_text)

    from collections import Counter
    ref_counts = Counter(ref_words)
    gen_counts = Counter(gen_words)

    match_weight = 0.0
    ref_total_weight = 0.0
    for word, ref_cnt in ref_counts.items():
        weight = ref_weights.get(word, 0.5)
        match_count = min(gen_counts.get(word, 0), ref_cnt)
        match_weight += match_count * weight
        ref_total_weight += ref_cnt * weight

    return (match_weight / ref_total_weight) if ref_total_weight > 0 else 0.0


def compute_weighted_rouge1_zh_avg(refs: List[str], hyps: List[str]) -> float:
    if not refs:
        return 0.0
    scores = [weighted_rouge1_zh(h, r) for h, r in zip(hyps, refs)]
    return float(sum(scores) / len(scores))


if __name__ == '__main__':
    # 自动检测可用的CSV文件
    csv_files = get_available_csv_files()
    if not csv_files:
        print("❌ 未找到可用的CSV文件，使用默认列表")
        csv_files = DEFAULT_CSV_FILES
    
    print(f"📊 找到 {len(csv_files)} 个pipeline结果文件:")
    for fname in csv_files:
        print(f"  - {fname}")
    print()
    
    results = {}
    for fname in csv_files:
        csv_path = os.path.join(DATA_DIR, fname)
        if not os.path.exists(csv_path):
            continue
        rows = read_csv_rows(csv_path)
        refs_raw = [(r.get('Ground_Truth') or '').strip() for r in rows]
        hyps_raw = [(r.get('LLM_Output') or '').strip() for r in rows]
        
        # 跳过空内容样本和LLM相关异常样本，只保留两边都有内容的
        valid_pairs = [(h, r) for h, r in zip(hyps_raw, refs_raw) 
                      if h != '' and r != '' and not h.startswith('LLM')]
        
        # 统计跳过的样本
        empty_hyp = sum(1 for h in hyps_raw if h == '')
        empty_ref = sum(1 for r in refs_raw if r == '')
        llm_error = sum(1 for h in hyps_raw if h.startswith('LLM'))
        skipped_total = empty_hyp + empty_ref + llm_error
        
        if not valid_pairs:
            results[fname] = {
                'num_samples': len(rows),
                'valid_samples': 0,
                'skipped_empty_hyp': empty_hyp,
                'skipped_empty_ref': empty_ref,
                'skipped_llm_error': llm_error,
                'skipped_total': skipped_total,
                'bertscore_f1': None,
                'moverscore': None,
                'zh_weighted_rouge1': None,
            }
            continue
            
        hyps = [h for h, _ in valid_pairs]
        refs = [r for _, r in valid_pairs]

        bert_f1 = compute_bertscore_f1(refs, hyps)
        try:
            mover = compute_moverscore_offline(refs, hyps)
        except Exception:
            mover = None
        try:
            zh_rouge = compute_weighted_rouge1_zh_avg(refs, hyps)
        except Exception:
            zh_rouge = None

        results[fname] = {
            'num_samples': len(rows),
            'valid_samples': len(valid_pairs),
            'skipped_empty_hyp': empty_hyp,
            'skipped_empty_ref': empty_ref,
            'skipped_llm_error': llm_error,
            'skipped_total': skipped_total,
            'bertscore_f1': bert_f1,
            'moverscore': mover,
            'zh_weighted_rouge1': zh_rouge,
        }

    # 打印结果表格
    print("=" * 80)
    print("📊 三个Pipeline性能评估结果")
    print("=" * 80)
    
    # 表头
    print(f"{'Pipeline':<15} {'Samples':<10} {'Valid':<8} {'Skipped':<10} {'BERTScore':<12} {'MoverScore':<12} {'ROUGE-1':<12}")
    print("-" * 90)
    
    for fname, res in results.items():
        # 提取pipeline类型
        pipeline_type = fname.split('_')[2].title()
        
        def fmt(x):
            return 'N/A' if x is None else f"{x:.4f}"
        
        print(f"{pipeline_type:<15} {res['num_samples']:<10} {res['valid_samples']:<8} {res['skipped_total']:<10} "
              f"{fmt(res['bertscore_f1']):<12} {fmt(res['moverscore']):<12} {fmt(res['zh_weighted_rouge1']):<12}")
    
    print("-" * 80)
    
    # 详细结果
    print("\n📋 详细结果:")
    for fname, res in results.items():
        pipeline_type = fname.split('_')[2].title()
        print(f"\n🔹 {pipeline_type} Pipeline ({fname}):")
        print(f"   样本总数: {res['num_samples']}")
        print(f"   有效样本: {res['valid_samples']}")
        print(f"   跳过样本: {res['skipped_total']}")
        if res['skipped_total'] > 0:
            print(f"     - 空输出: {res['skipped_empty_hyp']}")
            print(f"     - 空真值: {res['skipped_empty_ref']}")
            print(f"     - LLM异常(以'LLM'开头): {res['skipped_llm_error']}")
        print(f"   BERTScore F1: {fmt(res['bertscore_f1'])}")
        print(f"   MoverScore: {fmt(res['moverscore'])}")
        print(f"   中文加权ROUGE-1: {fmt(res['zh_weighted_rouge1'])}")
    
    # 性能对比分析
    print("\n" + "=" * 80)
    print("📈 性能对比分析")
    print("=" * 80)
    
    # 找出最佳性能
    best_bert = max((res['bertscore_f1'] for res in results.values() if res['bertscore_f1'] is not None), default=None)
    best_mover = max((res['moverscore'] for res in results.values() if res['moverscore'] is not None), default=None)
    best_rouge = max((res['zh_weighted_rouge1'] for res in results.values() if res['zh_weighted_rouge1'] is not None), default=None)
    
    if best_bert:
        print(f"🏆 最佳BERTScore F1: {best_bert:.4f}")
        for fname, res in results.items():
            if res['bertscore_f1'] == best_bert:
                pipeline_type = fname.split('_')[2].title()
                print(f"   由 {pipeline_type} Pipeline 获得")
    
    if best_mover:
        print(f"🏆 最佳MoverScore: {best_mover:.4f}")
        for fname, res in results.items():
            if res['moverscore'] == best_mover:
                pipeline_type = fname.split('_')[2].title()
                print(f"   由 {pipeline_type} Pipeline 获得")
    
    if best_rouge:
        print(f"🏆 最佳中文加权ROUGE-1: {best_rouge:.4f}")
        for fname, res in results.items():
            if res['zh_weighted_rouge1'] == best_rouge:
                pipeline_type = fname.split('_')[2].title()
                print(f"   由 {pipeline_type} Pipeline 获得")
    
    # 移除生成CSV报告的功能


