"""
Custom VERL training entry point with per-puzzle reward metric logging.

Subclasses VERL's TaskRunner to monkey-patch compute_data_metrics inside
the Ray actor process (where it actually runs). This adds:

  train/{acc,partial_score,format_xmlcount}/{mean,max,min}
  train/{puzzle}/count
  train/{puzzle}/{acc,partial_score,score}/{mean}
  grpo/per_prompt_reward_var/mean
  grpo/zero_var_fraction
  grpo/{puzzle}/zero_var_fraction

The patch reads from batch.non_tensor_batch, which VERL populates from
dict-returning reward functions via NaiveRewardManager.reward_extra_info.

Usage:
    Replace `python3 -m verl.trainer.main_ppo` with
    `python3 src/verl_helpers/train_main.py` in training scripts.
    All Hydra arguments work identically.
"""

import os

import hydra
import ray

import verl.trainer.config as _verl_config_module
from verl.trainer.main_ppo import TaskRunner, run_ppo
from verl.utils.device import auto_set_device


# ---- The actual metrics patch (applied inside Ray actor) ----

def _apply_metrics_patch():
    """Monkey-patch compute_data_metrics to add per-puzzle reward decomposition.

    Must be called inside the Ray actor process (TaskRunner.run), NOT in the driver.
    Patches both the module attribute and the local import in ray_trainer.
    """
    from collections import defaultdict

    import numpy as np

    import verl.trainer.ppo.metric_utils as mu
    import verl.trainer.ppo.ray_trainer as rt

    _original = mu.compute_data_metrics

    _KNOWN_PUZZLES = [
        "bridges", "galaxies", "pattern", "undead", "loopy",
        "cryptarithm", "sudoku", "nonogram", "graph",
        "towers", "slant", "solo", "filling", "tents", "singles", "lightup",
    ]

    def _resolve_puzzle(ds):
        ds_lower = str(ds).lower()
        for p in _KNOWN_PUZZLES:
            if p in ds_lower:
                return p
        return str(ds)

    def _patched_compute_data_metrics(batch, use_critic=True):
        metrics = _original(batch, use_critic=use_critic)

        ntb = batch.non_tensor_batch
        reward_keys = ["acc", "format_xmlcount", "partial_score"]

        # ---- Batch-wide decomposed reward metrics ----
        for key in reward_keys:
            if key in ntb:
                try:
                    arr = np.array(ntb[key], dtype=float)
                    if len(arr) > 0:
                        metrics[f"train/{key}/mean"] = float(np.mean(arr))
                        metrics[f"train/{key}/max"] = float(np.max(arr))
                        metrics[f"train/{key}/min"] = float(np.min(arr))
                except (ValueError, TypeError):
                    pass

        # ---- Per-puzzle breakdown ----
        if "data_source" in ntb:
            data_sources = ntb["data_source"]
            puzzle_names = [_resolve_puzzle(ds) for ds in data_sources]

            puzzle_indices = defaultdict(list)
            for i, pname in enumerate(puzzle_names):
                puzzle_indices[pname].append(i)

            # Per-puzzle aggregated score from token_level_scores
            sequence_score = batch.batch["token_level_scores"].sum(-1)

            for pname, indices in puzzle_indices.items():
                idx = np.array(indices)
                metrics[f"train/{pname}/count"] = len(idx)

                try:
                    pscores = sequence_score[idx].detach().cpu().numpy()
                    metrics[f"train/{pname}/score/mean"] = float(np.mean(pscores))
                except (IndexError, RuntimeError):
                    pass

                for key in reward_keys:
                    if key in ntb:
                        try:
                            arr = np.array(ntb[key], dtype=float)
                            metrics[f"train/{pname}/{key}/mean"] = float(np.mean(arr[idx]))
                        except (ValueError, TypeError, IndexError):
                            pass

        # ---- GRPO signal quality ----
        if "uid" in ntb and "acc" in ntb:
            try:
                uids = ntb["uid"]
                acc_vals = np.array(ntb["acc"], dtype=float)

                uid_to_rewards = defaultdict(list)
                uid_to_puzzle = {}
                has_ds = "data_source" in ntb
                for i, (uid, reward) in enumerate(zip(uids, acc_vals)):
                    uid_to_rewards[uid].append(reward)
                    if has_ds:
                        uid_to_puzzle[uid] = _resolve_puzzle(ntb["data_source"][i])

                prompt_vars = []
                zero_var_count = 0
                puzzle_zero_var = defaultdict(lambda: {"total": 0, "zero": 0})

                for uid, rewards in uid_to_rewards.items():
                    if len(rewards) > 1:
                        v = np.var(rewards)
                        prompt_vars.append(v)
                        pname = uid_to_puzzle.get(uid, "unknown")
                        puzzle_zero_var[pname]["total"] += 1
                        if v == 0.0:
                            zero_var_count += 1
                            puzzle_zero_var[pname]["zero"] += 1

                n_prompts = len(prompt_vars)
                if n_prompts > 0:
                    metrics["grpo/per_prompt_reward_var/mean"] = float(np.mean(prompt_vars))
                    metrics["grpo/zero_var_fraction"] = float(zero_var_count / n_prompts)

                    for pname, stats in puzzle_zero_var.items():
                        if stats["total"] > 0:
                            metrics[f"grpo/{pname}/zero_var_fraction"] = float(
                                stats["zero"] / stats["total"]
                            )
            except (ValueError, TypeError):
                pass

        return metrics

    # Patch both the module attribute and ray_trainer's local import
    mu.compute_data_metrics = _patched_compute_data_metrics
    rt.compute_data_metrics = _patched_compute_data_metrics


# ---- Custom TaskRunner ----

class MetricsTaskRunner(TaskRunner):
    """TaskRunner that adds per-puzzle reward metric logging."""

    def run(self, config):
        _apply_metrics_patch()
        super().run(config)


# ---- Hydra entry point ----

_verl_config_path = os.path.dirname(_verl_config_module.__file__)


@hydra.main(config_path=_verl_config_path, config_name="ppo_trainer", version_base=None)
def main(config):
    auto_set_device(config)
    run_ppo(config, task_runner_class=ray.remote(num_cpus=1)(MetricsTaskRunner))


if __name__ == "__main__":
    main()
