"""
Analyze and visualize experiment runs.
Generates comparison plots and summary tables from logged experiments.

Output is saved to run_analysis/ directory by default, organized by timestamp.
"""

import argparse
import json
from datetime import datetime
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np

from logger import load_all_runs, filter_runs, RUNS_DIR

# Output directory for analysis artifacts (plots, tables, etc.)
ANALYSIS_DIR = Path("run_analysis")


def generate_analysis_id() -> str:
    """Generate a unique analysis ID based on timestamp."""
    return datetime.now().strftime("%Y-%m-%d_%H.%M.%S")


def get_analysis_dir(timestamp: str) -> Path:
    """
    Get or create the analysis output directory for this session.
    
    Creates a timestamped subdirectory within run_analysis/ to keep
    outputs from different analysis sessions organized.
    
    Args:
        timestamp: Timestamp string for this analysis session.
        
    Returns:
        Path to the analysis output directory.
    """
    analysis_path = ANALYSIS_DIR / timestamp
    analysis_path.mkdir(parents=True, exist_ok=True)
    return analysis_path


def build_descriptive_filename(
    plot_type: str,
    group_by: str,
    objective_filter: str | None = None,
    n_arms_filter: int | None = None,
    **kwargs,
) -> str:
    """
    Build a descriptive filename for a plot based on its parameters.
    
    Follows the naming convention from ucb_expt.py:
    {plot_type}_{group_by}[_objective_{obj}][_arms_{n}].png
    
    Args:
        plot_type: Type of plot (e.g., "cumulative_regret", "perstep_regret").
        group_by: The config key used for grouping.
        objective_filter: If filtered by objective, include in name.
        n_arms_filter: If filtered by n_arms, include in name.
        **kwargs: Additional filters to include in filename.
        
    Returns:
        Descriptive filename string (without directory path).
    """
    parts = [plot_type, f"by_{group_by}"]
    
    # Add filter information to filename for clarity
    if objective_filter is not None:
        parts.append(f"objective_{objective_filter}")
    if n_arms_filter is not None:
        parts.append(f"arms_{n_arms_filter}")
    
    # Handle any additional filters passed via kwargs
    for key, value in kwargs.items():
        if value is not None:
            parts.append(f"{key}_{value}")
    
    return "_".join(parts) + ".png"


def extract_run_ids(runs: list[dict]) -> list[str]:
    """
    Extract run IDs from a list of run records.
    
    Args:
        runs: List of run records loaded from JSON files.
        
    Returns:
        List of run ID strings.
    """
    return [run.get("run_id", "unknown") for run in runs]


def add_plot_metadata(
    fig: plt.Figure,
    run_ids: list[str],
    analysis_timestamp: str,
    max_ids_per_line: int = 3,
) -> None:
    """
    Add metadata annotations to the bottom of a figure.
    
    Adds two pieces of information:
    1. List of run IDs used to generate the plot (bottom left, may wrap to multiple lines)
    2. Analysis timestamp (bottom right corner)
    
    Args:
        fig: Matplotlib figure to annotate.
        run_ids: List of run IDs to display.
        analysis_timestamp: Timestamp string for when analysis was created.
        max_ids_per_line: Maximum number of run IDs to show per line before wrapping.
    """
    # Format run IDs into wrapped lines for readability
    # Group IDs into chunks to avoid overly long lines
    id_chunks = [
        run_ids[i:i + max_ids_per_line] 
        for i in range(0, len(run_ids), max_ids_per_line)
    ]
    formatted_lines = [", ".join(chunk) for chunk in id_chunks]
    run_ids_text = "Run IDs: " + "\n         ".join(formatted_lines)
    
    # Add run IDs in bottom left (small font, gray color to not distract from main plot)
    fig.text(
        0.02, 0.02,  # x, y position (figure coordinates, 0-1)
        run_ids_text,
        fontsize=7,
        color="gray",
        verticalalignment="bottom",
        horizontalalignment="left",
        family="monospace",  # Monospace for cleaner alignment of wrapped IDs
    )
    
    # Add timestamp in bottom right
    timestamp_text = f"created: {analysis_timestamp}"
    fig.text(
        0.98, 0.02,
        timestamp_text,
        fontsize=7,
        color="gray",
        verticalalignment="bottom",
        horizontalalignment="right",
        style="italic",
    )


def summarize_runs(runs: list[dict]) -> None:
    """Print a summary table of all runs."""
    if not runs:
        print("No runs found.")
        return

    print(f"\n{'='*80}")
    print(f"{'Run ID':<28} {'Objective':<8} {'Arms':<6} {'Alloc':<6} {'Rounds':<8} {'Final Regret':<12}")
    print(f"{'='*80}")

    for run in runs:
        cfg = run["config"]
        res = run["results"]

        final_regret = res.get("avg_final_cumulative_regret", res.get("final_cumulative_regret", "N/A"))
        if isinstance(final_regret, (int, float)):
            final_regret = f"{final_regret:.4f}"

        print(
            f"{run['run_id']:<28} "
            f"{cfg.get('objective', 'N/A'):<8} "
            f"{cfg.get('n_arms', 'N/A'):<6} "
            f"{cfg.get('num_alloc', 'N/A'):<6} "
            f"{cfg.get('n_rounds', 'N/A'):<8} "
            f"{final_regret:<12}"
        )

    print(f"{'='*80}\n")


def plot_cumulative_regret_comparison(
    runs: list[dict],
    group_by: str = "objective",
    title: str | None = None,
    save_path: Path | None = None,
    show: bool = True,
    analysis_timestamp: str | None = None,
) -> None:
    """
    Plot cumulative regret curves for multiple runs, grouped by a config key.

    Args:
        runs: List of run records.
        group_by: Config key to group/color by (e.g., "objective", "n_arms").
        title: Optional plot title.
        save_path: If provided, save plot to this path.
        show: Whether to display plot interactively.
        analysis_timestamp: Timestamp for this analysis session (shown in plot metadata).
    """
    if not runs:
        print("No runs to plot.")
        return

    # Slightly taller figure to accommodate metadata text at bottom
    fig, ax = plt.subplots(figsize=(10, 6.8))

    # Group runs by the specified key
    groups = {}
    for run in runs:
        key = run["config"].get(group_by, "unknown")
        if key not in groups:
            groups[key] = []
        groups[key].append(run)

    colors = plt.cm.tab10(np.linspace(0, 1, len(groups)))

    for (group_name, group_runs), color in zip(sorted(groups.items()), colors):
        # Collect all regret curves for this group
        all_regrets = []
        for run in group_runs:
            regrets = run["results"].get("regrets") or run["results"].get("avg_regret")
            if regrets is not None:
                cumulative = np.cumsum(regrets)
                all_regrets.append(cumulative)

        if not all_regrets:
            continue

        # Align lengths (in case of different n_rounds)
        min_len = min(len(r) for r in all_regrets)
        all_regrets = np.array([r[:min_len] for r in all_regrets])

        mean_regret = np.mean(all_regrets, axis=0)
        std_regret = np.std(all_regrets, axis=0)
        n_runs = len(all_regrets)
        stderr = std_regret / np.sqrt(n_runs) if n_runs > 1 else std_regret

        x = np.arange(len(mean_regret))
        ax.plot(x, mean_regret, label=f"{group_by}={group_name} (n={n_runs})", color=color)
        ax.fill_between(
            x,
            mean_regret - 1.96 * stderr,
            mean_regret + 1.96 * stderr,
            alpha=0.2,
            color=color,
        )

    ax.set_xlabel("Round")
    ax.set_ylabel("Cumulative Regret")
    ax.set_title(title or f"Cumulative Regret by {group_by}")
    ax.legend()
    ax.grid(True, alpha=0.3)

    # Add space at bottom of figure for metadata and add annotations
    plt.subplots_adjust(bottom=0.15)
    if analysis_timestamp:
        run_ids = extract_run_ids(runs)
        add_plot_metadata(fig, run_ids, analysis_timestamp)

    plt.tight_layout(rect=[0, 0.08, 1, 1])  # Leave room at bottom for metadata

    if save_path:
        save_path.parent.mkdir(parents=True, exist_ok=True)
        plt.savefig(save_path, dpi=150)
        print(f"Saved plot to {save_path}")

    if show:
        plt.show()
    plt.close()


def plot_per_step_regret_comparison(
    runs: list[dict],
    group_by: str = "objective",
    window: int = 100,
    title: str | None = None,
    save_path: Path | None = None,
    show: bool = True,
    analysis_timestamp: str | None = None,
) -> None:
    """
    Plot smoothed per-step regret for multiple runs.

    Args:
        runs: List of run records.
        group_by: Config key to group/color by.
        window: Smoothing window size for moving average.
        title: Optional plot title.
        save_path: If provided, save plot to this path.
        show: Whether to display plot interactively.
        analysis_timestamp: Timestamp for this analysis session (shown in plot metadata).
    """
    if not runs:
        print("No runs to plot.")
        return

    # Slightly taller figure to accommodate metadata text at bottom
    fig, ax = plt.subplots(figsize=(10, 6.8))

    groups = {}
    for run in runs:
        key = run["config"].get(group_by, "unknown")
        if key not in groups:
            groups[key] = []
        groups[key].append(run)

    colors = plt.cm.tab10(np.linspace(0, 1, len(groups)))

    def moving_average(arr, w):
        if len(arr) < w:
            return arr
        return np.convolve(arr, np.ones(w) / w, mode="valid")

    for (group_name, group_runs), color in zip(sorted(groups.items()), colors):
        all_regrets = []
        for run in group_runs:
            regrets = run["results"].get("regrets") or run["results"].get("avg_regret")
            if regrets is not None:
                smoothed = moving_average(np.array(regrets), window)
                all_regrets.append(smoothed)

        if not all_regrets:
            continue

        min_len = min(len(r) for r in all_regrets)
        all_regrets = np.array([r[:min_len] for r in all_regrets])

        mean_regret = np.mean(all_regrets, axis=0)
        std_regret = np.std(all_regrets, axis=0)
        n_runs = len(all_regrets)
        stderr = std_regret / np.sqrt(n_runs) if n_runs > 1 else std_regret

        x = np.arange(len(mean_regret))
        ax.plot(x, mean_regret, label=f"{group_by}={group_name} (n={n_runs})", color=color)
        ax.fill_between(
            x,
            mean_regret - 1.96 * stderr,
            mean_regret + 1.96 * stderr,
            alpha=0.2,
            color=color,
        )

    ax.set_xlabel("Round")
    ax.set_ylabel(f"Per-Step Regret (smoothed, window={window})")
    ax.set_title(title or f"Per-Step Regret by {group_by}")
    ax.legend()
    ax.grid(True, alpha=0.3)

    # Add space at bottom of figure for metadata and add annotations
    plt.subplots_adjust(bottom=0.15)
    if analysis_timestamp:
        run_ids = extract_run_ids(runs)
        add_plot_metadata(fig, run_ids, analysis_timestamp)

    plt.tight_layout(rect=[0, 0.08, 1, 1])  # Leave room at bottom for metadata

    if save_path:
        save_path.parent.mkdir(parents=True, exist_ok=True)
        plt.savefig(save_path, dpi=150)
        print(f"Saved plot to {save_path}")

    if show:
        plt.show()
    plt.close()


def generate_latex_table(runs: list[dict], group_by: str = "objective") -> str:
    """
    Generate a LaTeX table summarizing runs grouped by a config key.
    Useful for ICML paper.
    """
    groups = {}
    for run in runs:
        key = run["config"].get(group_by, "unknown")
        if key not in groups:
            groups[key] = []
        groups[key].append(run)

    lines = [
        r"\begin{table}[h]",
        r"\centering",
        r"\begin{tabular}{lccc}",
        r"\toprule",
        f"{group_by.capitalize()} & Arms & Alloc & Final Cumulative Regret \\\\",
        r"\midrule",
    ]

    for group_name in sorted(groups.keys()):
        group_runs = groups[group_name]

        final_regrets = []
        for run in group_runs:
            res = run["results"]
            fr = res.get("avg_final_cumulative_regret", res.get("final_cumulative_regret"))
            if fr is not None:
                final_regrets.append(fr)

        if final_regrets:
            mean_fr = np.mean(final_regrets)
            std_fr = np.std(final_regrets)
            regret_str = f"${mean_fr:.2f} \\pm {std_fr:.2f}$"
        else:
            regret_str = "N/A"

        cfg = group_runs[0]["config"]
        n_arms = cfg.get("n_arms", "N/A")
        num_alloc = cfg.get("num_alloc", "N/A")

        lines.append(f"{group_name} & {n_arms} & {num_alloc} & {regret_str} \\\\")

    lines.extend([
        r"\bottomrule",
        r"\end{tabular}",
        r"\caption{Cumulative regret comparison across objectives.}",
        r"\label{tab:regret-comparison}",
        r"\end{table}",
    ])

    return "\n".join(lines)


def save_latex_table(table: str, save_path: Path) -> None:
    """Save LaTeX table to a .tex file."""
    save_path.parent.mkdir(parents=True, exist_ok=True)
    with open(save_path, "w") as f:
        f.write(table)
    print(f"Saved LaTeX table to {save_path}")


def main():
    parser = argparse.ArgumentParser(description="Analyze experiment runs")
    parser.add_argument("--objective", type=str, help="Filter by objective (wpm, kolm, gini)")
    parser.add_argument("--n-arms", type=int, help="Filter by number of arms")
    parser.add_argument("--group-by", type=str, default="objective", help="Group plots by this config key")
    parser.add_argument("--summary", action="store_true", help="Print summary table")
    parser.add_argument("--plot-cumulative", action="store_true", help="Plot cumulative regret")
    parser.add_argument("--plot-perstep", action="store_true", help="Plot per-step regret")
    parser.add_argument("--latex", action="store_true", help="Generate LaTeX table")
    # Changed: now defaults to saving, use --no-save to disable (matches ucb_expt.py pattern)
    parser.add_argument("--no-save", action="store_true", help="Don't save outputs to files")
    parser.add_argument("--no-show", action="store_true", help="Don't display plots (just save)")

    args = parser.parse_args()

    # Load and filter runs
    runs = load_all_runs()
    runs = filter_runs(runs, objective=args.objective, n_arms=args.n_arms)

    if not runs:
        print(f"No runs found in {RUNS_DIR}/")
        print("Run some experiments first with ucb_expt.py")
        return

    print(f"Found {len(runs)} runs")

    # Generate timestamp for this analysis session (used for output directory)
    timestamp = generate_analysis_id()
    save_outputs = not args.no_save
    show_plots = not args.no_show

    # If saving is enabled, create the output directory and report it
    if save_outputs:
        output_dir = get_analysis_dir(timestamp)
        print(f"Saving outputs to {output_dir}/")

    if args.summary or not any([args.plot_cumulative, args.plot_perstep, args.latex]):
        summarize_runs(runs)

    if args.plot_cumulative:
        # Build descriptive filename based on current filters
        filename = build_descriptive_filename(
            plot_type="cumulative_regret",
            group_by=args.group_by,
            objective_filter=args.objective,
            n_arms_filter=args.n_arms,
        )
        save_path = get_analysis_dir(timestamp) / filename if save_outputs else None
        plot_cumulative_regret_comparison(
            runs, 
            group_by=args.group_by, 
            save_path=save_path,
            show=show_plots,
            analysis_timestamp=timestamp,
        )

    if args.plot_perstep:
        filename = build_descriptive_filename(
            plot_type="perstep_regret",
            group_by=args.group_by,
            objective_filter=args.objective,
            n_arms_filter=args.n_arms,
        )
        save_path = get_analysis_dir(timestamp) / filename if save_outputs else None
        plot_per_step_regret_comparison(
            runs, 
            group_by=args.group_by, 
            save_path=save_path,
            show=show_plots,
            analysis_timestamp=timestamp,
        )

    if args.latex:
        table = generate_latex_table(runs, group_by=args.group_by)
        print("\nLaTeX Table:\n")
        print(table)
        
        # Also save to file if saving is enabled
        if save_outputs:
            filename = build_descriptive_filename(
                plot_type="table",
                group_by=args.group_by,
                objective_filter=args.objective,
                n_arms_filter=args.n_arms,
            ).replace(".png", ".tex")  # Change extension for LaTeX
            save_latex_table(table, get_analysis_dir(timestamp) / filename)


if __name__ == "__main__":
    main()
