"""
VERL training entry point with top-k SFT novelty bonus for GSPO.

Extends MetricsTaskRunner (from train_main.py) to inject a novelty bonus
into token_level_rewards before compute_advantage runs. The bonus rewards
correct rollouts whose reasoning is surprising to the SFT reference model,
measured by the mean neg-logprob of the top-k most surprising tokens.

The bonus is z-scored within each prompt group (uid) among correct responses
only, so it measures "novel relative to this prompt's correct paths."

Usage:
    Replace `python3 src/verl_helpers/train_main.py` with
    `python3 src/verl_helpers/train_main_novelty.py` in training scripts.
    All Hydra arguments work identically.
"""

import os

import hydra
import ray

import verl.trainer.config as _verl_config_module
from verl.trainer.main_ppo import run_ppo
from verl.utils.device import auto_set_device

# Import MetricsTaskRunner from train_main (gets per-puzzle metrics for free)
from train_main import MetricsTaskRunner, _apply_metrics_patch


# ---- Novelty bonus patch ----

def _apply_novelty_patch(
    alpha=0.1,
    topk=100,
    z_clip=2.0,
    use_sum=False,
    alpha_decay=False,
    total_steps=312,
    decay_schedule="cosine",
    resume_from_step=0,
):
    """Monkey-patch compute_advantage to inject top-k novelty bonus.

    Must be called inside the Ray actor process (TaskRunner.run).

    For each response, computes the (mean or sum) neg-logprob of the top-k
    most surprising tokens under the SFT reference model. Within each prompt
    group, z-scores these among correct responses and adds current_alpha * z
    to token_level_rewards at the last valid token.

    Args:
        alpha: Max bonus weight (at z=z_clip). 0.1 gives ~20% base-reward at ±2σ.
        topk: Number of most surprising tokens to average. 0 = mean over all.
        z_clip: Absolute clamp on z-score before weighting. Default 2.0.
        use_sum: If True AND topk<=0, use sum of neg-logprob instead of mean
            (disables 1/T length normalization). Verbosity can inflate the
            raw signal, so pair with tight z_clip.
        alpha_decay: If True, scale alpha by a schedule of the current step /
            total_steps.
        total_steps: Denominator for the decay schedule. Set to roughly
            (epochs * train_size / batch_size).
        decay_schedule: "linear" or "cosine". Cosine decays slower early.
    """
    import math
    import verl.trainer.ppo.ray_trainer as rt

    _orig_compute_advantage = rt.compute_advantage
    # Initialize from resume_from_step so cosine α-decay schedule aligns with
    # VERL's actual global_step across resumes (otherwise counter restarts at 0
    # after resume and the schedule is misaligned by however many steps we'd
    # already trained).
    _step_counter = [resume_from_step]

    def _current_alpha():
        if not alpha_decay:
            return alpha
        frac = min(1.0, _step_counter[0] / max(total_steps, 1))
        if decay_schedule == "linear":
            return alpha * (1.0 - frac)
        # cosine: 1 at frac=0, 0 at frac=1, smooth
        return alpha * 0.5 * (1.0 + math.cos(math.pi * frac))

    def _novelty_compute_advantage(data, adv_estimator, **kwargs):
        import torch
        import numpy as np
        from collections import defaultdict

        _step_counter[0] += 1
        cur_alpha = _current_alpha()

        # Detach at source: novelty bonus is reward shaping, treated as a scalar
        # constant w.r.t. the current policy update. Prevents grad from flowing
        # back through ref model forward pass (ref model is frozen anyway, but
        # detaching saves autograd graph memory and is defensive against future
        # VERL changes to how ref_log_prob is stored).
        ref_lp = data.batch["ref_log_prob"].detach()  # (B, seq_len)
        mask = data.batch["response_mask"]            # (B, seq_len)
        rewards = data.batch["token_level_rewards"]   # (B, seq_len)

        # --- Novelty score per response ---
        # neg-logprob: higher = more surprising to SFT
        neg_lp = -ref_lp
        resp_lens = mask.sum(dim=-1)  # (B,)

        if topk <= 0:
            # Mean NLL mode: average neg-logprob over all response tokens
            # If use_sum: raw sum (no 1/T normalization). Verbosity inflates signal.
            if use_sum:
                topk_scores = (neg_lp * mask).sum(dim=-1)
            else:
                topk_scores = (neg_lp * mask).sum(dim=-1) / resp_lens.clamp(min=1)
        else:
            # Top-k mode: average neg-logprob of k most surprising tokens
            neg_lp_masked = neg_lp.clone()
            neg_lp_masked[mask == 0] = float('-inf')
            k = min(topk, int(resp_lens.min().item()))
            if k < 10:
                topk_scores = (neg_lp * mask).sum(dim=-1) / resp_lens.clamp(min=1)
            else:
                topk_vals = torch.topk(neg_lp_masked, k=k, dim=-1).values  # (B, k)
                topk_scores = topk_vals.mean(dim=-1)  # (B,)

        # --- Correctness + grouping ---
        acc = np.array(
            data.non_tensor_batch.get("acc", [0.0] * topk_scores.shape[0]),
            dtype=float,
        )
        uids = data.non_tensor_batch["uid"]

        uid_to_indices = defaultdict(list)
        for i, uid in enumerate(uids):
            uid_to_indices[uid].append(i)

        # --- Z-score novelty bonus for correct-only within each prompt group ---
        bonus = torch.zeros_like(topk_scores)
        n_bonus_applied = 0
        n_prompts_with_bonus = 0

        for uid, indices in uid_to_indices.items():
            correct_idx = [i for i in indices if acc[i] >= 1.0]
            if len(correct_idx) < 2:
                continue  # need ≥2 correct to compute within-group z-score

            scores = topk_scores[correct_idx]
            mu, sigma = scores.mean(), scores.std()
            if sigma < 1e-6:
                continue  # all identical — no signal

            n_prompts_with_bonus += 1
            for i in correct_idx:
                z = (topk_scores[i] - mu) / sigma
                z = z.clamp(-z_clip, z_clip)  # configurable z-clip
                bonus[i] = z
                n_bonus_applied += 1

        # --- Inject bonus at last valid token of each response ---
        for i in range(bonus.shape[0]):
            if bonus[i].abs() > 1e-8:
                valid_len = int(resp_lens[i].item())
                if valid_len > 0:
                    rewards[i, valid_len - 1] += cur_alpha * bonus[i]

        data.batch["token_level_rewards"] = rewards

        # --- Store for WandB metrics ---
        # All non_tensor_batch entries must have length == batch_size
        data.non_tensor_batch["_novelty_bonus"] = bonus.detach().cpu().numpy()
        data.non_tensor_batch["_novelty_topk_score"] = topk_scores.detach().cpu().numpy()
        data.non_tensor_batch["_novelty_current_alpha"] = np.full(
            bonus.shape[0], cur_alpha, dtype=np.float32
        )

        return _orig_compute_advantage(data, adv_estimator, **kwargs)

    rt.compute_advantage = _novelty_compute_advantage


def _apply_novelty_metrics_patch():
    """Extend compute_data_metrics to log novelty bonus statistics.

    Must be called AFTER _apply_metrics_patch() so we wrap the already-patched
    version (which includes per-puzzle metrics).
    """
    import numpy as np

    import verl.trainer.ppo.metric_utils as mu
    import verl.trainer.ppo.ray_trainer as rt

    _patched = mu.compute_data_metrics  # already patched by _apply_metrics_patch

    def _novelty_metrics(batch, use_critic=True):
        metrics = _patched(batch, use_critic=use_critic)

        ntb = batch.non_tensor_batch

        # Novelty metrics (populated by _novelty_compute_advantage)
        if "_novelty_topk_score" in ntb:
            topk_scores = np.array(ntb["_novelty_topk_score"], dtype=float)
            bonus = np.array(ntb["_novelty_bonus"], dtype=float)

            metrics["novelty/mean_topk_score"] = float(np.mean(topk_scores))

            # Correct-only top-k score
            if "acc" in ntb:
                acc = np.array(ntb["acc"], dtype=float)
                correct_mask = acc >= 1.0
                if correct_mask.any():
                    metrics["novelty/mean_topk_score_correct"] = float(
                        np.mean(topk_scores[correct_mask])
                    )

            # Bonus statistics
            active = np.abs(bonus) > 1e-8
            metrics["novelty/mean_bonus"] = float(np.mean(np.abs(bonus[active]))) if active.any() else 0.0
            metrics["novelty/max_bonus"] = float(np.max(np.abs(bonus))) if len(bonus) > 0 else 0.0

            # Fraction of prompts where bonus was applied (compute from per-response bonus)
            if "uid" in ntb:
                from collections import defaultdict as _dd
                uid_has_bonus = _dd(bool)
                for i, uid in enumerate(ntb["uid"]):
                    if np.abs(bonus[i]) > 1e-8:
                        uid_has_bonus[uid] = True
                n_total_prompts = len(set(ntb["uid"]))
                n_with_bonus = sum(uid_has_bonus.values())
                metrics["novelty/frac_prompts_with_bonus"] = float(
                    n_with_bonus / max(n_total_prompts, 1)
                )
                metrics["novelty/n_prompts_with_bonus"] = n_with_bonus

            metrics["novelty/n_bonus_applied"] = int(active.sum())

            # Current alpha (for decay schedule visibility)
            if "_novelty_current_alpha" in ntb:
                cur = np.array(ntb["_novelty_current_alpha"], dtype=float)
                if len(cur) > 0:
                    metrics["novelty/current_alpha"] = float(cur[0])

            # Within-prompt std of top-k scores (the key signal metric)
            if "uid" in ntb:
                from collections import defaultdict
                uid_scores = defaultdict(list)
                for i, uid in enumerate(ntb["uid"]):
                    uid_scores[uid].append(topk_scores[i])
                within_stds = [
                    np.std(v) for v in uid_scores.values() if len(v) >= 2
                ]
                if within_stds:
                    metrics["novelty/topk_within_prompt_std"] = float(np.mean(within_stds))

        return metrics

    mu.compute_data_metrics = _novelty_metrics
    rt.compute_data_metrics = _novelty_metrics


# ---- EMA self-anchored reference patch ----

def _apply_ema_patch(ema_alpha: float, every_n_steps: int):
    """Monkey-patch compute_ref_log_prob on the worker to do in-place EMA
    updates on the ref model before computing ref log_probs.

    EMA rule: ref.p = ema_alpha * ref.p + (1 - ema_alpha) * actor.p

    Rationale: the fixed SFT reference causes a "novelty treadmill" — as the
    actor drifts, everything becomes surprising to the SFT model, which
    rewards arbitrary drift including verbose rambling. Self-anchoring the
    reference on an EMA of the actor bounds the surprise: only genuinely
    new (vs recent self) patterns score high.

    Must be called BEFORE Ray serializes the worker class for remote actors,
    i.e. before verl.trainer.main_ppo.TaskRunner.run() runs init_workers.

    Args:
        ema_alpha: EMA decay. 0.995 tracks ~200-step lag; closer to 1 is slower.
        every_n_steps: Run EMA update every N calls to compute_ref_log_prob.
    """
    from functools import wraps

    import torch
    from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

    import verl.workers.fsdp_workers as fw
    from verl.single_controller.base.decorator import MAGIC_ATTR

    # main_ppo uses AsyncActorRolloutRefWorker for fsdp strategy; fall back to
    # ActorRolloutRefWorker if async isn't available.
    BaseCls = (
        fw.AsyncActorRolloutRefWorker
        if hasattr(fw, "AsyncActorRolloutRefWorker")
        else fw.ActorRolloutRefWorker
    )
    _orig_compute_ref_log_prob = BaseCls.compute_ref_log_prob

    @wraps(_orig_compute_ref_log_prob)
    def _ema_wrapped_compute_ref_log_prob(self, data):
        # Per-worker call counter; EMA fires at counter > 1 to skip the initial
        # SFT-vs-SFT state.
        cnt = getattr(self, "_ema_step_counter", 0) + 1
        self._ema_step_counter = cnt
        do_ema = (
            cnt > 1
            and (cnt % every_n_steps == 0)
            and getattr(self, "_is_ref", False)
        )
        if do_ema:
            actor_sub = type(self)._siblings.get("actor_rollout")
            if actor_sub is not None:
                with torch.no_grad():
                    # summon_full_params materializes the true named_parameters
                    # (FSDP's top-level named_parameters exposes only opaque
                    # FlatParameter handles with ~29 entries for 7B; we need
                    # the underlying ~339 leaf params).
                    with FSDP.summon_full_params(
                        actor_sub.actor_module_fsdp, writeback=False, recurse=True
                    ):
                        actor_params = dict(
                            actor_sub.actor_module_fsdp.named_parameters()
                        )
                        with FSDP.summon_full_params(
                            self.ref_module_fsdp, writeback=True, recurse=True
                        ):
                            matched = 0
                            diff_sq = 0.0
                            for name, ref_p in self.ref_module_fsdp.named_parameters():
                                ap = actor_params.get(name)
                                if ap is None or ap.shape != ref_p.shape:
                                    continue
                                ap_on_dev = ap.data.to(ref_p.device, non_blocking=True)
                                diff_sq += (
                                    (ref_p.data.float() - ap_on_dev.float())
                                    .pow(2)
                                    .sum()
                                    .item()
                                )
                                ref_p.data.mul_(ema_alpha).add_(
                                    ap_on_dev, alpha=(1.0 - ema_alpha)
                                )
                                matched += 1
                if getattr(self, "rank", 0) == 0:
                    print(
                        f"[EMA] step={cnt} matched={matched} "
                        f"pre_diff_l2={diff_sq ** 0.5:.4f} alpha={ema_alpha}",
                        flush=True,
                    )
        return _orig_compute_ref_log_prob(self, data)

    # Preserve Ray dispatch metadata (@register sets MAGIC_ATTR).
    if hasattr(_orig_compute_ref_log_prob, MAGIC_ATTR):
        setattr(
            _ema_wrapped_compute_ref_log_prob,
            MAGIC_ATTR,
            getattr(_orig_compute_ref_log_prob, MAGIC_ATTR),
        )

    # Build a DYNAMIC subclass via type() so cloudpickle serializes it by
    # value (including our method overrides). If we only set methods on the
    # existing class, Ray pickles the class by module-reference and the worker
    # process re-imports the original (unpatched) version.
    #
    # The subclass carries two things we need at the worker:
    #   * compute_ref_log_prob override: runs the EMA update before the
    #     original ref log-prob call
    #   * __init__ override: registers each sub-worker in a class-level
    #     _siblings dict so the ref sub-worker can look up the actor sub-worker
    #     at EMA time (we can't use VERL's fused_worker_dict because VERL 0.7.0
    #     still uses the deprecated WorkerDict path which doesn't inject it)
    _orig_base_init = BaseCls.__init__

    def _ema_init(self, *args, **kwargs):
        _orig_base_init(self, *args, **kwargs)
        # Self-register in the class-level siblings dict. Both actor-rollout
        # and ref sub-workers are instances of the same EMAWorkerCls in the
        # same Python process (FusedWorker/WorkerDict colocation), so they
        # share this class attribute. The ref sub-worker's compute_ref_log_prob
        # wrapper uses it to find the actor sub-worker at EMA update time.
        role = getattr(self, "role", None)
        if role is not None:
            type(self)._siblings[role] = self

    EMAWorkerCls = type(
        "EMAActorRolloutRefWorker",
        (BaseCls,),
        {
            "_siblings": {},  # class-level; shared across all instances in the process
            "__init__": _ema_init,
            "compute_ref_log_prob": _ema_wrapped_compute_ref_log_prob,
        },
    )
    # Force cloudpickle to use by-value serialization: set __module__ to a
    # name the worker process cannot import, so cloudpickle can't resolve by
    # reference and falls back to by-value.
    EMAWorkerCls.__module__ = "__ema_dynamic__"
    EMAWorkerCls.__qualname__ = "EMAActorRolloutRefWorker"

    # Swap the module attribute so main_ppo's import picks up our subclass.
    if hasattr(fw, "AsyncActorRolloutRefWorker"):
        fw.AsyncActorRolloutRefWorker = EMAWorkerCls
    else:
        fw.ActorRolloutRefWorker = EMAWorkerCls

    print(
        f"[EMA-PATCH] applied alpha={ema_alpha} every_n_steps={every_n_steps} "
        f"base={BaseCls.__name__}",
        flush=True,
    )


# ---- Custom TaskRunner ----

class NoveltyTaskRunner(MetricsTaskRunner):
    """TaskRunner with per-puzzle metrics + top-k SFT novelty bonus + EMA ref."""

    def run(self, config):
        # Read novelty hyperparameters from environment (allows sweep without code changes)
        alpha = float(os.environ.get("NOVELTY_ALPHA", "0.1"))
        topk = int(os.environ.get("NOVELTY_TOPK", "100"))
        z_clip = float(os.environ.get("NOVELTY_Z_CLIP", "2.0"))
        use_sum = os.environ.get("NOVELTY_USE_SUM", "0") == "1"
        alpha_decay = os.environ.get("NOVELTY_ALPHA_DECAY", "0") == "1"
        total_steps = int(os.environ.get("NOVELTY_TOTAL_STEPS", "312"))
        decay_schedule = os.environ.get("NOVELTY_DECAY_SCHEDULE", "cosine")
        ema_alpha = float(os.environ.get("NOVELTY_EMA_ALPHA", "1.0"))
        ema_every_n_steps = int(os.environ.get("NOVELTY_EMA_EVERY_N_STEPS", "1"))

        # Auto-detect resume step from VERL's latest_checkpointed_iteration.txt
        # so the cosine α-decay schedule stays aligned across resumes. This
        # reads from the trainer's default_local_dir (set in Hydra config).
        resume_from_step = 0
        try:
            ckpt_dir = config.trainer.default_local_dir
            iter_file = os.path.join(ckpt_dir, "latest_checkpointed_iteration.txt")
            if os.path.exists(iter_file):
                with open(iter_file) as f:
                    resume_from_step = int(f.read().strip())
                print(
                    f"[NoveltyTaskRunner] Detected resume from step "
                    f"{resume_from_step} (via {iter_file})"
                )
        except Exception as e:
            print(f"[NoveltyTaskRunner] Resume-step detection failed, starting at 0: {e}")

        print(
            f"[NoveltyTaskRunner] alpha={alpha}, topk={topk}, z_clip={z_clip}, "
            f"use_sum={use_sum}, alpha_decay={alpha_decay}, "
            f"total_steps={total_steps}, decay_schedule={decay_schedule}, "
            f"resume_from_step={resume_from_step}, "
            f"ema_alpha={ema_alpha}, ema_every_n_steps={ema_every_n_steps}"
        )

        # Order matters: metrics patch first, then novelty patches, then EMA patch
        _apply_metrics_patch()
        _apply_novelty_patch(
            alpha=alpha,
            topk=topk,
            z_clip=z_clip,
            use_sum=use_sum,
            alpha_decay=alpha_decay,
            total_steps=total_steps,
            decay_schedule=decay_schedule,
            resume_from_step=resume_from_step,
        )
        _apply_novelty_metrics_patch()
        # EMA patch is off when ema_alpha == 1.0 (ref never updates → original behavior)
        if ema_alpha < 1.0:
            _apply_ema_patch(ema_alpha=ema_alpha, every_n_steps=ema_every_n_steps)
        # Skip MetricsTaskRunner.run's _apply_metrics_patch (already done above)
        # Call TaskRunner.run directly
        from verl.trainer.main_ppo import TaskRunner
        TaskRunner.run(self, config)


# ---- Hydra entry point ----

_verl_config_path = os.path.dirname(_verl_config_module.__file__)


@hydra.main(config_path=_verl_config_path, config_name="ppo_trainer", version_base=None)
def main(config):
    auto_set_device(config)
    run_ppo(config, task_runner_class=ray.remote(num_cpus=1)(NoveltyTaskRunner))


if __name__ == "__main__":
    main()
