"""
Plot timing breakdown from wandb logs or exported CSV.

Usage:
    # From wandb run export (CSV)
    python scripts/analysis/plot_timing_breakdown.py --csv_path /path/to/wandb_export.csv

    # From wandb API (requires wandb login)
    python scripts/analysis/plot_timing_breakdown.py --wandb_run user/project/run_id
"""
import argparse
import os
import numpy as np
import matplotlib.pyplot as plt


def load_timing_from_csv(csv_path):
    """Load timing data from exported wandb CSV."""
    import pandas as pd

    df = pd.read_csv(csv_path)

    # Find timing columns
    timing_cols = [col for col in df.columns if col.startswith("timing/")]

    if not timing_cols:
        raise ValueError(f"No timing columns found in {csv_path}")

    # Extract timing fractions (or ms values)
    data = {}
    for col in timing_cols:
        key = col.replace("timing/", "").replace("_frac", "").replace("_ms", "")
        if "_frac" in col:
            data[key] = df[col].dropna().mean() * 100  # Convert to percentage
        elif "_ms" in col and key not in data:
            # Only use ms if frac not available
            data[key] = df[col].dropna().mean()

    return data


def load_timing_from_wandb(run_path):
    """Load timing data from wandb API."""
    import wandb

    api = wandb.Api()
    run = api.run(run_path)

    # Get timing metrics from summary
    summary = run.summary

    data = {}
    for key in ["forward", "backward", "spectral", "optimizer", "other"]:
        frac_key = f"timing/{key}_frac"
        ms_key = f"timing/{key}_ms"

        if frac_key in summary:
            data[key] = summary[frac_key] * 100  # Convert to percentage
        elif ms_key in summary:
            # Fall back to ms values
            data[key] = summary[ms_key]

    return data


def plot_stacked_bar(data_dict, output_path, title="Training Step Time Breakdown"):
    """Create stacked bar chart of timing breakdown.

    Args:
        data_dict: dict mapping config name -> {component: percentage}
        output_path: path to save the plot
        title: plot title
    """
    plt.rcParams.update(
        {
            "font.size": 11,
            "font.family": "serif",
            "axes.labelsize": 12,
            "axes.titlesize": 12,
            "xtick.labelsize": 10,
            "ytick.labelsize": 10,
            "legend.fontsize": 10,
            "figure.figsize": (8, 5),
            "figure.dpi": 150,
        }
    )

    components = ["forward", "backward", "spectral", "optimizer", "other"]
    colors = ["#2ecc71", "#3498db", "#e74c3c", "#9b59b6", "#95a5a6"]

    fig, ax = plt.subplots(figsize=(10, 6))

    configs = list(data_dict.keys())
    x = np.arange(len(configs))
    width = 0.6

    bottom = np.zeros(len(configs))

    for component, color in zip(components, colors):
        values = [data_dict[cfg].get(component, 0) for cfg in configs]
        ax.bar(x, values, width, bottom=bottom, label=component.capitalize(), color=color)
        bottom += values

    ax.set_ylabel("Percentage of Step Time (%)")
    ax.set_title(title)
    ax.set_xticks(x)
    ax.set_xticklabels(configs, rotation=15, ha="right")
    ax.legend(loc="upper right")
    ax.set_ylim(0, 105)
    ax.grid(True, alpha=0.3, axis="y")

    plt.tight_layout()
    plt.savefig(output_path.replace(".png", ".pdf"))
    plt.savefig(output_path)
    plt.close()

    print(f"Saved stacked bar chart to {output_path}")


def plot_comparison_bars(data_baseline, data_treatment, output_path,
                         label_baseline="AdamW", label_treatment="AdamW + Spectral Clip"):
    """Create side-by-side comparison bar chart."""
    plt.rcParams.update(
        {
            "font.size": 11,
            "font.family": "serif",
            "axes.labelsize": 12,
            "axes.titlesize": 12,
            "xtick.labelsize": 10,
            "ytick.labelsize": 10,
            "legend.fontsize": 10,
            "figure.figsize": (8, 5),
            "figure.dpi": 150,
        }
    )

    components = ["forward", "backward", "spectral", "optimizer", "other"]

    fig, ax = plt.subplots(figsize=(10, 5))

    x = np.arange(len(components))
    width = 0.35

    baseline_values = [data_baseline.get(c, 0) for c in components]
    treatment_values = [data_treatment.get(c, 0) for c in components]

    bars1 = ax.bar(x - width / 2, baseline_values, width, label=label_baseline)
    bars2 = ax.bar(x + width / 2, treatment_values, width, label=label_treatment)

    ax.set_ylabel("Percentage of Step Time (%)")
    ax.set_title("Training Step Time Breakdown Comparison")
    ax.set_xticks(x)
    ax.set_xticklabels([c.capitalize() for c in components])
    ax.legend()
    ax.grid(True, alpha=0.3, axis="y")

    # Add value labels on bars
    for bars in [bars1, bars2]:
        for bar in bars:
            height = bar.get_height()
            if height > 2:  # Only label if > 2%
                ax.annotate(
                    f"{height:.1f}%",
                    xy=(bar.get_x() + bar.get_width() / 2, height),
                    xytext=(0, 3),
                    textcoords="offset points",
                    ha="center",
                    va="bottom",
                    fontsize=8,
                )

    plt.tight_layout()
    plt.savefig(output_path.replace(".png", ".pdf"))
    plt.savefig(output_path)
    plt.close()

    print(f"Saved comparison chart to {output_path}")


def main():
    parser = argparse.ArgumentParser(description="Plot timing breakdown from training logs")
    parser.add_argument(
        "--csv_path",
        type=str,
        default=None,
        help="Path to exported wandb CSV file",
    )
    parser.add_argument(
        "--wandb_run",
        type=str,
        default=None,
        help="Wandb run path (user/project/run_id)",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="./timing_plots",
        help="Directory to save plots",
    )
    parser.add_argument(
        "--demo",
        action="store_true",
        help="Generate demo plots with example data",
    )
    args = parser.parse_args()

    os.makedirs(args.output_dir, exist_ok=True)

    if args.demo:
        # Demo data based on expected results
        data = {
            "AdamW (baseline)": {
                "forward": 42.0,
                "backward": 45.0,
                "spectral": 0.0,
                "optimizer": 10.0,
                "other": 3.0,
            },
            "AdamW + Spectral Clip": {
                "forward": 40.0,
                "backward": 43.0,
                "spectral": 4.0,
                "optimizer": 10.0,
                "other": 3.0,
            },
            "Muon": {
                "forward": 40.0,
                "backward": 43.0,
                "spectral": 5.0,  # Muon's orthogonalization is similar
                "optimizer": 9.0,
                "other": 3.0,
            },
        }

        plot_stacked_bar(
            data,
            os.path.join(args.output_dir, "timing_breakdown_stacked.png"),
            title="Training Step Time Breakdown",
        )

        plot_comparison_bars(
            data["AdamW (baseline)"],
            data["AdamW + Spectral Clip"],
            os.path.join(args.output_dir, "timing_comparison.png"),
        )

        # Print summary table
        print("\n" + "=" * 60)
        print("TIMING BREAKDOWN SUMMARY (Demo Data)")
        print("=" * 60)
        for config, values in data.items():
            print(f"\n{config}:")
            total = sum(values.values())
            for component, pct in values.items():
                print(f"  {component:12s}: {pct:5.1f}% ({pct/100*total:.1f}ms estimated)")

        return

    if args.csv_path:
        data = {"Run": load_timing_from_csv(args.csv_path)}
        plot_stacked_bar(
            data,
            os.path.join(args.output_dir, "timing_breakdown.png"),
        )
    elif args.wandb_run:
        data = {"Run": load_timing_from_wandb(args.wandb_run)}
        plot_stacked_bar(
            data,
            os.path.join(args.output_dir, "timing_breakdown.png"),
        )
    else:
        print("No data source specified. Use --csv_path, --wandb_run, or --demo")
        print("\nExample usage:")
        print("  python scripts/analysis/plot_timing_breakdown.py --demo")
        print("  python scripts/analysis/plot_timing_breakdown.py --csv_path /path/to/export.csv")


if __name__ == "__main__":
    main()
