import glob
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import argparse

def load_data():
    """
    Scans the log directory for .npz files and aggregates them by method.
    """
    files = glob.glob(os.path.join(LOG_DIR, "*.npz"))
    data = {} 
    # Data Structure:
    # data["MethodName (mode)"] = {
    #     'acc': [ [seed0_accs...], [seed1_accs...] ],
    #     'loss': ...
    # }
    
    if not files:
        print(f"No .npz files found in {LOG_DIR}")
        return None

    print(f"Found {len(files)} log files.")

    for f in files:
        try:
            loaded = np.load(f, allow_pickle=True)
            
            # --- 1. Extract Method Name ---
            # Handle numpy 0-d arrays or missing keys
            if 'method' in loaded:
                method = str(loaded['method'])
            else:
                # Fallback: Try to parse from filename
                method = os.path.basename(f).split('_')[0]

            # --- 2. Extract Training Mode ---
            # Crucial to distinguish between 'full' (all layers) and 'head' (classifier only)
            if 'training_mode' in loaded:
                mode = str(loaded['training_mode'])
            else:
                # Fallback: Try to parse from filename
                parts = os.path.basename(f).replace('.npz', '').split('_')
                if len(parts) > 1 and parts[1] in ['full', 'head']:
                    mode = parts[1]
                else:
                    mode = "unknown"

            # --- 3. Generate Unique Display Key ---
            # Combine method and mode for the legend (e.g., "FedAvg (head)")
            display_name = f"{method} ({mode})"

            # --- 4. Initialize Data Structure ---
            if display_name not in data:
                data[display_name] = {
                    'rounds': loaded['rounds'], # Assuming rounds are consistent across seeds
                    'acc': [], 
                    'test_loss': [], 
                    'train_loss': [], 
                    'time': []
                }
            
            # --- 5. Append Data ---
            data[display_name]['acc'].append(loaded['test_acc'])
            data[display_name]['test_loss'].append(loaded['test_loss'])
            
            # Check for 'train_loss' compatibility (older logs might lack this key)
            if 'train_loss' in loaded:
                data[display_name]['train_loss'].append(loaded['train_loss'])
            else:
                # If missing, skip or handle gracefully
                pass
                
            data[display_name]['time'].append(loaded['wall_time'])

        except Exception as e:
            print(f"Error loading {f}: {e}")
    
    return data

def plot_metric(data, metric_key, y_label, title, filename):
    """
    Generic function to plot a specific metric (Acc/Loss) vs Rounds.
    """
    # Check if valid data exists for this metric
    valid_data = False
    for method, metrics in data.items():
        if metric_key in metrics and len(metrics[metric_key]) > 0:
            valid_data = True
            break
    
    if not valid_data:
        print(f"Skipping {filename}: metric '{metric_key}' not found.")
        return

    plt.figure(figsize=(8, 6))    
    
    # Use a colormap to handle many lines without repetition
    colors = plt.cm.tab10(np.linspace(0, 1, 10))
    
    for i, (method, metrics) in enumerate(data.items()):
        if metric_key not in metrics or not metrics[metric_key]:
            continue

        arr = np.array(metrics[metric_key]) # Shape: (num_seeds, rounds)
        
        # Basic shape validation
        if arr.ndim != 2:
            print(f"Warning: Data shape mismatch for {method} - {metric_key}, skipping.")
            continue

        mean = np.mean(arr, axis=0)
        std = np.std(arr, axis=0)
        x = metrics['rounds']
        
        # Ensure x and y axes have matching lengths (handle early stopping)
        if len(x) != len(mean):
            min_len = min(len(x), len(mean))
            x = x[:min_len]
            mean = mean[:min_len]
            std = std[:min_len]

        color = colors[i % len(colors)]
        plt.plot(x, mean, label=method, linewidth=2, color=color)
        plt.fill_between(x, mean - std, mean + std, color=color, alpha=0.2)
        
    plt.xlabel("Communication Rounds")
    plt.ylabel(y_label)
    plt.title(title)
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.tight_layout()
    
    save_path = os.path.join(FIG_DIR, filename)
    plt.savefig(save_path, dpi=300)
    plt.close()
    print(f"Generated {filename}")

def plot_time_acc(data):
    """
    Plots Test Accuracy vs. Wall-clock Time.
    """
    plt.figure(figsize=(8, 6))
    colors = plt.cm.tab10(np.linspace(0, 1, 10))

    for i, (method, metrics) in enumerate(data.items()):
        time_arr = np.array(metrics['time'])
        acc_arr = np.array(metrics['acc'])
        
        if time_arr.ndim != 2 or acc_arr.ndim != 2:
            continue

        mean_time = np.mean(time_arr, axis=0)
        mean_acc = np.mean(acc_arr, axis=0)
        std_acc = np.std(acc_arr, axis=0)
        
        # Ensure lengths align
        min_len = min(len(mean_time), len(mean_acc))
        mean_time = mean_time[:min_len]
        mean_acc = mean_acc[:min_len]
        std_acc = std_acc[:min_len]
        
        color = colors[i % len(colors)]
        plt.plot(mean_time, mean_acc, label=method, linewidth=2, color=color)
        plt.fill_between(mean_time, mean_acc - std_acc, mean_acc + std_acc, color=color, alpha=0.2)

    plt.xlabel("Wall-clock Time (s)")
    plt.ylabel("Test Accuracy")
    plt.title("Accuracy vs Wall-clock Time")
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.tight_layout()
    
    save_path = os.path.join(FIG_DIR, "acc_vs_time.png")
    plt.savefig(save_path, dpi=300)
    plt.close()
    print("Generated acc_vs_time.png")

def parse_args():
    p = argparse.ArgumentParser(description="Analyze experiment logs")
    p.add_argument("--log_dir", type=str, default="./logs", help="Directory containing log .npz files")
    p.add_argument("--fig_dir", type=str, default="./figs", help="Directory to save figures")
    return p.parse_args()

if __name__ == "__main__":
    args = parse_args()
    LOG_DIR = os.path.abspath(args.log_dir)
    FIG_DIR = os.path.abspath(args.fig_dir)
    os.makedirs(FIG_DIR, exist_ok=True)

    print(f"Using log dir: {LOG_DIR}")
    print(f"Using fig dir: {FIG_DIR}")

    data = load_data()
    
    if data:
        print("\n--- Generating Plots ---")
        plot_metric(data, 'acc', 'Test Accuracy', 'Test Accuracy vs Rounds', 'acc_vs_rounds.png')
        plot_metric(data, 'test_loss', 'Test Loss', 'Test Loss vs Rounds', 'test_loss_vs_rounds.png')
        plot_metric(data, 'train_loss', 'Train Loss', 'Training Loss vs Rounds', 'train_loss_vs_rounds.png')
        plot_time_acc(data)
        
        print("\n=== Final Summary (Last Round Accuracy) ===")
        rows = []
        for method, metrics in data.items():
            accs = np.array(metrics['acc']) # Shape: (seeds, rounds)
            
            # Get accuracy from the final round
            final_accs = accs[:, -1] * 100 
            
            mean_val = np.mean(final_accs)
            std_val = np.std(final_accs)
            
            rows.append([method, f"{mean_val:.2f} ± {std_val:.2f}"])
        
        df = pd.DataFrame(rows, columns=["Method (Mode)", "Final Accuracy (%)"])
        print(df)
        
        # Save Summary to CSV
        csv_path = os.path.join(FIG_DIR, "summary.csv")
        df.to_csv(csv_path, index=False)
        print(f"\nSummary saved to {csv_path}")
    else:
        print("No data loaded. Check your log directory.")
