import wandb
from collections import defaultdict
from typing import Any, Dict, Tuple, List

# 初始化 API
api = wandb.Api()

# 指定您的项目路径（格式为 "entity/project"）和要比较的两个 run 的 ID
run_id_1 = "astrid_tuning_llm/verl-qwen3-4b-oct/nq4o1e10"
run_id_2 = "astrid_tuning_llm/verl-qwen3-4b-oct/ltejsqyq"
# project_path = "astrid_tuning_llm/verl-qwen3-4b-oct"


def flatten_dict(d: Dict, parent_key: str = '', sep: str = '.') -> Dict[str, Any]:
    """将嵌套字典扁平化"""
    items = []
    for k, v in d.items():
        new_key = f"{parent_key}{sep}{k}" if parent_key else k
        if isinstance(v, dict):
            items.extend(flatten_dict(v, new_key, sep=sep).items())
        else:
            items.append((new_key, v))
    return dict(items)


def find_differences(config1: Dict, config2: Dict) -> List[Tuple[str, Any, Any]]:
    """找出两个配置之间的所有差异，返回 (键路径, 值1, 值2) 的列表"""
    flat1 = flatten_dict(config1)
    flat2 = flatten_dict(config2)
    
    differences = []
    all_keys = set(flat1.keys()) | set(flat2.keys())
    
    for key in sorted(all_keys):
        val1 = flat1.get(key)
        val2 = flat2.get(key)
        if val1 != val2:
            differences.append((key, val1, val2))
    
    return differences


def group_by_category(differences: List[Tuple[str, Any, Any]]) -> Dict[str, List[Tuple[str, Any, Any]]]:
    """按顶层键（类别）对差异进行分组"""
    grouped = defaultdict(list)
    for key, val1, val2 in differences:
        # 获取顶层键（第一个点之前的部分，如果没有点则使用整个键）
        category = key.split('.')[0] if '.' in key else '其他'
        grouped[category].append((key, val1, val2))
    return dict(grouped)


def truncate_path(path: str, max_length: int = 80) -> str:
    """截断过长的路径，保留开头和结尾"""
    if len(path) <= max_length:
        return path
    return path[:max_length-20] + "..." + path[-17:]


def format_value(val: Any) -> str:
    """格式化值以便显示"""
    if val is None:
        return "None"
    if isinstance(val, str):
        # 如果是路径，尝试截断
        if len(val) > 80:
            return truncate_path(val)
        return val
    return str(val)


try:
    # 获取两个 run 的对象
    run1 = api.run(f"{run_id_1}")
    run2 = api.run(f"{run_id_2}")

    # 获取它们的配置（参数）
    config1 = run1.config
    config2 = run2.config

    # 获取运行名称
    run1_name = run1.name or run_id_1
    run2_name = run2.name or run_id_2

    # 比较两个配置字典
    differences = find_differences(config1, config2)
    
    if not differences:
        print("=" * 80)
        print("✓ 两个 run 的参数完全相同。")
        print("=" * 80)
    else:
        # 按类别分组
        grouped_diffs = group_by_category(differences)
        
        # 打印标题
        print("=" * 80)
        print(f"比较结果: {run1_name} vs {run2_name}")
        print("=" * 80)
        print(f"\n总共发现 {len(differences)} 个参数差异，分布在 {len(grouped_diffs)} 个类别中。\n")
        
        # 按类别打印差异
        for category in sorted(grouped_diffs.keys()):
            category_diffs = grouped_diffs[category]
            print(f"{'─' * 80}")
            print(f"📂 类别: {category} ({len(category_diffs)} 个差异)")
            print(f"{'─' * 80}")
            
            for key, val1, val2 in category_diffs:
                print(f"\n  🔹 {key}")
                print(f"    Run 1: {format_value(val1)}")
                print(f"    Run 2: {format_value(val2)}")
            
            print()
        
        # 打印总结
        print("=" * 80)
        print("差异总结:")
        print(f"  总差异数: {len(differences)}")
        for category in sorted(grouped_diffs.keys()):
            print(f"  {category}: {len(grouped_diffs[category])} 个差异")
        print("=" * 80)

except Exception as e:
    print(f"❌ 发生错误: {e}")
    print("请检查您的项目路径和 run ID 是否正确。")