"""Generate a markdown report from exploration analysis results."""
from __future__ import annotations

from pathlib import Path
from typing import Optional

import pandas as pd


def generate_report(
    output_dir: Path,
    checkpoint_table: pd.DataFrame,
    correlation_table: Optional[pd.DataFrame] = None,
    regression_results: Optional[list[dict]] = None,
    sanity_checks: Optional[dict] = None,
    config_name: str = "exploration_analysis",
) -> str:
    """Generate a markdown report and write to output_dir/report.md.

    Returns the report text.
    """
    lines = []
    lines.append(f"# Exploration Analysis Report: {config_name}")
    lines.append("")
    lines.append("## 1. Overview")
    lines.append("")
    lines.append("This report analyzes whether RL exploration metrics on puzzle traces")
    lines.append("correlate with cross-domain math capability gains (OlymMATH Hard pass@32).")
    lines.append("")

    # Metric definitions
    lines.append("## 2. Metric Definitions")
    lines.append("")
    lines.append("### Primitive Metrics")
    lines.append("- `{PRIM}_per_1k`: Count of {PRIM} episodes per 1000 tokens (length-normalized)")
    lines.append("- `primitive_entropy`: Shannon entropy over primitive label distribution")
    lines.append("- `primitive_bigram_entropy`: Entropy over primitive transition bigrams")
    lines.append("- `n_unique_primitives`: Number of distinct primitive types used")
    lines.append("")
    lines.append("### Diversity Metrics")
    lines.append("- `num_clusters`: Number of distinct reasoning path clusters per prompt")
    lines.append("- `cluster_entropy`: Shannon entropy of cluster size distribution")
    lines.append("- `effective_num_paths`: Inverse Simpson index (1/Σp²)")
    lines.append("- `top_cluster_mass`: Fraction of traces in the largest cluster")
    lines.append("")
    lines.append("### Novelty Metrics")
    lines.append("- `novel_cluster_mass_rl`: Fraction of RL traces in novel clusters (SFT mass < τ)")
    lines.append("- `successful_novel_cluster_mass_rl`: Same, restricted to correct RL traces")
    lines.append("")

    # Summary table
    lines.append("## 3. Checkpoint Summary")
    lines.append("")
    if not checkpoint_table.empty:
        # Select key columns for display
        display_cols = [c for c in checkpoint_table.columns
                        if any(k in c for k in ["checkpoint", "sft_base", "pass", "gain",
                                                  "primitive_entropy", "effective_num",
                                                  "novel_cluster"])]
        if display_cols:
            lines.append(checkpoint_table[display_cols].to_markdown(index=False))
        else:
            lines.append(checkpoint_table.to_markdown(index=False))
    else:
        lines.append("*No checkpoint data available.*")
    lines.append("")

    # Correlations
    lines.append("## 4. Correlations with Math Gain")
    lines.append("")
    if correlation_table is not None and not correlation_table.empty:
        lines.append("### Spearman Rank Correlations")
        lines.append("")
        lines.append(correlation_table.to_markdown(index=False))
    else:
        lines.append("*Insufficient data for correlation analysis.*")
    lines.append("")

    # Regression
    lines.append("## 5. Regression Results")
    lines.append("")
    if regression_results:
        for reg in regression_results:
            predictors = reg.get("predictors", [reg.get("predictor", "?")])
            lines.append(f"### {' + '.join(predictors)} → {reg.get('outcome', '?')}")
            lines.append(f"- R² = {reg.get('r_squared', 'N/A'):.4f}" if isinstance(reg.get('r_squared'), float) else f"- R² = N/A")
            lines.append(f"- N = {reg.get('n', '?')}")
            if "coefficients" in reg:
                for p, c in reg["coefficients"].items():
                    lines.append(f"  - {p}: β = {c:.4f}")
            elif "slope" in reg:
                lines.append(f"  - slope = {reg['slope']:.4f}, p = {reg.get('p_value', 'N/A'):.4f}" if isinstance(reg.get('slope'), float) else "")
            lines.append("")
    else:
        lines.append("*No regression results available.*")
    lines.append("")

    # Sanity checks
    lines.append("## 6. Sanity Checks")
    lines.append("")
    if sanity_checks:
        for check_name, result in sanity_checks.items():
            status = "PASS" if result.get("ok", False) else "WARN"
            lines.append(f"- **{check_name}**: {status} — {result.get('message', '')}")
    else:
        lines.append("*No sanity checks run.*")
    lines.append("")

    # Plots
    lines.append("## 7. Plots")
    lines.append("")
    plots_dir = output_dir / "plots"
    if plots_dir.exists():
        for png in sorted(plots_dir.glob("*.png")):
            lines.append(f"### {png.stem.replace('_', ' ').title()}")
            lines.append(f"![{png.stem}](plots/{png.name})")
            lines.append("")

    # Key findings placeholder
    lines.append("## 8. Key Findings")
    lines.append("")
    lines.append("*To be filled in after reviewing the data.*")
    lines.append("")

    report_text = "\n".join(lines)

    # Write
    report_path = output_dir / "report.md"
    report_path.write_text(report_text)
    print(f"Report written to: {report_path}")

    return report_text
