import glob
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import argparse

def load_data():
    files = glob.glob(os.path.join(LOG_DIR, "*.npz"))
    data = {} 
    # data 结构:
    # data["FedAvg (full)"]['acc'] = [ [seed0_accs...], [seed1_accs...] ]
    
    if not files:
        print(f"No .npz files found in {LOG_DIR}")
        return None

    print(f"Found {len(files)} log files.")

    for f in files:
        try:
            loaded = np.load(f, allow_pickle=True)
            
            # --- 1. 读取 Method ---
            # 处理 numpy 0-d array 的情况
            if 'method' in loaded:
                method = str(loaded['method'])
            else:
                # 如果文件里没存，尝试从文件名解析 (Fallback)
                method = os.path.basename(f).split('_')[0]

            # --- 2. 读取 Training Mode ---
            # 区分 full 和 head 模式非常重要
            if 'training_mode' in loaded:
                mode = str(loaded['training_mode'])
            else:
                # Fallback: 尝试从文件名解析
                parts = os.path.basename(f).replace('.npz', '').split('_')
                if len(parts) > 1 and parts[1] in ['full', 'head']:
                    mode = parts[1]
                else:
                    mode = "unknown"

            # --- 3. 生成唯一 Key ---
            display_name = f"{method} ({mode})"
            display_name = f"{method}"

            # --- 4. 初始化数据结构 ---
            if display_name not in data:
                data[display_name] = {
                    'rounds': loaded['rounds'], # 假设所有 run 的 rounds 是一样的
                    'acc': [], 
                    'test_loss': [], 
                    'train_loss': [], 
                    'time': []
                }
            
            # --- 5. 存入数据 ---
            data[display_name]['acc'].append(loaded['test_acc'])
            data[display_name]['test_loss'].append(loaded['test_loss'])
            # 部分旧代码可能叫 train_loss 或 loss，做个兼容
            if 'train_loss' in loaded:
                data[display_name]['train_loss'].append(loaded['train_loss'])
            else:
                # 如果没有 train_loss，填入空或者 0 占位，防止后续绘图报错
                # 这里选择跳过 train_loss 的绘图逻辑
                pass
                
            data[display_name]['time'].append(loaded['wall_time'])

        except Exception as e:
            print(f"Error loading {f}: {e}")
    
    return data

def plot_metric(data, metric_key, y_label, title, filename):
    # 检查是否有数据包含该 metric
    valid_data = False
    for method, metrics in data.items():
        if metric_key in metrics and len(metrics[metric_key]) > 0:
            valid_data = True
            break
    
    if not valid_data:
        print(f"Skipping {filename}: metric '{metric_key}' not found.")
        return

    plt.figure(figsize=(8, 6))    
    
    # 颜色循环，防止线条太多颜色重复
    colors = plt.cm.tab10(np.linspace(0, 1, 10))
    
    for i, (method, metrics) in enumerate(data.items()):
        if metric_key not in metrics or not metrics[metric_key]:
            continue

        arr = np.array(metrics[metric_key]) # Shape: (num_seeds, rounds)
        
        # 简单的形状检查
        if arr.ndim != 2:
            print(f"Warning: Data shape mismatch for {method} - {metric_key}, skipping.")
            continue

        mean = np.mean(arr, axis=0)
        std = np.std(arr, axis=0)
        x = metrics['rounds']
        
        # 确保 x 和 y 长度一致
        if len(x) != len(mean):
            min_len = min(len(x), len(mean))
            x = x[:min_len]
            mean = mean[:min_len]
            std = std[:min_len]

        color = colors[i % len(colors)]
        plt.plot(x, mean, label=method, linewidth=2, color=color)
        plt.fill_between(x, mean - std, mean + std, color=color, alpha=0.2)
        
    plt.xlabel("Communication Rounds")
    plt.ylabel(y_label)
    plt.title(title)
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.tight_layout()
    save_path = os.path.join(FIG_DIR, filename)
    plt.savefig(save_path, dpi=300)
    plt.close()
    print(f"Generated {filename}")

def plot_time_acc(data):
    plt.figure(figsize=(8, 6))
    colors = plt.cm.tab10(np.linspace(0, 1, 10))

    for i, (method, metrics) in enumerate(data.items()):
        time_arr = np.array(metrics['time'])
        acc_arr = np.array(metrics['acc'])
        
        if time_arr.ndim != 2 or acc_arr.ndim != 2:
            continue

        mean_time = np.mean(time_arr, axis=0)
        mean_acc = np.mean(acc_arr, axis=0)
        std_acc = np.std(acc_arr, axis=0)
        
        # 确保长度一致
        min_len = min(len(mean_time), len(mean_acc))
        mean_time = mean_time[:min_len]
        mean_acc = mean_acc[:min_len]
        std_acc = std_acc[:min_len]
        
        color = colors[i % len(colors)]
        plt.plot(mean_time, mean_acc, label=method, linewidth=2, color=color)
        plt.fill_between(mean_time, mean_acc - std_acc, mean_acc + std_acc, color=color, alpha=0.2)

    plt.xlabel("Wall-clock Time (s)")
    plt.ylabel("Test Accuracy")
    plt.title("Accuracy vs Wall-clock Time")
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.tight_layout()
    save_path = os.path.join(FIG_DIR, "acc_vs_time.png")
    plt.savefig(save_path, dpi=300)
    plt.close()
    print("Generated acc_vs_time.png")

def parse_args():
    p = argparse.ArgumentParser(description="Analyze experiment logs")
    p.add_argument("--log_dir", type=str, default="./logs", help="Directory containing log .npz files")
    p.add_argument("--fig_dir", type=str, default="./figs", help="Directory to save figures")
    return p.parse_args()

if __name__ == "__main__":
    args = parse_args()
    LOG_DIR = os.path.abspath(args.log_dir)
    FIG_DIR = os.path.abspath(args.fig_dir)
    os.makedirs(FIG_DIR, exist_ok=True)

    print(f"Using log dir: {LOG_DIR}")
    print(f"Using fig dir: {FIG_DIR}")

    data = load_data()
    
    if data:
        print("\n--- Generating Plots ---")
        plot_metric(data, 'acc', 'Test Accuracy', 'Test Accuracy vs Rounds', 'acc_vs_rounds.png')
        plot_metric(data, 'test_loss', 'Test Loss', 'Test Loss vs Rounds', 'test_loss_vs_rounds.png')
        plot_metric(data, 'train_loss', 'Train Loss', 'Training Loss vs Rounds', 'train_loss_vs_rounds.png')
        plot_time_acc(data)
        
        print("\n=== Final Summary (Last Round Accuracy) ===")
        rows = []
        for method, metrics in data.items():
            accs = np.array(metrics['acc']) # (seeds, rounds)
            # 取最后一轮的准确率
            final_accs = accs[:, -1] * 100 
            
            mean_val = np.mean(final_accs)
            std_val = np.std(final_accs)
            
            rows.append([method, f"{mean_val:.2f} ± {std_val:.2f}"])
        
        df = pd.DataFrame(rows, columns=["Method (Mode)", "Final Accuracy (%)"])
        print(df)
        
        # 保存 Summary 为 CSV
        csv_path = os.path.join(FIG_DIR, "summary.csv")
        df.to_csv(csv_path, index=False)
        print(f"\nSummary saved to {csv_path}")
    else:
        print("No data loaded. Check your log directory.")
