# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Script to download and plot gradient norms from wandb runs.
This script reads grad_norm_attn_layer_n and grad_norm_ffn_layer_n metrics
from specified wandb runs and creates plots for visualization and comparison.
"""

import os
from typing import Dict, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import wandb


def download_wandb_data(entity: str, project: str, run_name: str) -> wandb.Api:
    """
    Connect to wandb and get the specified run.

    Args:
        entity: wandb entity (username/organization)
        project: wandb project name
        run_name: specific run name to fetch data from

    Returns:
        wandb run object
    """
    api = wandb.Api()

    # Get the run
    runs = api.runs(f"{entity}/{project}")

    target_run = None
    for run in runs:
        if run.name == run_name:
            target_run = run
            break

    if target_run is None:
        raise ValueError(f"Run '{run_name}' not found in {entity}/{project}")

    return target_run


def extract_grad_norm_data(
    run: wandb.Api, num_layers: int = 15
) -> Tuple[Dict, Dict, pd.Series, pd.Series]:
    """
    Extract gradient norm data for attention and FFN layers, plus loss data.

    Args:
        run: wandb run object
        num_layers: number of layers (0 to num_layers-1)

    Returns:
        Tuple of (attn_data, ffn_data, steps, loss_data) dictionaries and series
    """
    # Get ALL history data by setting samples=0 (unlimited)
    print(
        f"Downloading complete history from wandb run '{run.name}' (this may take a moment)..."
    )
    history = run.history(samples=10000)  # This gets ALL data points, not just 500

    print(f"Downloaded {len(history)} total data points for run '{run.name}'")

    attn_data = {}
    ffn_data = {}

    # Extract data for each layer
    for layer_idx in range(num_layers):
        attn_col = f"grad_norm_attn_layer_{layer_idx}"
        ffn_col = f"grad_norm_ffn_layer_{layer_idx}"

        if attn_col in history.columns:
            attn_data[f"Layer {layer_idx}"] = history[attn_col].dropna()
            print(
                f"Found {len(attn_data[f'Layer {layer_idx}'])} data points for {attn_col}"
            )
        else:
            print(f"Warning: {attn_col} not found in data")

        if ffn_col in history.columns:
            ffn_data[f"Layer {layer_idx}"] = history[ffn_col].dropna()
            print(
                f"Found {len(ffn_data[f'Layer {layer_idx}'])} data points for {ffn_col}"
            )
        else:
            print(f"Warning: {ffn_col} not found in data")

    # Extract loss data
    loss_col = "loss_metrics/global_avg_loss"
    if loss_col in history.columns:
        loss_data = history[loss_col].dropna()
        print(f"Found {len(loss_data)} data points for {loss_col}")
    else:
        print(f"Warning: {loss_col} not found in data")
        loss_data = pd.Series(dtype=float)

    # Get steps
    steps = history.get("_step", history.index)

    return attn_data, ffn_data, steps, loss_data


def plot_grad_norms_single_run(
    data: Dict,
    category: str,
    run_name: str,
    display_name: str,
    output_dir: str = "outputs/grad_norm_plots",
):
    """
    Plot gradient norms for all layers in the same plot for a single run.

    Args:
        data: Dictionary with layer names as keys and data as values
        category: Category name (e.g., 'attention' or 'ffn')
        run_name: Name of the run for filename
        display_name: Name to display in plot title and legend
        output_dir: Directory to save plots
    """
    if not data:
        print(f"No data available for {category}")
        return

    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)

    plt.figure(figsize=(14, 8))

    # Plot each layer
    for layer_name, values in data.items():
        if len(values) > 0:
            steps = np.arange(0, len(values) * 10, 10)
            plt.plot(steps, values, label=layer_name, alpha=0.7, linewidth=2)

    plt.xlabel("Training Step")
    plt.ylabel(f"Gradient Norm ({category})")
    plt.title(f"Gradient Norms - {category.upper()} Layers ({display_name})")
    plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
    plt.grid(True, alpha=0.3)
    plt.tight_layout()

    # Save the plot
    filename = f"grad_norm_{category.lower()}_all_layers_{display_name.lower()}.png"
    filepath = os.path.join(output_dir, filename)
    plt.savefig(filepath, dpi=300, bbox_inches="tight")
    print(f"Saved plot: {filepath}")

    plt.show()


def plot_loss_comparison(
    loss1: pd.Series,
    loss2: pd.Series,
    display1_name: str,
    display2_name: str,
    output_dir: str = "outputs/grad_norm_plots",
):
    """
    Plot loss comparison between two runs.

    Args:
        loss1: Loss data for first run
        loss2: Loss data for second run
        display1_name: Display name for first run
        display2_name: Display name for second run
        output_dir: Directory to save plots
    """
    if len(loss1) == 0 and len(loss2) == 0:
        print("No loss data available for comparison")
        return

    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)

    plt.figure(figsize=(12, 6))

    # Plot loss curves
    if len(loss1) > 0:
        steps1 = np.arange(0, len(loss1) * 10, 10)
        plt.plot(
            steps1, loss1, label=display1_name, alpha=0.8, linewidth=2.5, color="blue"
        )

    if len(loss2) > 0:
        steps2 = np.arange(0, len(loss2) * 10, 10)
        plt.plot(
            steps2, loss2, label=display2_name, alpha=0.8, linewidth=2.5, color="red"
        )

    plt.xlabel("Training Step", fontsize=12)
    plt.ylabel("Global Average Loss", fontsize=12)
    plt.title(
        f"Training Loss Comparison: {display1_name} vs {display2_name}",
        fontsize=14,
        fontweight="bold",
    )
    plt.legend(fontsize=12)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()

    # Save the plot
    filename = f"loss_comparison_{display1_name.lower()}_vs_{display2_name.lower()}.png"
    filepath = os.path.join(output_dir, filename)
    plt.savefig(filepath, dpi=300, bbox_inches="tight")
    print(f"Saved loss comparison plot: {filepath}")

    plt.show()


def plot_comparison_subplots(
    data1: Dict,
    data2: Dict,
    run1_name: str,
    run2_name: str,
    display1_name: str,
    display2_name: str,
    category: str,
    output_dir: str = "outputs/grad_norm_plots",
):
    """
    Create subplot comparison plots where each subplot shows one layer comparing two runs.

    Args:
        data1: First run data dictionary
        data2: Second run data dictionary
        run1_name: Name of the first run for filename
        run2_name: Name of the second run for filename
        display1_name: Display name for first run
        display2_name: Display name for second run
        category: Category name (e.g., 'attention' or 'ffn')
        output_dir: Directory to save plots
    """
    if not data1 or not data2:
        print(f"Insufficient data for comparison plots ({category})")
        return

    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)

    # Get all layers that exist in both runs
    common_layers = set(data1.keys()) & set(data2.keys())
    common_layers = sorted(list(common_layers), key=lambda x: int(x.split()[1]))

    if not common_layers:
        print(f"No common layers found for {category} comparison")
        return

    # Calculate subplot layout (aim for roughly square grid)
    n_layers = len(common_layers)
    cols = int(np.ceil(np.sqrt(n_layers)))
    rows = int(np.ceil(n_layers / cols))

    fig, axes = plt.subplots(rows, cols, figsize=(4 * cols, 3 * rows))
    fig.suptitle(
        f"{category.upper()} Gradient Norms Comparison: {display1_name} vs {display2_name}",
        fontsize=16,
        y=0.98,
    )

    # Flatten axes array for easier indexing
    if rows == 1 and cols == 1:
        axes = [axes]
    elif rows == 1 or cols == 1:
        axes = axes.flatten()
    else:
        axes = axes.flatten()

    for idx, layer_name in enumerate(common_layers):
        ax = axes[idx]

        # Get data for both runs
        values1 = data1[layer_name]
        values2 = data2[layer_name]

        if len(values1) > 0:
            steps1 = np.arange(0, len(values1) * 10, 10)
            ax.plot(
                steps1,
                values1,
                label=display1_name,
                alpha=0.8,
                linewidth=2.5,
                color="blue",
            )

        if len(values2) > 0:
            steps2 = np.arange(0, len(values2) * 10, 10)
            ax.plot(
                steps2,
                values2,
                label=display2_name,
                alpha=0.8,
                linewidth=2.5,
                color="red",
            )

        ax.set_title(f"{layer_name}", fontsize=10, fontweight="bold")
        ax.set_xlabel("Step", fontsize=8)
        ax.set_ylabel("Grad Norm", fontsize=8)
        ax.grid(True, alpha=0.3)
        ax.legend(fontsize=8)
        ax.tick_params(labelsize=8)

    # Hide unused subplots
    for idx in range(len(common_layers), len(axes)):
        axes[idx].set_visible(False)

    plt.tight_layout()
    plt.subplots_adjust(top=0.93)

    # Save the plot
    filename = f"grad_norm_{category.lower()}_comparison_{display1_name.lower()}_vs_{display2_name.lower()}.png"
    filepath = os.path.join(output_dir, filename)
    plt.savefig(filepath, dpi=300, bbox_inches="tight")
    print(f"Saved comparison plot: {filepath}")

    plt.show()


def print_summary_statistics(data: Dict, run_name: str, category: str):
    """Print summary statistics for a run."""
    print(f"\n=== {category.upper()} SUMMARY for {run_name} ===")
    for layer, values in data.items():
        if len(values) > 0:
            print(
                f"  {layer}: {len(values)} steps, "
                f"mean={values.mean():.4f}, std={values.std():.4f}, "
                f"min={values.min():.4f}, max={values.max():.4f}"
            )


def main():
    """Main function to orchestrate the data download and plotting."""

    # Configuration
    ENTITY = "ajanthan-pluralis-research"
    PROJECT = "torchtitan"
    RUN_NAMES = ["whole-capybara-74", "lilac-rain-73"]  # Both runs to analyze
    DISPLAY_NAMES = ["compressed", "vanilla"]  # Display names for the runs
    NUM_LAYERS = 16  # Layers 0 to 14

    print(f"Connecting to wandb runs: {ENTITY}/{PROJECT}")
    print(f"Runs to analyze: {RUN_NAMES}")

    try:
        # Download data from both runs
        runs_data = {}

        for i, run_name in enumerate(RUN_NAMES):
            print(f"\n--- Processing run: {run_name} ---")
            run = download_wandb_data(ENTITY, PROJECT, run_name)
            print(f"Successfully connected to run: {run.name}")
            print(f"Run URL: {run.url}")

            # Extract gradient norm data and loss data
            attn_data, ffn_data, steps, loss_data = extract_grad_norm_data(
                run, NUM_LAYERS
            )

            runs_data[run_name] = {
                "attn": attn_data,
                "ffn": ffn_data,
                "steps": steps,
                "loss": loss_data,
                "run_obj": run,
                "display_name": DISPLAY_NAMES[i],
            }

            print(f"Found attention data for {len(attn_data)} layers")
            print(f"Found FFN data for {len(ffn_data)} layers")

        # Create individual run plots (original functionality)
        print("\n=== Creating individual run plots ===")
        for run_name, data in runs_data.items():
            display_name = data["display_name"]
            print(f"\nPlotting data for {run_name} ({display_name})...")
            plot_grad_norms_single_run(
                data["attn"], "Attention", run_name, display_name
            )
            plot_grad_norms_single_run(data["ffn"], "FFN", run_name, display_name)

        # Create comparison plots if we have exactly 2 runs
        if len(RUN_NAMES) == 2:
            print("\n=== Creating comparison subplot plots ===")
            run1_name, run2_name = RUN_NAMES
            run1_data = runs_data[run1_name]
            run2_data = runs_data[run2_name]

            print("Creating attention layers comparison...")
            plot_comparison_subplots(
                run1_data["attn"],
                run2_data["attn"],
                run1_name,
                run2_name,
                run1_data["display_name"],
                run2_data["display_name"],
                "Attention",
            )

            print("Creating FFN layers comparison...")
            plot_comparison_subplots(
                run1_data["ffn"],
                run2_data["ffn"],
                run1_name,
                run2_name,
                run1_data["display_name"],
                run2_data["display_name"],
                "FFN",
            )

            print("Creating loss comparison...")
            plot_loss_comparison(
                run1_data["loss"],
                run2_data["loss"],
                run1_data["display_name"],
                run2_data["display_name"],
            )

        # Print summary statistics for all runs
        print("\n" + "=" * 60)
        print("DETAILED STATISTICS")
        print("=" * 60)

        for run_name, data in runs_data.items():
            display_name = data["display_name"]
            print_summary_statistics(
                data["attn"], f"{run_name} ({display_name})", "Attention"
            )
            print_summary_statistics(data["ffn"], f"{run_name} ({display_name})", "FFN")

            # Print loss statistics
            loss_data = data["loss"]
            if len(loss_data) > 0:
                print(f"\n=== LOSS SUMMARY for {run_name} ({display_name}) ===")
                print(
                    f"  Loss: {len(loss_data)} steps, "
                    f"mean={loss_data.mean():.4f}, std={loss_data.std():.4f}, "
                    f"min={loss_data.min():.4f}, max={loss_data.max():.4f}"
                )
                print(f"  Final loss: {loss_data.iloc[-1]:.4f}")

            # Print overall info
            full_history_len = len(data["steps"])
            print(f"\nRun {run_name}: Total data points = {full_history_len}")
            if hasattr(data["steps"], "min") and hasattr(data["steps"], "max"):
                print(f"Step range: {data['steps'].min()} to {data['steps'].max()}")

    except Exception as e:
        print(f"Error: {e}")
        import traceback

        traceback.print_exc()
        return


if __name__ == "__main__":
    main()
