import glob
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import argparse

# --- 仅保留指定的4个模型 ---
BACKBONE_ALIAS = {
    'resnet18': 'ResNet-18',
    'tv_resnet50': 'ResNet-50',
    'vit_base_patch14_dinov2.lvd142m': 'ViT-Base (DinoV2)',
    'vit_base_patch16_clip_224': 'ViT-Base (CLIP)'
}

# --- 配置：截断阈值 ---
ACC_CUTOFF_THRESHOLD = 0.95  # 95% 精度

def load_data(log_dir):
    """
    递归加载日志文件
    """
    search_pattern = os.path.join(log_dir, "**", "*.npz")
    files = glob.glob(search_pattern, recursive=True)
    
    data = {} 
    
    if not files:
        print(f"No .npz files found in {log_dir}")
        return None

    print(f"Found {len(files)} log files.")

    for f in files:
        try:
            rel_path = os.path.relpath(f, log_dir)
            path_parts = rel_path.split(os.sep)
            
            if len(path_parts) > 0:
                raw_backbone = path_parts[0] 
            else:
                raw_backbone = "unknown"
            
            if raw_backbone not in BACKBONE_ALIAS:
                continue
            
            backbone_name = BACKBONE_ALIAS[raw_backbone]
            display_name = backbone_name

            loaded = np.load(f, allow_pickle=True)

            if display_name not in data:
                data[display_name] = {
                    'acc': [], 
                    'test_loss': [], 
                    'train_loss': [], 
                    'time': []
                }
            
            if 'test_acc' in loaded and len(loaded['test_acc']) > 0:
                data[display_name]['acc'].append(loaded['test_acc'])
                data[display_name]['test_loss'].append(loaded['test_loss'])
                data[display_name]['time'].append(loaded['wall_time'])
                
                if 'train_loss' in loaded:
                    data[display_name]['train_loss'].append(loaded['train_loss'])

        except Exception as e:
            print(f"Error loading {f}: {e}")
    
    return data

def plot_metric(data, metric_key, y_label, title, filename, fig_dir, use_cutoff=False):
    """
    绘制曲线图
    :param use_cutoff: 是否启用 95% 精度截断。
                       如果 metric_key 是 'acc'，直接根据自身数值截断。
                       如果 metric_key 是 'loss' 类，则参考对应的 'acc' 数据进行截断。
    """
    plt.figure(figsize=(10, 7))
    
    colors = plt.cm.tab10(np.linspace(0, 1, 10))
    markers = ['o', 's', '^', 'D', 'v', '<', '>'] 
    
    has_data = False
    sorted_keys = sorted(data.keys()) 

    for i, name in enumerate(sorted_keys):
        metrics = data[name]
        if metric_key not in metrics or not metrics[metric_key]:
            continue

        raw_data_list = metrics[metric_key] 
        if len(raw_data_list) == 0: continue

        processed_data = []
        
        # --- 截断逻辑 ---
        if use_cutoff:
            # 无论画什么图，如果启用了截断，统一参考 'acc' 数据
            # 因为截断的标准是 "精度达到95%"
            ref_data_list = None
            if 'acc' in metrics:
                ref_data_list = metrics['acc']
            
            if ref_data_list:
                for idx, run in enumerate(raw_data_list):
                    # 确保参考数据存在且长度够用
                    if idx >= len(ref_data_list): 
                        processed_data.append(run)
                        continue
                        
                    ref_run = np.array(ref_data_list[idx])
                    
                    # 找到 accuracy 达到 95% 的位置
                    cutoff_indices = np.where(ref_run >= ACC_CUTOFF_THRESHOLD)[0]
                    
                    if len(cutoff_indices) > 0:
                        cutoff_idx = cutoff_indices[0]
                        # 稍微多留一点点 (例如 +2) 以展示趋势
                        cutoff_len = min(len(run), cutoff_idx + 2)
                        processed_data.append(run[:cutoff_len])
                    else:
                        processed_data.append(run)
            else:
                # 如果没有 acc 数据作为参考，就不截断
                processed_data = raw_data_list
        else:
            processed_data = raw_data_list

        if not processed_data: continue

        min_len = min(len(run) for run in processed_data)
        if min_len == 0: continue
        
        truncated_data = [run[:min_len] for run in processed_data]
        arr = np.array(truncated_data) 
        
        has_data = True
        mean = np.mean(arr, axis=0)
        std = np.std(arr, axis=0)
        
        x = np.arange(1, min_len + 1)

        color = colors[i % len(colors)]
        marker = markers[i % len(markers)]
        
        plt.plot(x, mean, label=name, linewidth=2, color=color, marker=marker, markevery=max(1, len(x)//10), markersize=6)
        plt.fill_between(x, mean - std, mean + std, color=color, alpha=0.15)
        
    if not has_data:
        print(f"No valid data to plot for {metric_key}")
        plt.close()
        return
    
    final_title = title
    if use_cutoff:
        final_title += f" (Truncated at Acc={ACC_CUTOFF_THRESHOLD*100:.0f}%)"

    plt.xlabel("Communication Rounds", fontsize=12)
    plt.ylabel(y_label, fontsize=12)
    plt.title(final_title, fontsize=14)
    plt.legend(fontsize=10, loc='best')
    
    # 仅在画 Acc 时显示阈值线
    if metric_key == 'acc':
        plt.axhline(y=ACC_CUTOFF_THRESHOLD, color='gray', linestyle='--', alpha=0.5, label=f'{ACC_CUTOFF_THRESHOLD*100:.0f}% Threshold')

    plt.grid(True, linestyle='--', alpha=0.6)
    plt.tight_layout()
    
    save_path = os.path.join(fig_dir, filename)
    plt.savefig(save_path, dpi=300)
    plt.close()
    print(f"Generated {filename}")

def generate_summary_table(data, fig_dir):
    rows = []
    sorted_keys = sorted(data.keys())
    
    for name in sorted_keys:
        metrics = data[name]
        if not metrics['acc']:
            continue
            
        raw_accs = metrics['acc']
        
        final_vals = [] 
        best_vals = []
        time_vals = []
        rounds_to_target = [] 
        
        for i, run_acc in enumerate(raw_accs):
            if len(run_acc) == 0: continue
            run_arr = np.array(run_acc)
            
            best_vals.append(np.max(run_arr) * 100)
            
            target_idx = np.where(run_arr >= ACC_CUTOFF_THRESHOLD)[0]
            
            if len(target_idx) > 0:
                idx = target_idx[0]
                rounds_to_target.append(idx + 1)
                final_vals.append(run_arr[idx] * 100)
            else:
                rounds_to_target.append(np.nan)
                final_vals.append(run_arr[-1] * 100)

            if i < len(metrics['time']):
                t = metrics['time'][i]
                if isinstance(t, (list, np.ndarray)) and len(t) > 0:
                    time_vals.append(t[-1])
                elif isinstance(t, (int, float)):
                    time_vals.append(t)

        if not final_vals:
            continue

        mean_final = np.mean(final_vals)
        std_final = np.std(final_vals)
        
        mean_best = np.mean(best_vals)
        std_best = np.std(best_vals)
        
        mean_time = np.mean(time_vals) if time_vals else 0
        
        valid_rounds = [r for r in rounds_to_target if not np.isnan(r)]
        if valid_rounds:
            mean_rounds = np.mean(valid_rounds)
            str_rounds = f"{mean_rounds:.1f}"
        else:
            str_rounds = "> Max"

        rows.append({
            "Backbone Model": name,
            "Final (Truncated) Acc (%)": f"{mean_final:.2f} ± {std_final:.2f}",
            "Best Acc (%)": f"{mean_best:.2f} ± {std_best:.2f}",
            "Rounds to 95%": str_rounds,
            "Total Time (s)": f"{mean_time:.0f}"
        })
    
    if not rows:
        print("No matching models found to generate table.")
        return

    df = pd.DataFrame(rows)
    
    print("\n" + "="*85)
    print(f" MODEL COMPARISON SUMMARY (Threshold: {ACC_CUTOFF_THRESHOLD*100:.0f}%)")
    print("="*85)
    print(df.to_string(index=False))
    
    csv_path = os.path.join(fig_dir, "backbone_comparison.csv")
    df.to_csv(csv_path, index=False)
    print(f"\nTable saved to {csv_path}")

def parse_args():
    p = argparse.ArgumentParser(description="Analyze experiment logs")
    p.add_argument("--log_dir", type=str, default=".", help="Path to the root run folder")
    p.add_argument("--fig_dir", type=str, default="figs", help="Directory to save figures")
    return p.parse_args()

if __name__ == "__main__":
    args = parse_args()
    LOG_DIR = os.path.abspath(args.log_dir)
    
    if args.fig_dir == "figs":
        FIG_DIR = os.path.join(LOG_DIR, "analysis_results")
    else:
        FIG_DIR = os.path.abspath(args.fig_dir)
        
    os.makedirs(FIG_DIR, exist_ok=True)

    print(f"Scanning logs recursively in: {LOG_DIR}")
    print(f"Saving results to: {FIG_DIR}")

    data = load_data(LOG_DIR)
    
    if data:
        print("\n--- Generating Plots ---")
        
        # 1. Acc 截断图
        plot_metric(data, 'acc', 'Test Accuracy', 'Test Accuracy (Truncated)', 'compare_acc_truncated.png', FIG_DIR, use_cutoff=True)
        
        # 2. Test Loss 截断图 (使用 Acc 的截断点)
        plot_metric(data, 'test_loss', 'Test Loss', 'Test Loss (Truncated)', 'compare_test_loss_truncated.png', FIG_DIR, use_cutoff=True)
        
        # 3. Train Loss 截断图 (使用 Acc 的截断点)
        plot_metric(data, 'train_loss', 'Training Loss', 'Training Loss (Truncated)', 'compare_train_loss_truncated.png', FIG_DIR, use_cutoff=True)
        
        # 4. 原始完整图 (保留一份作为对照，可选)
        # plot_metric(data, 'acc', 'Test Accuracy', 'Test Accuracy (Full)', 'compare_acc_full.png', FIG_DIR, use_cutoff=False)

        generate_summary_table(data, FIG_DIR)
    else:
        print("No data found.")
