#!/usr/bin/env python
#-*- coding:utf-8 _*-
# Filename: compare_adaptive_strategies.py
# 用于直接对比 SAR-GNOT (静态容量) 和 MoR-GNOT (静态深度)
# 两种自适应策略的性能、吞吐量和负载均衡性。

import torch
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import argparse
import os
import sys
from tqdm import tqdm
import json

# 确保项目路径在 sys.path 中
sys.path.append(os.path.dirname(os.path.abspath(__file__)))

from data_utils import get_dataset, get_model, MIODataLoader
from utils import get_seed

def relative_l2_error(pred, target):
    """计算相对 L2 范数误差。"""
    if pred.shape != target.shape:
        target = target.view(pred.shape)
    numerator = torch.norm(pred.flatten() - target.flatten(), p=2)
    denominator = torch.norm(target.flatten(), p=2)
    return (numerator / (denominator + 1e-8)).item()

def measure_throughput(model, loader, device):
    """测量模型的端到端吞吐量。"""
    model.eval()
    total_samples = 0
    torch.cuda.synchronize(device=device)
    start_time = torch.cuda.Event(enable_timing=True)
    end_time = torch.cuda.Event(enable_timing=True)
    start_time.record()

    with torch.no_grad():
        for g_batch, u_p_batch, inputs_f_batch in loader:
            g, u_p = g_batch.to(device), u_p_batch.to(device)
            if hasattr(inputs_f_batch, 'to'):
                inputs_f_batch = inputs_f_batch.to(device)
            _ = model(g, u_p, inputs_f_batch)
            total_samples += g.batch_size

    end_time.record()
    torch.cuda.synchronize(device=device)
    elapsed_time_sec = start_time.elapsed_time(end_time) / 1000.0
    return total_samples / elapsed_time_sec if elapsed_time_sec > 0 else 0

def analyze_model(model, model_name, loader, device, args):
    """
    对单个模型进行全面的分析，收集误差、负载和吞吐量数据。
    """
    model.eval()
    per_sample_errors = []
    layer_loads_all_samples = []
    
    # --- 1. 逐样本收集误差和负载数据 (batch_size=1) ---
    pbar_detail = tqdm(loader, desc=f"Analyzing {model_name} (details)")
    with torch.no_grad():
        for g_batch, u_p_batch, inputs_f_batch in pbar_detail:
            g, u_p = g_batch.to(device), u_p_batch.to(device)
            if hasattr(inputs_f_batch, 'to'):
                inputs_f_batch = inputs_f_batch.to(device)

            model_output = model(g, u_p, inputs_f_batch)
            
            # --- 关键：根据模型类型处理返回值和负载信息 ---
            if 'StaticDepth' in model.__class__.__name__:
                # GNOT_StaticDepth 返回 (out, aux_dict)
                out, aux_info = model_output
                layer_loads = aux_info["layer_capacities"]
            else: # SR_GNOT_SS 只返回 out
                out = model_output
                # 我们在外部为其计算静态负载
                N = g.number_of_nodes()
                if model.capacity_ratios is not None:
                    ratios = model.capacity_ratios
                else:
                    ratios = np.linspace(1.0, model.final_keep_ratio, model.recursion_depth)
                layer_loads = [max(1, int(N * r)) for r in ratios]
            
            layer_loads_all_samples.append(layer_loads)
            
            error = relative_l2_error(out, g.ndata['y'])
            per_sample_errors.append(error)
            
    # --- 2. 计算统计结果 ---
    avg_error = np.mean(per_sample_errors)
    load_data = np.array(layer_loads_all_samples)
    avg_loads_per_layer = np.mean(load_data, axis=0)
    variance_loads_per_layer = np.var(load_data, axis=0)
    total_avg_computation = np.sum(avg_loads_per_layer)

    # --- 3. 测量吞吐量 ---
    print(f"Measuring throughput for {model_name} with batch_size={args.benchmark_batch_size}...")
    throughput_loader = MIODataLoader(loader.dataset, batch_size=args.benchmark_batch_size, shuffle=False)
    throughput = measure_throughput(model, throughput_loader, device)

    return {
        'avg_error': avg_error,
        'avg_loads_per_layer': avg_loads_per_layer.tolist(),
        'variance_loads_per_layer': variance_loads_per_layer.tolist(),
        'total_avg_computation': total_avg_computation,
        'throughput': throughput,
        'model_depth': model.n_layers if hasattr(model, 'n_layers') else model.recursion_depth if hasattr(model, 'recursion_depth') else model.max_depth
    }

def generate_comparison_report(results, output_dir, dataset_name):
    """
    生成所有对比图表和统计表格。
    """
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # --- 1. 打印最终性能对比表 ---
    sar_model_key = next((key for key in results if 'SR_GNOT' in key), None)
    mor_model_key = next((key for key in results if 'StaticDepth' in key), None)
    
    if not sar_model_key or not mor_model_key:
        print("Error: Could not find both SAR and MoR model results to compare.")
        return

    header = f"{'Metric':<30} | {sar_model_key:<35} | {mor_model_key:<35}"
    print("\n" + "="*len(header))
    print(" " * (len(header)//2 - 14) + "ADAPTIVE STRATEGY COMPARISON")
    print("="*len(header))
    print(header)
    print("-"*len(header))
    
    sar_res = results[sar_model_key]
    mor_res = results[mor_model_key]
    
    print(f"{'Mean Rel2 Error':<30} | {sar_res['avg_error']:<35.4f} | {mor_res['avg_error']:<35.4f}")
    print(f"{'Throughput (samples/s)':<30} | {sar_res['throughput']:<35.2f} | {mor_res['throughput']:<35.2f}")
    print(f"{'Total Avg. Comp. (Tokens)':<30} | {sar_res['total_avg_computation']:<35.0f} | {mor_res['total_avg_computation']:<35.0f}")
    print(f"{'Avg. Load Variance':<30} | {np.mean(sar_res['variance_loads_per_layer']):<35.2e} | {np.mean(mor_res['variance_loads_per_layer']):<35.2e}")
    print("="*len(header))

    # --- 2. 绘制负载-深度对比图 ---
    plt.style.use('seaborn-whitegrid')
    fig, ax = plt.subplots(figsize=(12, 7))
    
    depth = sar_res['model_depth']
    layer_indices = np.arange(depth)
    
    ax.step(layer_indices, sar_res['avg_loads_per_layer'], where='post', 
            color='mediumblue', linewidth=2.5, label='SBR-GNOT (Static Capacity)')
            
    mor_avg = np.array(mor_res['avg_loads_per_layer'])
    mor_std = np.sqrt(np.array(mor_res['variance_loads_per_layer']))
    ax.plot(layer_indices, mor_avg, color='firebrick', marker='o', linestyle='--', label='MoR-GNOT (Avg. Load)')
    ax.fill_between(layer_indices, mor_avg - mor_std, mor_avg + mor_std, 
                    color='lightcoral', alpha=0.3, label='MoR-GNOT (±1 StdDev)')

    ax.set_title(f'Computation Load Comparison\n(Dataset: Pipe)', fontsize=25)
    ax.set_xlabel('Layer Depth', fontsize=25)
    ax.set_ylabel('Number of Active Tokens', fontsize=25)
    ax.set_xticks(layer_indices)
    ax.tick_params(axis='both', labelsize=20)
    ax.legend(fontsize=20)
    ax.grid(True, which='both', linestyle='--', linewidth=0.5)
    
    save_path = os.path.join(output_dir, f'load_comparison_{dataset_name}.png')
    plt.savefig(save_path, dpi=1000, bbox_inches='tight')
    plt.close(fig)
    print(f"\nLoad comparison plot saved to: {save_path}")

    json_path = os.path.join(output_dir, f'comparison_results_{dataset_name}.json')
    with open(json_path, 'w') as f:
        json.dump(results, f, indent=4)
    print(f"Detailed results saved to: {json_path}")

def main(args):
    """主函数"""
    get_seed(args.seed)
    device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() and not args.no_cuda else "cpu")
    print(f"Using device: {device}")

    _, test_dataset = get_dataset(args)
    detail_loader = MIODataLoader(test_dataset, batch_size=1, shuffle=False)
    
    results = {}
    
    # --- 分析 SAR-GNOT ---
    print("--- Analyzing SAR-GNOT (Static Capacity) Model ---")
    sar_ckpt = torch.load(args.sar_model_path, map_location=device)
    sar_model = get_model(sar_ckpt['args']).to(device)
    sar_model.load_state_dict(sar_ckpt['model'])
    sar_model_name = os.path.splitext(os.path.basename(args.sar_model_path))[0]
    results[sar_model_name] = analyze_model(sar_model, sar_model_name, detail_loader, device, args)

    # --- 分析 MoR-GNOT ---
    print("\n--- Analyzing MoR-GNOT (Static Depth) Model ---")
    mor_ckpt = torch.load(args.mor_model_path, map_location=device)
    mor_model = get_model(mor_ckpt['args']).to(device)
    mor_model.load_state_dict(mor_ckpt['model'])
    mor_model_name = os.path.splitext(os.path.basename(args.mor_model_path))[0]
    results[mor_model_name] = analyze_model(mor_model, mor_model_name, detail_loader, device, args)

    # --- 生成报告 ---
    generate_comparison_report(results, args.output_dir, args.dataset)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Compare SAR-GNOT (Static Capacity) and MoR-GNOT (Static Depth) strategies.")
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--no-cuda', action='store_true')
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--dataset', type=str, required=True)
    parser.add_argument('--sar_model_path', type=str, required=True, help='Path to the trained SAR-GNOT (SR_GNOT_SS) checkpoint.')
    parser.add_argument('--mor_model_path', type=str, required=True, help='Path to the trained MoR-GNOT (GNOT_StaticDepth) checkpoint.')
    parser.add_argument('--output_dir', type=str, default='./results/comparison', help='Directory to save the plots and results.')
    parser.add_argument('--benchmark_batch_size', type=int, default=16)
    parser.add_argument('--sort_data', type=int, default=0, help='GPU device ID.')
    parser.add_argument('--train-num', default='all')
    parser.add_argument('--test-num', default='all')
    parser.add_argument('--use-normalizer', default='none')
    parser.add_argument('--normalize_x', default='none')
    
    args = parser.parse_args()
    main(args)