import os
import argparse
import matplotlib.pyplot as plt
import numpy as np
import glob

def process_subfolder_npz(folder_path):
    """
    读取子文件夹下所有 .npz 文件，聚合数据
    """
    npz_files = glob.glob(os.path.join(folder_path, "*.npz"))
    
    if not npz_files:
        return None

    data_collection = {
        "test_acc": [],
        "test_loss": [],
        "train_loss": [],
        "wall_time": [],
        "rounds": []
    }

    for f in npz_files:
        try:
            with np.load(f) as data:
                if 'test_acc' not in data: continue
                
                data_collection["test_acc"].append(data['test_acc'])
                data_collection["test_loss"].append(data['test_loss'])
                data_collection["train_loss"].append(data['train_loss'])
                data_collection["wall_time"].append(data['wall_time'])
                data_collection["rounds"].append(data['rounds'])
        except Exception as e:
            print(f"Error loading {f}: {e}")

    if not data_collection["test_acc"]:
        return None

    min_len = min([len(x) for x in data_collection["rounds"]])
    
    metrics = {}
    for key in ["test_acc", "test_loss", "train_loss", "wall_time"]:
        arr = np.array([x[:min_len] for x in data_collection[key]])
        
        # Accuracy 0-1 转 百分比
        if key == "test_acc" and np.max(arr) <= 1.0:
            arr = arr * 100.0
            
        metrics[f"{key}_mean"] = np.mean(arr, axis=0)
        metrics[f"{key}_std"] = np.std(arr, axis=0)

    metrics["rounds"] = data_collection["rounds"][0][:min_len]
    return metrics

def plot_single_metric(save_name, title, x_data_key, y_data_key, y_std_key, 
                       xlabel, ylabel, all_metrics, subdirs, colors, root_dir):
    """
    通用绘图函数：绘制单个指标并保存为单独文件
    """
    plt.figure(figsize=(10, 6))
    
    has_valid_plot = False
    
    for idx, subdir in enumerate(subdirs):
        metrics = all_metrics[subdir]
        if metrics is None: continue
            
        has_valid_plot = True
        # 简化图例名称
        label_name = subdir.replace("lr", "LR=").replace("_damp", ", Damp=").replace("_nlr", ", NLR=")
        c = colors[idx]
        
        x = metrics[x_data_key]
        y = metrics[y_data_key]
        
        plt.plot(x, y, label=label_name, color=c, linewidth=2)
        
        # 如果有方差数据且不是时间轴(时间轴的x不统一很难画fill)，则画阴影
        if y_std_key and x_data_key != "wall_time_mean":
            y_std = metrics[y_std_key]
            plt.fill_between(x, y - y_std, y + y_std, color=c, alpha=0.15)

    if has_valid_plot:
        plt.title(title, fontsize=16, fontweight='bold')
        plt.xlabel(xlabel, fontsize=14)
        plt.ylabel(ylabel, fontsize=14)
        plt.grid(True, linestyle='--', alpha=0.6)
        plt.legend(fontsize=10, loc='best', framealpha=0.9)
        plt.tight_layout()
        
        save_path = os.path.join(root_dir, save_name)
        plt.savefig(save_path, dpi=300)
        print(f"✅ Saved: {save_path}")
        plt.close() # 关闭当前图，释放内存

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--root_dir', type=str, default='./logs', help='Path to the grid search root directory')
    args = parser.parse_args()

    if not os.path.exists(args.root_dir):
        print(f"Error: Directory {args.root_dir} not found.")
        return

    subdirs = [d for d in os.listdir(args.root_dir) if os.path.isdir(os.path.join(args.root_dir, d))]
    subdirs.sort()
    
    colors = plt.cm.tab10(np.linspace(0, 1, max(len(subdirs), 1)))
    
    # 1. 先预处理所有数据
    all_metrics = {}
    print(f"{'Experiment ID':<40} | {'Final Acc':<10} | {'Final Time (s)':<15}")
    print("-" * 75)
    
    has_any_data = False
    for subdir in subdirs:
        full_path = os.path.join(args.root_dir, subdir)
        m = process_subfolder_npz(full_path)
        all_metrics[subdir] = m
        
        if m:
            has_any_data = True
            print(f"{subdir:<40} | {m['test_acc_mean'][-1]:.2f}%     | {m['wall_time_mean'][-1]:.1f}")
            
    if not has_any_data:
        print("No valid .npz files found.")
        return

    print("-" * 75)
    print("Generating plots...")

    # 2. 分别绘制4张图
    
    # 图1: Test Accuracy vs Rounds
    plot_single_metric(
        save_name='acc_vs_rounds.png',
        title='Test Accuracy vs. Rounds',
        x_data_key='rounds',
        y_data_key='test_acc_mean',
        y_std_key='test_acc_std',
        xlabel='Communication Rounds',
        ylabel='Test Accuracy (%)',
        all_metrics=all_metrics, subdirs=subdirs, colors=colors, root_dir=args.root_dir
    )

    # 图2: Test Accuracy vs Wall Time
    plot_single_metric(
        save_name='acc_vs_time.png',
        title='Test Accuracy vs. Wall Time',
        x_data_key='wall_time_mean', # 注意这里 X 轴变了
        y_data_key='test_acc_mean',
        y_std_key=None, # 时间轴通常不对齐，不画阴影
        xlabel='Wall Clock Time (seconds)',
        ylabel='Test Accuracy (%)',
        all_metrics=all_metrics, subdirs=subdirs, colors=colors, root_dir=args.root_dir
    )

    # 图3: Test Loss vs Rounds
    plot_single_metric(
        save_name='test_loss_vs_rounds.png',
        title='Test Loss vs. Rounds',
        x_data_key='rounds',
        y_data_key='test_loss_mean',
        y_std_key='test_loss_std',
        xlabel='Communication Rounds',
        ylabel='Test Loss',
        all_metrics=all_metrics, subdirs=subdirs, colors=colors, root_dir=args.root_dir
    )

    # 图4: Train Loss vs Rounds
    plot_single_metric(
        save_name='train_loss_vs_rounds.png',
        title='Train Loss vs. Rounds',
        x_data_key='rounds',
        y_data_key='train_loss_mean',
        y_std_key='train_loss_std',
        xlabel='Communication Rounds',
        ylabel='Train Loss',
        all_metrics=all_metrics, subdirs=subdirs, colors=colors, root_dir=args.root_dir
    )

if __name__ == "__main__":
    main()
