import os
import argparse
import matplotlib.pyplot as plt
import numpy as np
import glob
import sys

# =================================================================
# Helper: Data Loading
# =================================================================

def process_subfolder_npz(folder_path):
    """
    Reads all .npz files within a subfolder and aggregates metrics.
    Returns a dictionary containing mean and std of the metrics.
    """
    # Search for .npz files (supports nested structure if needed, but keeping simple for now)
    npz_files = glob.glob(os.path.join(folder_path, "*.npz"))
    
    if not npz_files:
        return None

    data_collection = {
        "test_acc": [],
        "test_loss": [],
        "train_loss": [],
        "wall_time": [],
        "rounds": []
    }

    valid_files_count = 0
    
    for f in npz_files:
        try:
            with np.load(f) as data:
                # Validation: Check if keys exist
                if 'test_acc' not in data: 
                    print(f"Warning: 'test_acc' missing in {os.path.basename(f)}")
                    continue
                
                # Append data
                data_collection["test_acc"].append(data['test_acc'])
                data_collection["test_loss"].append(data['test_loss'])
                
                # Handle optional keys gracefully
                if 'train_loss' in data:
                    data_collection["train_loss"].append(data['train_loss'])
                if 'wall_time' in data:
                    data_collection["wall_time"].append(data['wall_time'])
                if 'rounds' in data:
                    data_collection["rounds"].append(data['rounds'])
                
                valid_files_count += 1
        except Exception as e:
            print(f"Error loading {f}: {e}")

    if valid_files_count == 0:
        return None

    # Determine the minimum length (stop at the earliest crash)
    # Filter out runs that might be empty
    lengths = [len(x) for x in data_collection["rounds"] if len(x) > 0]
    if not lengths:
        return None
        
    min_len = min(lengths)

    metrics = {}
    
    # Process each metric
    for key in ["test_acc", "test_loss", "train_loss", "wall_time"]:
        if not data_collection[key]:
            continue
            
        # Truncate to min_len to ensure numpy array rectangularity
        # We only take valid arrays that are at least min_len long
        valid_arrays = [x[:min_len] for x in data_collection[key] if len(x) >= min_len]
        
        if not valid_arrays:
            continue

        arr = np.array(valid_arrays)
        
        # Convert accuracy to percentage if it's in 0-1 range
        if key == "test_acc" and np.max(arr) <= 1.0:
            arr = arr * 100.0
            
        metrics[f"{key}_mean"] = np.mean(arr, axis=0)
        metrics[f"{key}_std"] = np.std(arr, axis=0)

    # rounds are usually identical, just take the first one
    metrics["rounds"] = data_collection["rounds"][0][:min_len]
    return metrics

# =================================================================
# Helper: Plotting
# =================================================================

def plot_single_metric(save_name, title, x_data_key, y_data_key, y_std_key, 
                       xlabel, ylabel, all_metrics, subdirs, colors, output_dir):
    """
    Generic plotting function.
    """
    plt.figure(figsize=(10, 6))
    has_valid_plot = False
    
    for idx, subdir_name in enumerate(subdirs):
        metrics = all_metrics[subdir_name]
        if metrics is None: continue
        if y_data_key not in metrics: continue # Skip if metric doesn't exist (e.g. train_loss)

        has_valid_plot = True
        
        # Beautify Legend Labels
        # Example: "log_FedNewton_cg" -> "FedNewton (CG)"
        label_name = subdir_name
        label_name = label_name.replace("log_", "").replace("Fed", "Fed-")
        label_name = label_name.replace("_lr", ", LR=").replace("_nlr", ", NLR=")
        
        c = colors[idx % len(colors)] # Cycle colors if runs > colors
        
        x = metrics[x_data_key]
        y = metrics[y_data_key]
        
        # Handle Wall Time alignment (X-axis might differ per method)
        if x_data_key == "wall_time_mean" and "wall_time_mean" not in metrics:
             # Fallback if wall_time is missing
             continue

        plt.plot(x, y, label=label_name, color=c, linewidth=2.5, alpha=0.9)
        
        # Shaded Error Bars
        if y_std_key and y_std_key in metrics and x_data_key != "wall_time_mean":
            y_std = metrics[y_std_key]
            plt.fill_between(x, y - y_std, y + y_std, color=c, alpha=0.1)

    if has_valid_plot:
        plt.title(title, fontsize=16, fontweight='bold')
        plt.xlabel(xlabel, fontsize=14)
        plt.ylabel(ylabel, fontsize=14)
        plt.grid(True, linestyle='--', alpha=0.6)
        plt.legend(fontsize=10, loc='best', framealpha=0.9)
        plt.tight_layout()
        
        save_path = os.path.join(output_dir, save_name)
        plt.savefig(save_path, dpi=300)
        print(f"✅ Saved Plot: {save_path}")
        plt.close()
    else:
        print(f"⚠️  Skipped Plot: {save_name} (No valid data found)")

# =================================================================
# Main Execution
# =================================================================

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--log_dir', type=str, required=True, help='Path containing subfolders of logs (e.g., runs/cifar10/logs)')
    parser.add_argument('--fig_dir', type=str, default=None, help='Path to save figures. Defaults to sibling "figs" directory.')
    args = parser.parse_args()

    if not os.path.exists(args.log_dir):
        print(f"Error: Log directory {args.log_dir} not found.")
        sys.exit(1)

    # If fig_dir is not provided, assume standard structure ../figs
    if args.fig_dir is None:
        # If input is runs/cifar10/logs, output will be runs/cifar10/figs
        args.fig_dir = args.log_dir.replace("/logs", "/figs")
        if args.fig_dir == args.log_dir: # fallback if replacement didn't happen
             args.fig_dir = os.path.join(args.log_dir, "figs")
    
    os.makedirs(args.fig_dir, exist_ok=True)

    # Identify subdirectories (each represents a method/solver)
    # We look for folders that contain at least one .npz file
    subdirs = [d for d in os.listdir(args.log_dir) if os.path.isdir(os.path.join(args.log_dir, d))]
    subdirs.sort()
    
    if not subdirs:
        # Handle case where logs might be directly in log_dir (no subfolders)
        if glob.glob(os.path.join(args.log_dir, "*.npz")):
            subdirs = ["."]
        else:
            print(f"No subdirectories or .npz files found in {args.log_dir}")
            return

    # Generate colors (Tableau 10 or 20)
    colors = plt.cm.tab10(np.linspace(0, 1, 10))
    if len(subdirs) > 10:
        colors = plt.cm.tab20(np.linspace(0, 1, 20))

    # 1. Process Data
    all_metrics = {}
    
    print("\n" + "="*80)
    print(f"{'Experiment / Method':<40} | {'Final Acc':<10} | {'Time (s)':<10} | {'Rounds':<6}")
    print("="*80)
    
    has_any_data = False
    
    for subdir in subdirs:
        full_path = os.path.join(args.log_dir, subdir)
        m = process_subfolder_npz(full_path)
        all_metrics[subdir] = m
        
        if m:
            has_any_data = True
            last_acc = m['test_acc_mean'][-1]
            last_time = m['wall_time_mean'][-1] if 'wall_time_mean' in m else 0.0
            num_rounds = len(m['rounds'])
            
            # Display simple name for readability
            display_name = subdir if subdir != "." else "Default"
            print(f"{display_name:<40} | {last_acc:.2f}%     | {last_time:.0f}s      | {num_rounds}")
            
    if not has_any_data:
        print("\n❌ No valid data found to plot.")
        return

    print("-" * 80)
    print(f"Saving figures to: {args.fig_dir}")

    # 2. Generate Plots

    # Plot 1: Accuracy vs Rounds
    plot_single_metric(
        save_name='acc_vs_rounds.png',
        title=f'Test Accuracy vs. Rounds',
        x_data_key='rounds',
        y_data_key='test_acc_mean',
        y_std_key='test_acc_std',
        xlabel='Communication Rounds',
        ylabel='Test Accuracy (%)',
        all_metrics=all_metrics, subdirs=subdirs, colors=colors, output_dir=args.fig_dir
    )

    # Plot 2: Accuracy vs Time
    plot_single_metric(
        save_name='acc_vs_time.png',
        title=f'Test Accuracy vs. Wall Time',
        x_data_key='wall_time_mean',
        y_data_key='test_acc_mean',
        y_std_key=None, # Standard deviation fills look messy on unaligned time axes
        xlabel='Wall Clock Time (seconds)',
        ylabel='Test Accuracy (%)',
        all_metrics=all_metrics, subdirs=subdirs, colors=colors, output_dir=args.fig_dir
    )

    # Plot 3: Test Loss vs Rounds
    plot_single_metric(
        save_name='loss_vs_rounds.png',
        title=f'Test Loss vs. Rounds',
        x_data_key='rounds',
        y_data_key='test_loss_mean',
        y_std_key='test_loss_std',
        xlabel='Communication Rounds',
        ylabel='Test Loss',
        all_metrics=all_metrics, subdirs=subdirs, colors=colors, output_dir=args.fig_dir
    )
    
    print("="*80 + "\n")

if __name__ == "__main__":
    main()