"""
report.py — Generate Markdown report (.pplx.md) for the consistency loss experiment.

Includes:
  - Summary of run configuration (smoke vs full)
  - Per-variant metrics table at final epoch
  - Statistical test (t-test) comparing consistency_loss vs no_consistency_loss
  - At least 10 qualitative examples: epoch-1 vs epoch-final generations
  - Honest reporting: smoke results labeled as such
"""

import os
import math
import datetime
import numpy as np
import pandas as pd
from typing import List, Dict, Optional
from scipy import stats as scipy_stats


# ──────────────────────────────────────────────────────────────────────────────
# 1. Statistical test helpers
# ──────────────────────────────────────────────────────────────────────────────

def welch_t_test(df: pd.DataFrame, metric: str = "val_bleu1") -> dict:
    """
    Welch's t-test comparing consistency_loss vs no_consistency_loss
    on the specified metric at the final epoch.
    Since we only have one measurement per variant per epoch (not a distribution
    over seeds), we compute the test over the epoch trajectory as a proxy for
    variance (treating each epoch as an observation, last 50% of epochs).
    """
    variants = df["variant"].unique().tolist()
    result = {"metric": metric, "available": False}

    if "consistency_loss" not in variants or "no_consistency_loss" not in variants:
        return result

    df_sorted = df.sort_values("epoch")
    max_epoch = df_sorted["epoch"].max()
    cutoff    = max_epoch // 2  # use second half of training

    a = df_sorted[(df_sorted["variant"] == "consistency_loss") &
                  (df_sorted["epoch"] > cutoff)][metric].values
    b = df_sorted[(df_sorted["variant"] == "no_consistency_loss") &
                  (df_sorted["epoch"] > cutoff)][metric].values

    if len(a) < 2 or len(b) < 2:
        result["note"] = "Insufficient epochs for t-test (need >2 epochs in second half)"
        return result

    t_stat, p_val = scipy_stats.ttest_ind(a, b, equal_var=False)
    mean_a, mean_b = float(np.mean(a)), float(np.mean(b))
    result.update({
        "available": True,
        "mean_consistency_loss": mean_a,
        "mean_no_consistency_loss": mean_b,
        "t_statistic": float(t_stat),
        "p_value": float(p_val),
        "significant_at_0.05": p_val < 0.05,
        "n_epochs_a": len(a),
        "n_epochs_b": len(b),
    })
    return result


# ──────────────────────────────────────────────────────────────────────────────
# 2. Report builder
# ──────────────────────────────────────────────────────────────────────────────

def build_report(
    df: pd.DataFrame,
    qualitative_examples: List[dict],
    cfg_dict: dict,
    is_smoke: bool,
    output_path: str,
):
    """
    Builds and writes the Markdown report.

    Parameters
    ----------
    df : DataFrame with all epoch metrics for all variants
    qualitative_examples : list of dicts from trainer.collect_qualitative_examples
    cfg_dict : dict of effective training config
    is_smoke : whether this was a smoke run
    output_path : where to write the .pplx.md file
    """
    lines = []
    now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")

    # ── Header ────────────────────────────────────────────────────────────────
    lines += [
        "# Consistency Loss Experiment — Results Report",
        "",
        f"**Generated:** {now}",
        "",
    ]

    if is_smoke:
        lines += [
            "> **NOTE: This report is based on a SMOKE RUN.**",
            "> A smoke run uses a smaller dataset, fewer epochs, and a tiny model",
            "> to verify that all components work end-to-end. The full configured",
            "> run uses 3,000 examples, 20 epochs, and a larger model (see §1).",
            "> Results should be interpreted as proof-of-mechanism, not as final",
            "> experimental conclusions. Metric values from the smoke run may not",
            "> reflect the trends expected at full scale.",
            "",
        ]

    # ── 1. Configuration ──────────────────────────────────────────────────────
    lines += ["## 1. Experiment Configuration", ""]

    full_cfg = {
        "Dataset size (full config)":    "3,000 examples",
        "Validation set (full config)":  "500 examples",
        "Epochs (full config)":          "20",
        "Batch size (full config)":      "32",
        "Learning rate":                 "5e-5",
        "Lambda (consistency weight)":   "1.0",
        "Model (full config)":           "GPT-2-style Transformer, `small` config (~10M params); `gpt2_small` documented for ~117M",
        "Optimizer":                     "AdamW (weight decay=0.01, grad clip=1.0)",
        "Checkpoint interval":           "every 5 epochs",
    }

    smoke_cfg = {
        "Dataset size (smoke run)":  str(cfg_dict.get("n_examples", "?")),
        "Validation set (smoke run)": str(cfg_dict.get("val_size", "?")),
        "Epochs (smoke run)":        str(cfg_dict.get("n_epochs", "?")),
        "Batch size (smoke run)":    str(cfg_dict.get("batch_size", "?")),
        "Model (smoke run)":         cfg_dict.get("model_size", "?"),
        "Max sequence length":       str(cfg_dict.get("max_seq_len", "?")),
    }

    lines += ["### Full Configuration (target)"]
    lines += ["| Parameter | Value |", "|---|---|"]
    for k, v in full_cfg.items():
        lines.append(f"| {k} | {v} |")
    lines += [""]

    if is_smoke:
        lines += ["### Smoke Run Configuration (actual)"]
        lines += ["| Parameter | Value |", "|---|---|"]
        for k, v in smoke_cfg.items():
            lines.append(f"| {k} | {v} |")
        lines += [""]

    # ── 2. Architecture ───────────────────────────────────────────────────────
    lines += [
        "## 2. Model Architecture",
        "",
        "The model is a **GPT-2-style causal Transformer** implemented from scratch",
        "in PyTorch with identical causal masking semantics to GPT-2.",
        "",
        "**Key design choices:**",
        "",
        "- **Causal masking**: Standard lower-triangular attention mask.",
        "  Explanation tokens appear *before* claim tokens in the sequence,",
        "  so causal attention already prevents explanation tokens from attending",
        "  to future claim tokens. An explicit additive attention bias further",
        "  enforces this structural constraint.",
        "",
        "- **Sequence format**: `<bos> [code] <sep> [explanation] <claim>time_complexity=X</claim>`",
        "  `<claim>space_complexity=Y</claim> <claim>correctness=Z</claim> <eos>`",
        "",
        "- **LM head**: Tied to token embeddings. Loss computed over full sequence",
        "  (next-token prediction).",
        "",
        "- **Consistency head**: Three linear classifiers (time complexity, space",
        "  complexity, correctness) applied to *mean-pooled hidden states* of",
        "  explanation tokens from the final Transformer layer.",
        "",
        "- **Note on GPT-2 weights**: This implementation is a from-scratch",
        "  GPT-2-compatible architecture. Loading pretrained GPT-2 weights would",
        "  require the `transformers` library. The `gpt2_small` config (768 dim,",
        "  12 heads, 12 layers, ~117M params) is provided but requires GPU with",
        "  ≥8GB VRAM.",
        "",
    ]

    # ── 3. Variants ───────────────────────────────────────────────────────────
    lines += [
        "## 3. Experimental Variants",
        "",
        "### V1 — Original Ablation Ladder",
        "",
        "| Variant | Description | Ablation axis |",
        "|---|---|---|",
        "| `consistency_loss` | Full mechanism: LM loss + consistency loss on explanation token pooling | Baseline reference |",
        "| `no_consistency_loss` | LM loss only; no gradient through consistency head | Isolates LM-only training |",
        "| `claim_only_pooling` | Negative control: pool *claim* tokens instead of explanation tokens | Tests pooling location |",
        "| `random_label_consistency` | Negative control: consistency loss with shuffled ground-truth labels | Tests label signal |",
        "",
        "### V2 — Stronger Ablation Ladder (strict flow + surface bottleneck)",
        "",
        "| Variant | Description | Ablation axis |",
        "|---|---|---|",
        "| `no_claim_to_claim_attention` | Like `consistency_loss` but claim tokens **cannot attend other claim tokens**; claim queries see code + explanation + self only | Tests cross-claim information flow |",
        "| `claims_from_explanation_only` | **Strict flow bottleneck**: claim tokens can only attend explanation tokens (not code, not BOS/SEP, not other claims). | Tests whether code-to-claim path can be forced through explanation |",
        "| `surface_bottleneck_consistency` | Consistency signal derived from **softmax distributions** (LM logit probs) at explanation positions, not hidden states. Gradients flow through LM outputs. | Tests whether surface-form explanation must encode claim info |",
        "| `surface_bottleneck_no_expl_lm` | Surface bottleneck + **LM loss disabled on mismatched explanation tokens**. Only code and claim positions contribute to LM loss. | Most extreme: removes incentive to fit mismatched explanation text |",
        "",
        "**Key predictions for V2:**",
        "",
        "- `no_claim_to_claim_attention`: similar coupling to V1 `consistency_loss` but tests cross-claim span flow.",
        "- `claims_from_explanation_only`: forces code-to-explanation-to-claim information path; explanation hidden states should develop stronger semantic structure.",
        "- `surface_bottleneck_consistency`: tests whether consistency pressure propagates to explanation *token choices* (surface form).",
        "- `surface_bottleneck_no_expl_lm`: most extreme — sole pressure on explanation logits is the surface bottleneck consistency signal.",
        "",
    ]

    # ── 4. Final-epoch metrics table ─────────────────────────────────────────
    lines += ["## 4. Final-Epoch Validation Metrics", ""]

    max_epoch = df["epoch"].max()
    final = df[df["epoch"] == max_epoch].copy()

    if is_smoke:
        lines.append(f"*Metrics at epoch {max_epoch} (smoke run final epoch; "
                     f"corresponds to epoch {max_epoch}/{cfg_dict.get('n_epochs','?')} of smoke config).*")
    else:
        lines.append(f"*Metrics at epoch {max_epoch} (final epoch).*")
    lines.append("")

    metric_cols = [
        ("val_coupling_strength", "Coupling Strength"),
        ("val_bleu1",             "BLEU-1"),
        ("val_rouge_l",           "ROUGE-L"),
        ("val_swap_influence",    "Swap Influence"),
        ("val_claim_accuracy",    "Claim Accuracy"),
        ("val_lm_loss",           "Val LM Loss"),
    ]

    header_cols = ["Variant"] + [m[1] for m in metric_cols]
    lines.append("| " + " | ".join(header_cols) + " |")
    lines.append("|" + "|".join(["---"] * len(header_cols)) + "|")

    variant_order = [
        # V1 original
        "consistency_loss", "no_consistency_loss",
        "claim_only_pooling", "random_label_consistency",
        # V2 stronger ablations
        "no_claim_to_claim_attention", "claims_from_explanation_only",
        "surface_bottleneck_consistency", "surface_bottleneck_no_expl_lm",
    ]
    for v in variant_order:
        row = final[final["variant"] == v]
        if row.empty:
            continue
        r = row.iloc[0]
        vals = [v.replace("_", " ").title()]
        for col, _ in metric_cols:
            try:
                vals.append(f"{r[col]:.4f}")
            except (KeyError, TypeError):
                vals.append("N/A")
        lines.append("| " + " | ".join(vals) + " |")

    lines += [""]

    # ── 5. Statistical test ───────────────────────────────────────────────────
    lines += ["## 5. Statistical Test: consistency_loss vs no_consistency_loss", ""]

    t_result = welch_t_test(df, metric="val_bleu1")

    if not t_result.get("available"):
        lines += [
            "Welch's t-test could not be computed (insufficient epoch data).",
            f"Note: {t_result.get('note', 'Unknown reason.')}",
            "",
            "**Interpretation**: With more epochs (full 20-epoch run), the test would",
            "compare BLEU-1 scores in the second half of training across variants.",
            "",
        ]
    else:
        sig = "**significant** (p < 0.05)" if t_result["significant_at_0.05"] else "not significant (p ≥ 0.05)"
        direction = ("higher" if t_result["mean_consistency_loss"] > t_result["mean_no_consistency_loss"]
                     else "lower")

        lines += [
            "**Metric**: BLEU-1 (explanation correctness proxy)",
            f"**Test**: Welch's independent-samples t-test (epochs > {max_epoch // 2})",
            "",
            f"| | consistency\\_loss | no\\_consistency\\_loss |",
            "|---|---|---|",
            f"| Mean BLEU-1 | {t_result['mean_consistency_loss']:.4f} | {t_result['mean_no_consistency_loss']:.4f} |",
            f"| n epochs | {t_result['n_epochs_a']} | {t_result['n_epochs_b']} |",
            "",
            f"**t = {t_result['t_statistic']:.3f}, p = {t_result['p_value']:.4f}** — {sig}",
            "",
            f"`consistency_loss` has **{direction}** mean BLEU-1 than `no_consistency_loss`.",
            "",
        ]

        if is_smoke:
            lines += [
                "> **Smoke-run caveat**: With only {} epochs and {} training examples,".format(
                    cfg_dict.get('n_epochs', '?'), cfg_dict.get('n_examples', '?')),
                "> statistical power is very low. The t-test result should not be",
                "> interpreted as a strong conclusion. The full 20-epoch run on 3,000",
                "> examples would provide more reliable evidence.",
                "",
            ]

    # ── 6. Metric trajectories summary ───────────────────────────────────────
    lines += ["## 6. Metric Trajectory Summary", ""]

    for variant in variant_order:
        sub = df[df["variant"] == variant].sort_values("epoch")
        if sub.empty:
            continue
        ep1 = sub.iloc[0]
        epf = sub.iloc[-1]
        lines += [
            f"### {variant.replace('_', ' ').title()}",
            "",
            f"| Metric | Epoch 1 | Epoch {int(epf['epoch'])} | Δ |",
            "|---|---|---|---|",
        ]
        for col, label in metric_cols[:5]:  # skip lm_loss for brevity
            try:
                v1 = float(ep1[col])
                v2 = float(epf[col])
                delta = v2 - v1
                sign = "+" if delta >= 0 else ""
                lines.append(f"| {label} | {v1:.4f} | {v2:.4f} | {sign}{delta:.4f} |")
            except (KeyError, TypeError, ValueError):
                lines.append(f"| {label} | N/A | N/A | N/A |")
        lines += [""]

    # ── 7. Qualitative examples ───────────────────────────────────────────────
    lines += [
        "## 7. Qualitative Examples: Epoch-1 vs Final Epoch Generations",
        "",
        "The following examples are drawn from the `consistency_loss` variant.",
        "They show the model's generated explanation at epoch 1 (essentially random,",
        "as the model just started training) versus the final epoch.",
        "",
        "> **Smoke-run note**: With a tiny model and few epochs, generations are",
        "> short and may not yet form coherent prose. The progression from epoch 1",
        "> to the final epoch demonstrates that training is occurring and the",
        "> model is adapting, even if fluency is limited.",
        "",
    ]

    for i, ex in enumerate(qualitative_examples[:15], 1):
        code_snippet = ex["code"].strip()[:300]
        lines += [
            f"### Example {i}: `{ex.get('template_name', ex.get('code', '')[:40].strip()[:20])}...`",
            "",
            f"**Ground-truth claims:** time={ex['time_complexity']}, "
            f"space={ex['space_complexity']}, correct={ex['correctness']}",
            "",
            f"**Mismatched explanation (training input):**",
            f"> {ex['mismatched_expl'][:200]}",
            "",
            f"**True explanation (reference):**",
            f"> {ex['true_expl'][:200]}",
            "",
            "**Code snippet:**",
            "```python",
            code_snippet,
            "```",
            "",
            f"**Epoch-1 generation:**",
            f"> `{ex['epoch_1_gen'][:300]}`",
            "",
            f"**Final-epoch generation:**",
            f"> `{ex['epoch_final_gen'][:300]}`",
            "",
            "---",
            "",
        ]

    if len(qualitative_examples) < 10:
        lines += [
            f"> Only {len(qualitative_examples)} qualitative examples were collected",
            "> (fewer than 10 examples available in this run configuration).",
            "> The full run with 3,000 examples would provide richer qualitative analysis.",
            "",
        ]

    # ── 8. Limitations ────────────────────────────────────────────────────────
    lines += [
        "## 8. Limitations and Interpretation",
        "",
        "1. **Smoke run constraints**: The smoke configuration uses a tiny Transformer",
        "   (~0.5–2M params), a reduced dataset, and few epochs. These constraints",
        "   prevent the model from reaching the performance levels expected in the full run.",
        "",
        "2. **Tokenizer**: A simple whitespace tokenizer is used (no BPE/SentencePiece).",
        "   This means token sequences are longer than with a subword tokenizer, and",
        "   the vocabulary may not generalize as well.",
        "",
        "3. **BLEU/ROUGE proxy**: BLEU-1 and ROUGE-L are computed against ground-truth",
        "   explanation templates, not against diverse human references. They measure",
        "   whether the model recovers the training-set language, not open-ended quality.",
        "",
        "4. **Claim emission accuracy**: Measured by string-matching in greedy-decoded",
        "   output. A model could emit the correct claim token by memorizing without",
        "   true generalization.",
        "",
        "5. **Coupling vs causality**: Classifier accuracy on explanation hidden states",
        "   measures *correlation*, not causal coupling. The full experiment with",
        "   multiple random seeds and probing experiments would provide stronger evidence.",
        "",
        "6. **Full 20-epoch run**: The full configuration (3,000 examples, 20 epochs,",
        "   small model) requires approximately 2–4 hours on a modern GPU. The",
        "   `gpt2_small` config (~117M params) would require ≥8GB GPU VRAM and",
        "   significantly more compute.",
        "",
    ]

    # ── 9. Run instructions ───────────────────────────────────────────────────
    lines += [
        "## 9. Run Instructions",
        "",
        "See `README.md` for full setup and run instructions.",
        "",
        "```bash",
        "# Smoke run (fast, ~2-5 min on CPU):",
        "python run_experiment.py --smoke",
        "",
        "# Full run (20 epochs, 3000 examples):",
        "python run_experiment.py --full",
        "",
        "# Small model, custom config:",
        "python run_experiment.py --full --model small --epochs 20 --batch 32",
        "",
        "# GPT-2-style config (GPU required):",
        "python run_experiment.py --full --model gpt2_small",
        "```",
        "",
    ]

    # Write report
    content = "\n".join(lines)
    with open(output_path, "w") as f:
        f.write(content)
    print(f"  Report saved: {output_path}")
    return content


if __name__ == "__main__":
    # Test with synthetic data
    rows = []
    for v in ["consistency_loss", "no_consistency_loss"]:
        for ep in range(1, 4):
            rows.append({
                "variant": v, "epoch": ep,
                "val_coupling_strength": 0.4 + ep * 0.05,
                "val_bleu1": 0.1 + ep * 0.02,
                "val_rouge_l": 0.12 + ep * 0.02,
                "val_swap_influence": 0.05,
                "val_claim_accuracy": 0.2,
                "val_lm_loss": 4.0 - ep * 0.2,
            })
    df = pd.DataFrame(rows)
    qual = [{
        "idx": i, "code": "def foo(): pass",
        "true_expl": "A trivial function.", "mismatched_expl": "Sorts a list.",
        "epoch_1_gen": "<bos> foo bar", "epoch_final_gen": "A trivial function in O(1) time.",
        "time_complexity": "O(1)", "space_complexity": "O(1)", "correctness": 1,
    } for i in range(12)]
    build_report(df, qual, {"n_examples": 100, "n_epochs": 3}, is_smoke=True,
                 output_path="/tmp/test_report.pplx.md")
