import csv
import os
from typing import List, Dict
import inspect

# 强制使用CPU，避免CUDA问题
os.environ['CUDA_VISIBLE_DEVICES'] = ''
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'


BASE_DIR = os.path.dirname(__file__)
OUTPUT_DIR = os.path.join(BASE_DIR, 'output')

# 自动检测可用的CSV文件
def get_available_csv_files():
    """检测gpt_output和qwen_output目录中的CSV文件"""
    model_dirs = ['gpt_output', 'qwen_output']
    available_files = []
    pipeline_types = ['baseline', 'kg']
    
    for model_dir in model_dirs:
        model_path = os.path.join(OUTPUT_DIR, model_dir)
        
        if not os.path.exists(model_path):
            print(f"⚠️ 目录不存在: {model_path}")
            continue
        
        # 在每个模型目录中查找CSV文件
        for pipeline in pipeline_types:
            for file in os.listdir(model_path):
                if file.startswith(f'test_set_{pipeline}_') and file.endswith('.csv'):
                    # 添加模型信息到文件名
                    model_file = f"{model_dir}/{file}"
                    available_files.append(model_file)
                    break
    
    if not available_files:
        print(f"⚠️  在指定目录中未找到任何pipeline结果文件")
        print(f"   期望的文件名格式: test_set_{{baseline|kg}}_{{model_name}}.csv")
    
    return available_files

# 默认CSV文件列表（如果自动检测失败）
DEFAULT_CSV_FILES = [
    'gpt_output/test_set_baseline_gpt.csv',
    'gpt_output/test_set_kg_gpt.csv',
    'qwen_output/test_set_baseline_qwen.csv',
    'qwen_output/test_set_kg_qwen.csv',
]
BERT_MODEL_NAME = 'bert-base-uncased'  # Use standard Hugging Face model name


def read_csv_rows(csv_path: str) -> List[Dict[str, str]]:
    rows: List[Dict[str, str]] = []
    with open(csv_path, 'r', encoding='utf-8-sig', newline='') as f:
        reader = csv.DictReader(f)
        for row in reader:
            rows.append(row)
    required_cols = {
        'Ground_Truth',
        'LLM_Output',
    }
    missing = required_cols - set(rows[0].keys()) if rows else required_cols
    if missing:
        raise ValueError(f"CSV缺少必要列: {sorted(missing)}")
    return rows


def compute_bertscore_f1(refs: List[str], hyps: List[str]) -> float:
    import os
    # 强制使用CPU，避免CUDA问题
    os.environ['CUDA_VISIBLE_DEVICES'] = ''
    
    try:
        from bert_score import score as bert_score

        _, _, f1 = bert_score(
            cands=hyps,
            refs=refs,
            lang='en',  # Changed to English
            model_type=BERT_MODEL_NAME,  # Uses bert-base-uncased
            num_layers=12,
            rescale_with_baseline=True,  # 启用baseline rescaling，这通常能提高分数
            device='cpu',
            batch_size=16,  # 减小batch size避免内存问题
        )
        return float(f1.mean().item()) if hasattr(f1, 'mean') else float(sum(f1) / len(f1))
    except Exception as e:
        print(f"⚠️ BERTScore计算失败: {e}")
        return 0.0


def compute_moverscore_offline(refs: List[str], hyps: List[str]) -> float:
    import os
    # 强制使用CPU，避免CUDA问题
    os.environ['CUDA_VISIBLE_DEVICES'] = ''
    # 移除离线模式设置，允许在线下载模型
    # os.environ.setdefault('TRANSFORMERS_OFFLINE', '1')
    # os.environ.setdefault('HF_HUB_OFFLINE', '1')
    os.environ.setdefault('TRANSFORMERS_NO_TF', '1')
    os.environ.setdefault('TRANSFORMERS_NO_FLAX', '1')
    os.environ.setdefault('TRANSFORMERS_NO_JAX', '1')

    try:
        from moverscore import get_idf_dict, word_mover_score  # type: ignore
        try:
            from moverscore import get_stop_words  # type: ignore
        except Exception:
            get_stop_words = None  # type: ignore
    except ImportError:
        from moverscore_v2 import get_idf_dict, word_mover_score, get_stop_words  # type: ignore

    if not refs or not hyps:
        return 0.0

    # 使用英文停用词
    stop_words = []
    try:
        if get_stop_words:
            stop_words = get_stop_words('en') or []
        else:
            # 手动定义一些常见的英文停用词
            stop_words = ['the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by', 'is', 'are', 'was', 'were', 'be', 'been', 'have', 'has', 'had', 'do', 'does', 'did', 'will', 'would', 'could', 'should']
    except Exception:
        stop_words = []

    idf_ref = get_idf_dict(refs)
    idf_hyp = get_idf_dict(hyps)

    sig = inspect.signature(word_mover_score)
    params = sig.parameters

    kwargs = {}
    if 'stop_words' in params:
        kwargs['stop_words'] = stop_words
    if 'n_gram' in params:
        kwargs['n_gram'] = 1
    if 'remove_subwords' in params:
        kwargs['remove_subwords'] = False
    if 'model_name' in params:
        kwargs['model_name'] = BERT_MODEL_NAME  # Uses bert-base-uncased

    try:
        if 'idf_dict' in params:
            kwargs['idf_dict'] = {**idf_ref, **idf_hyp}
            scores = word_mover_score(refs, hyps, **kwargs)
        else:
            try:
                scores = word_mover_score(refs, hyps, idf_ref, idf_hyp, **kwargs)
            except TypeError:
                scores = word_mover_score(refs=refs, hyps=hyps, idf_ref=idf_ref, idf_hyp=idf_hyp, **kwargs)  # type: ignore
        
        return float(sum(scores) / len(scores)) if scores else 0.0
    except Exception as e:
        print(f"⚠️ MoverScore计算失败: {e}")
        return 0.0


# ==== 英文加权 ROUGE-1（基于词性权重）====
POS_WEIGHT = {
    'NN': 1.0, 'NNS': 1.0, 'NNP': 1.0, 'NNPS': 1.0,  # 名词
    'VB': 0.9, 'VBD': 0.9, 'VBG': 0.9, 'VBN': 0.9, 'VBP': 0.9, 'VBZ': 0.9,  # 动词
    'JJ': 0.7, 'JJR': 0.7, 'JJS': 0.7,  # 形容词
    'RB': 0.7, 'RBR': 0.7, 'RBS': 0.7,  # 副词
    'PRP': 0.3, 'PRP$': 0.3, 'DT': 0.3, 'CC': 0.3, 'IN': 0.3,  # 代词、连词、介词等
}


def weighted_rouge1_en(gen_text: str, ref_text: str) -> float:
    from nltk import pos_tag, word_tokenize
    import nltk
    nltk.download('punkt')
    nltk.download('averaged_perceptron_tagger')

    # Tokenize and POS tag the reference text
    ref_words = word_tokenize(ref_text)
    ref_pos_tags = pos_tag(ref_words)
    ref_weights: Dict[str, float] = {}
    weighted_ref_words = []
    for word, tag in ref_pos_tags:
        weight = POS_WEIGHT.get(tag, 0.5)  # Default weight for unknown tags
        weighted_ref_words.append(word)
        if word not in ref_weights:
            ref_weights[word] = weight

    # Tokenize the generated text
    gen_words = word_tokenize(gen_text)

    from collections import Counter
    ref_counts = Counter(weighted_ref_words)
    gen_counts = Counter(gen_words)

    match_weight = 0.0
    ref_total_weight = 0.0
    for word, ref_cnt in ref_counts.items():
        weight = ref_weights.get(word, 0.5)
        match_count = min(gen_counts.get(word, 0), ref_cnt)
        match_weight += match_count * weight
        ref_total_weight += ref_cnt * weight

    return (match_weight / ref_total_weight) if ref_total_weight > 0 else 0.0


def compute_weighted_rouge1_en_avg(refs: List[str], hyps: List[str]) -> float:
    if not refs:
        return 0.0
    scores = [weighted_rouge1_en(h, r) for h, r in zip(hyps, refs)]
    return float(sum(scores) / len(scores))


if __name__ == '__main__':
    # 自动检测可用的CSV文件
    csv_files = get_available_csv_files()
    if not csv_files:
        print("❌ 未找到可用的CSV文件")
        exit(1)
    
    print(f"📊 找到 {len(csv_files)} 个pipeline结果文件:")
    for fname in csv_files:
        print(f"  - {fname}")
    print()
    
    results = {}
    for fname in csv_files:
        csv_path = os.path.join(OUTPUT_DIR, fname)
        if not os.path.exists(csv_path):
            print(f"⚠️ 文件不存在: {csv_path}")
            continue
        
        try:
            rows = read_csv_rows(csv_path)
            refs_raw = [(r.get('Ground_Truth') or '').strip() for r in rows]
            hyps_raw = [(r.get('LLM_Output') or '').strip() for r in rows]
        except Exception as e:
            print(f"❌ 读取文件失败 {fname}: {e}")
            continue
        
        # 跳过空内容样本和LLM相关异常样本，只保留两边都有内容的
        valid_pairs = [(h, r) for h, r in zip(hyps_raw, refs_raw) 
                      if h != '' and r != '' and not h.startswith('LLM')]
        
        # 统计跳过的样本
        empty_hyp = sum(1 for h in hyps_raw if h == '')
        empty_ref = sum(1 for r in refs_raw if r == '')
        llm_error = sum(1 for h in hyps_raw if h.startswith('LLM'))
        skipped_total = empty_hyp + empty_ref + llm_error
        
        if not valid_pairs:
            results[fname] = {
                'num_samples': len(rows),
                'valid_samples': 0,
                'skipped_empty_hyp': empty_hyp,
                'skipped_empty_ref': empty_ref,
                'skipped_llm_error': llm_error,
                'skipped_total': skipped_total,
                'bertscore_f1': None,
                'moverscore': None,
                'en_weighted_rouge1': None,  # Changed to English ROUGE-1
            }
            continue
            
        hyps = [h for h, _ in valid_pairs]
        refs = [r for _, r in valid_pairs]

        bert_f1 = compute_bertscore_f1(refs, hyps)
        try:
            mover = compute_moverscore_offline(refs, hyps)
        except Exception:
            mover = None
        try:
            en_rouge = compute_weighted_rouge1_en_avg(refs, hyps)  # Changed to English ROUGE-1
        except Exception:
            en_rouge = None

        results[fname] = {
            'num_samples': len(rows),
            'valid_samples': len(valid_pairs),
            'skipped_empty_hyp': empty_hyp,
            'skipped_empty_ref': empty_ref,
            'skipped_llm_error': llm_error,
            'skipped_total': skipped_total,
            'bertscore_f1': bert_f1,
            'moverscore': mover,
            'en_weighted_rouge1': en_rouge,  # Changed to English ROUGE-1
        }

    # 按模型分组结果
    model_results = {}
    for fname, res in results.items():
        model_name = fname.split('/')[0]  # 提取模型名称 (claude_output)
        # 修复pipeline类型提取逻辑
        if 'baseline' in fname:
            pipeline_type = 'baseline'
        elif 'kg' in fname:
            pipeline_type = 'kg'
        else:
            print(f"⚠️ 无法识别pipeline类型: {fname}")
            continue
        
        if model_name not in model_results:
            model_results[model_name] = {}
        model_results[model_name][pipeline_type] = res
    
    # 打印结果表格
    print("=" * 100)
    print("📊 多模型Pipeline性能评估结果")
    print("=" * 100)
    
    # 表头
    print(f"{'Model':<12} {'Pipeline':<12} {'Samples':<8} {'Valid':<6} {'Skipped':<8} {'BERTScore':<10} {'MoverScore':<10} {'ROUGE-1':<10}")
    print("-" * 100)
    
    def fmt(x):
        return 'N/A' if x is None else f"{x:.4f}"
    
    # 按模型和pipeline类型显示结果
    for model_name in sorted(model_results.keys()):
        model_data = model_results[model_name]
        for pipeline in ['baseline', 'kg']:
            if pipeline in model_data:
                res = model_data[pipeline]
                print(f"{model_name:<12} {pipeline.title():<12} {res['num_samples']:<8} {res['valid_samples']:<6} {res['skipped_total']:<8} "
                      f"{fmt(res['bertscore_f1']):<10} {fmt(res['moverscore']):<10} {fmt(res['en_weighted_rouge1']):<10}")
    
    print("-" * 100)
    
    # 详细结果
    print("\n📋 详细结果:")
    for model_name in sorted(model_results.keys()):
        print(f"\n🤖 模型: {model_name.upper()}")
        model_data = model_results[model_name]
        
        for pipeline in ['baseline', 'kg']:
            if pipeline in model_data:
                res = model_data[pipeline]
                print(f"\n  🔹 {pipeline.title()} Pipeline:")
                print(f"     样本总数: {res['num_samples']}")
                print(f"     有效样本: {res['valid_samples']}")
                print(f"     跳过样本: {res['skipped_total']}")
                if res['skipped_total'] > 0:
                    print(f"       - 空输出: {res['skipped_empty_hyp']}")
                    print(f"       - 空真值: {res['skipped_empty_ref']}")
                    print(f"       - LLM异常(以'LLM'开头): {res['skipped_llm_error']}")
                print(f"     BERTScore F1: {fmt(res['bertscore_f1'])}")
                print(f"     MoverScore: {fmt(res['moverscore'])}")
                print(f"     英文加权ROUGE-1: {fmt(res['en_weighted_rouge1'])}")
    
    # 性能对比分析
    print("\n" + "=" * 100)
    print("📈 性能对比分析")
    print("=" * 100)
    
    # 按pipeline类型分析最佳性能
    pipeline_metrics = {}
    for pipeline in ['baseline', 'kg']:
        pipeline_metrics[pipeline] = {
            'bertscore': [],
            'moverscore': [],
            'rouge': []
        }
        
        for model_name, model_data in model_results.items():
            if pipeline in model_data:
                res = model_data[pipeline]
                if res['bertscore_f1'] is not None:
                    pipeline_metrics[pipeline]['bertscore'].append((model_name, res['bertscore_f1']))
                if res['moverscore'] is not None:
                    pipeline_metrics[pipeline]['moverscore'].append((model_name, res['moverscore']))
                if res['en_weighted_rouge1'] is not None:
                    pipeline_metrics[pipeline]['rouge'].append((model_name, res['en_weighted_rouge1']))
    
    # 显示每个pipeline的最佳模型
    for pipeline in ['baseline', 'kg']:
        print(f"\n🔹 {pipeline.title()} Pipeline 最佳性能:")
        
        # BERTScore
        if pipeline_metrics[pipeline]['bertscore']:
            best_model, best_score = max(pipeline_metrics[pipeline]['bertscore'], key=lambda x: x[1])
            print(f"   🏆 最佳BERTScore F1: {best_score:.4f} (模型: {best_model})")
        
        # MoverScore
        if pipeline_metrics[pipeline]['moverscore']:
            best_model, best_score = max(pipeline_metrics[pipeline]['moverscore'], key=lambda x: x[1])
            print(f"   🏆 最佳MoverScore: {best_score:.4f} (模型: {best_model})")
        
        # ROUGE-1
        if pipeline_metrics[pipeline]['rouge']:
            best_model, best_score = max(pipeline_metrics[pipeline]['rouge'], key=lambda x: x[1])
            print(f"   🏆 最佳英文加权ROUGE-1: {best_score:.4f} (模型: {best_model})")
    
    # 显示每个模型的最佳pipeline
    print(f"\n🤖 各模型最佳Pipeline:")
    for model_name in sorted(model_results.keys()):
        model_data = model_results[model_name]
        best_pipeline = None
        best_score = 0
        
        for pipeline in ['baseline', 'kg']:
            if pipeline in model_data and model_data[pipeline]['bertscore_f1'] is not None:
                if model_data[pipeline]['bertscore_f1'] > best_score:
                    best_score = model_data[pipeline]['bertscore_f1']
                    best_pipeline = pipeline
        
        if best_pipeline:
            print(f"   {model_name}: {best_pipeline.title()} Pipeline (BERTScore: {best_score:.4f})")