import os
from typing import Dict, List, Optional, Tuple

import pandas as pd
import matplotlib.pyplot as plt
# Install dependencies: pip install pandas matplotlib wandb
# If you encounter SSL errors, you can try: pip install wandb --no-cache-dir
from wandb.apis.public import Api


def fetch_run_history(api: Api, run_path: str, keys: Optional[List[str]] = None) -> pd.DataFrame:
    """
    Fetches the history of a specified W&B run and returns it as a DataFrame.
    """
    try:
        run = api.run(run_path)
        history_iter = run.scan_history(keys=keys)
        records = [dict(row) for row in history_iter]
        if not records:
            print(f"Warning: No data returned for run '{run_path}'.")
            return pd.DataFrame()
        return pd.DataFrame.from_records(records)
    except Exception as e:
        print(f"Error fetching W&B run '{run_path}': {e}")
        return pd.DataFrame()


def resolve_step_column(df: pd.DataFrame) -> Optional[str]:
    """Automatically identifies the step column from common names."""
    candidates = ["trainer/global_step", "global_step", "step", "steps", "_step"]
    for col in candidates:
        if col in df.columns:
            return col
    for col in df.columns:
        if str(col).lower().endswith("step"):
            return col
    return None


def collect_metric_columns(dfs: List[pd.DataFrame]) -> List[str]:
    """Aggregates all numeric metric columns to be plotted."""
    target_cols = {"actor/entropy", "avg_score/16384", "response_length/clip_ratio", "response_length/mean"}
    
    numeric_cols: List[str] = []
    for col in sorted(list(target_cols)):
        for df in dfs:
            if col in df.columns:
                series = pd.to_numeric(df[col], errors='coerce')
                if series.notna().any():
                    numeric_cols.append(col)
                    break
    return numeric_cols


def slice_and_remap_steps(
    df: pd.DataFrame,
    step_col: str,
    start_inclusive: int,
    end_exclusive: int,
    step_offset: int,
    experiment_id: int,
) -> pd.DataFrame:
    """Slices and remaps the steps of a DataFrame."""
    if df.empty or step_col not in df.columns:
        return pd.DataFrame()
        
    work = df.copy()
    work = work.dropna(subset=[step_col])
    work[step_col] = pd.to_numeric(work[step_col], errors='coerce')
    work = work.dropna(subset=[step_col])

    work = work[(work[step_col] >= start_inclusive) & (work[step_col] <= end_exclusive)]
    if work.empty:
        return work

    work = work.sort_values(step_col).drop_duplicates(subset=[step_col], keep="last")
    
    work["remapped_step"] = work[step_col].astype(int) - start_inclusive + step_offset
    work["experiment_id"] = experiment_id
    work["original_step"] = work[step_col].astype(int)
    return work


def plot_metrics(
    segments: List[pd.DataFrame],
    metrics: List[str],
    segment_labels: List[str],
    out_path: str,
    stage_boundaries: List[int],
    stage_names: List[str],
    step_col: str,
    baseline_dfs: Optional[List[pd.DataFrame]] = None,
    baseline_label: str = "Baseline",
) -> None:
    """
    Generates subplots with a shared legend, stage annotations, and an optional baseline curve.
    """
    if not metrics:
        print("No metrics available for plotting.")
        return

    xmax = stage_boundaries[-1] if stage_boundaries else 1
    n_metrics = len(metrics)
    ncols = 2
    nrows = (n_metrics + ncols - 1) // ncols
    
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(16, 6 * nrows), dpi=150)
    axes = axes.flatten()
    
    color_cycle = plt.get_cmap("tab10").colors
   

    for idx, metric in enumerate(metrics):
        ax = axes[idx]
        is_sparse_metric = metric == "avg_score/16384"

        # Plot staged experiment segments
        for i, seg in enumerate(segments):
            if seg.empty or metric not in seg.columns: continue
            color = color_cycle[i % len(color_cycle)]
            valid_data = seg.dropna(subset=[metric])
            if valid_data.empty: continue
            plot_args = {"marker": "o", "markersize": 4, "linestyle": "-", "linewidth": 3.0} if is_sparse_metric else {"linewidth": 2.0}
            ax.plot(valid_data["remapped_step"], valid_data[metric], label=segment_labels[i], color=color, **plot_args)

        # --- NEW: Plot the complete baseline curve ---
        if baseline_dfs is not None and metric in baseline_dfs[0].columns:
            for i, baseline_df in enumerate(baseline_dfs):
                valid_baseline_data = baseline_df.dropna(subset=[metric, step_col]).sort_values(by=step_col)
                if not valid_baseline_data.empty:
                    ax.plot(valid_baseline_data[step_col], # Use original steps for x-axis
                            valid_baseline_data[metric],
                            label=f"{baseline_label[i]}",
                            color='gray' if i == 0 else 'purple',
                            linestyle='-.',
                            linewidth=2.0,
                            zorder=0 # Plot behind the main experiments
                    )

        ax.set_title(metric, fontsize=16, fontweight='bold', pad=10)
        ax.set_xlim(left=0, right=max(1, xmax))
        ax.grid(True, linestyle=":", linewidth=0.7)
        ax.set_xlabel("Remapped Step", fontsize=12)
        ax.set_ylabel("Value", fontsize=12)
        
        if metric == "avg_score/16384":
            ax.axhline(y=0.5222, color='red', linestyle='--', linewidth=1.5, zorder=0)
            ax.axhline(y=0.5413, color='red', linestyle='-.', linewidth=1.5, zorder=0)

        for i, boundary in enumerate(stage_boundaries):
            if i > 0: ax.axvline(boundary, color="grey", linestyle="--", linewidth=1.2)
        
    

    # --- Create a single, shared legend for the entire figure ---
    handles, labels = axes[0].get_legend_handles_labels()
    from matplotlib.lines import Line2D
    handles.append(Line2D([0], [0], color='red', linestyle='--', linewidth=1.5))
    labels.append('8k Baseline (Clipped)')
    handles.append(Line2D([0], [0], color='red', linestyle='-.', linewidth=1.5))
    labels.append('16k Baseline (Clipped)')
        
    # 隐藏多余的子图
    for idx in range(len(metrics), len(axes)):
        axes[idx].set_visible(False)
    
    # 调整子图布局，增加间距
    plt.subplots_adjust(left=0.08, right=0.95, top=0.90, bottom=0.08, hspace=0.35, wspace=0.25)
    
    # 进一步优化图例布局，确保完全显示
    ncol = min(4, len(labels))  # 每行最多4个标签
    fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 0.98), 
               ncol=ncol, fontsize=10, frameon=True, fancybox=True, shadow=True,
               columnspacing=1.5, handletextpad=0.8)
    
    plt.savefig(out_path, bbox_inches='tight')
    print(f"Plot saved successfully to: {out_path}")


def main() -> None:
    # --- Configuration Area ---
    experiment_stages = [
        [("astrid_tuning_llm/verl-qwen3-4b-base/wvop58l0", 0, 60, "add1k-round1", 0)],
        [("astrid_tuning_llm/verl-qwen3-4b-base/xw82ja08", 0, 250, "dapo 8k", 0),
         ("astrid_tuning_llm/verl-qwen3-4b-base/1o6qud6h", 0, 250, "dapo 12k", 1),
         ("astrid_tuning_llm/verl-qwen3-4b-base/map1dxdn", 0, 250, "dapo 16k", 2)],
        # [("astrid_tuning_llm/verl-qwen3-4b-base/m0ur9f39", 0, 20, "add1k-round2", 0)],
        # [("astrid_tuning_llm/verl-qwen3-4b-base/85jv43oc", 0, 110, "dapo", 0),
        #  ("astrid_tuning_llm/verl-qwen3-4b-base/qovm84x1", 0, 110, "dapo overlong filter", 1)],
    ]
    
    # --- NEW: Define the baseline run path ---
    baseline_run_paths = ["astrid_tuning_llm/verl-qwen3-4b-base/3fsbgafr", "astrid_tuning_llm/verl-qwen3-4b-base/dye4avkn"]
    baseline_legend_labels = ["8k Baseline", "16k Baseline"]

    stage_names = ["Stage 1: add1k", "Stage 2"]

    # --- Data Processing ---
    processed_runs, stage_boundaries, current_offset = [], [0], 0
    for stage_runs in experiment_stages:
        stage_duration = max(end - start for _, start, end, _, _ in stage_runs)
        for run_path, start, end, label, branch_id in stage_runs:
            processed_runs.append({"run_path": run_path, "start": start, "end": end, "label": f"{label}-[{branch_id}]", "offset": current_offset})
        current_offset += stage_duration
        stage_boundaries.append(current_offset)

    # --- Fetch data for both staged runs and the baseline ---
    api = Api()
    staged_dfs = [fetch_run_history(api, run["run_path"]) for run in processed_runs]
    baseline_dfs = [fetch_run_history(api, run_path) for run_path in baseline_run_paths]
    
    # Use all dataframes to determine columns and step_col name
    all_dfs_for_metadata = staged_dfs + baseline_dfs
    step_col = next((resolve_step_column(df) for df in all_dfs_for_metadata if not df.empty), "_step")
    metric_cols = collect_metric_columns(all_dfs_for_metadata)
    print(f"Identified step column: '{step_col}'")
    print(f"Metrics to be plotted: {metric_cols}")
    
    # Process staged segments
    segments, segment_legend_labels = [], []
    for i, run_spec in enumerate(processed_runs):
        df = staged_dfs[i]
        for col in metric_cols:
            if col not in df.columns: df[col] = pd.NA
        seg = slice_and_remap_steps(df, step_col, run_spec["start"], run_spec["end"], run_spec["offset"], i)
        segments.append(seg)
        segment_legend_labels.append(run_spec["label"])

    # --- Plotting ---
    out_dir = os.getcwd()
    out_path = os.path.join(out_dir, "merged_runs_with_baseline_16k.png")
    plot_metrics(
        segments, metric_cols, segment_legend_labels, out_path, 
        stage_boundaries, stage_names, step_col=step_col,
        baseline_dfs=baseline_dfs, baseline_label=baseline_legend_labels
    )


if __name__ == "__main__":
    main()