# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import random
import unittest

import numpy as np
import pytest
import torch

import verl.trainer.ppo.core_algos
from verl.trainer.ppo.core_algos import (
    compute_gae_advantage_return,
    compute_grpo_outcome_advantage,
    compute_grpo_verk_step_reward_global_norm_invcount_nostd_advantage,
    compute_grpo_verk_step_reward_global_norm_nostd_advantage,
    compute_grpo_verk_step_reward_global_norm_reweight_future_only_nostd_advantage,
    compute_grpo_verk_step_reward_step_norm_advantage,
    compute_grpo_verk_step_reward_step_norm_reweight_advantage,
    compute_grpo_verk_step_reward_step_norm_reweight_future_only_advantage,
    compute_grpo_verk_step_reward_step_norm_reweight_rms_advantage,
    compute_grpo_vectorized_outcome_advantage,
    compute_rloo_outcome_advantage,
    compute_rloo_vectorized_outcome_advantage,
    get_adv_estimator_fn,
    register_adv_est,
)


def mock_test_fn():
    pass


class TestRegisterAdvEst(unittest.TestCase):
    def setUp(self):
        """Clear the registry before each test"""
        verl.trainer.ppo.core_algos.ADV_ESTIMATOR_REGISTRY.clear()
        verl.trainer.ppo.core_algos.ADV_ESTIMATOR_REGISTRY = {
            "gae": lambda x: x * 2,
            "vtrace": lambda x: x + 1,
        }
        self.ADV_ESTIMATOR_REGISTRY = verl.trainer.ppo.core_algos.ADV_ESTIMATOR_REGISTRY

    def tearDown(self) -> None:
        verl.trainer.ppo.core_algos.ADV_ESTIMATOR_REGISTRY.clear()
        return super().tearDown()

    def test_register_new_function(self):
        """Test registering a new function with a string name"""

        @register_adv_est("test_estimator")
        def test_fn():
            pass

        self.assertIn("test_estimator", self.ADV_ESTIMATOR_REGISTRY)
        self.assertEqual(self.ADV_ESTIMATOR_REGISTRY["test_estimator"], test_fn)

    def test_register_with_enum(self):
        """Test registering with an enum value (assuming AdvantageEstimator exists)"""
        from enum import Enum

        class AdvantageEstimator(Enum):
            TEST = "test_enum_estimator"

        @register_adv_est(AdvantageEstimator.TEST)
        def test_fn():
            pass

        self.assertIn("test_enum_estimator", self.ADV_ESTIMATOR_REGISTRY)
        self.assertEqual(self.ADV_ESTIMATOR_REGISTRY["test_enum_estimator"], test_fn)

    def test_duplicate_registration_same_function(self):
        """Test that registering the same function twice doesn't raise an error"""
        register_adv_est("duplicate_test")(mock_test_fn)
        register_adv_est("duplicate_test")(mock_test_fn)

        self.assertEqual(self.ADV_ESTIMATOR_REGISTRY["duplicate_test"], mock_test_fn)

    def test_duplicate_registration_different_function(self):
        """Test that registering different functions with same name raises ValueError"""

        @register_adv_est("conflict_test")
        def test_fn1():
            pass

        with self.assertRaises(ValueError):

            @register_adv_est("conflict_test")
            def test_fn2():
                pass

    def test_decorator_preserves_function(self):
        """Test that the decorator returns the original function"""

        def test_fn():
            return "original"

        decorated = register_adv_est("preserve_test")(test_fn)
        self.assertEqual(decorated(), "original")

    def test_multiple_registrations(self):
        """Test registering multiple different functions"""
        init_adv_count = len(self.ADV_ESTIMATOR_REGISTRY)

        @register_adv_est("estimator1")
        def fn1():
            pass

        @register_adv_est("estimator2")
        def fn2():
            pass

        self.assertEqual(len(self.ADV_ESTIMATOR_REGISTRY), 2 + init_adv_count)
        self.assertEqual(self.ADV_ESTIMATOR_REGISTRY["estimator1"], fn1)
        self.assertEqual(self.ADV_ESTIMATOR_REGISTRY["estimator2"], fn2)

    def test_get_adv_estimator_fn_valid_names(self):
        """Test that valid names return the correct function from registry."""
        # Test GAE
        gae_fn = get_adv_estimator_fn("gae")
        assert gae_fn(5) == 10  # 5 * 2 = 10

        # Test Vtrace
        vtrace_fn = get_adv_estimator_fn("vtrace")
        assert vtrace_fn(5) == 6  # 5 + 1 = 6

    def test_get_adv_estimator_fn_invalid_name(self):
        """Test that invalid names raise ValueError."""
        with pytest.raises(ValueError) as excinfo:
            get_adv_estimator_fn("invalid_name")
        assert "Unknown advantage estimator simply: invalid_name" in str(excinfo.value)

    def test_get_adv_estimator_fn_case_sensitive(self):
        """Test that name lookup is case-sensitive."""
        with pytest.raises(ValueError):
            get_adv_estimator_fn("GAE")  # Different case


def test_multi_turn_compute_gae_advantage_return():
    """Test multi-turn GAE skip observation tokens."""
    gamma = random.uniform(0.0, 1.0)
    lam = random.uniform(0.0, 1.0)

    rewards = torch.tensor([[0.0, 0.0, 0.1, 0.1, 0.1, 0.0, 0.0, 0.1, 1.0, 0.0, 0.0]], dtype=torch.float)

    values1 = torch.tensor(
        [
            [
                random.uniform(-100.0, 100.0),
                random.random(),
                4.0,
                5.0,
                6.0,
                random.uniform(-100.0, 0),
                random.random(),
                7.0,
                9.0,
                0.0,
                0.0,
            ]
        ],
        dtype=torch.float,
    )

    values2 = torch.tensor(
        [
            [
                random.random(),
                random.uniform(-100.0, 100.0),
                4.0,
                5.0,
                6.0,
                random.random(),
                random.uniform(0.0, 100.0),
                7.0,
                9.0,
                0.0,
                0.0,
            ]
        ],
        dtype=torch.float,
    )

    response_mask = torch.tensor([[0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0]], dtype=torch.float)

    adv1, ret1 = compute_gae_advantage_return(rewards, values1, response_mask, gamma, lam)
    adv2, ret2 = compute_gae_advantage_return(rewards, values2, response_mask, gamma, lam)

    ret1 *= response_mask
    ret2 *= response_mask
    assert torch.equal(adv1, adv2), f"{adv1=}, {adv2=}"
    assert torch.equal(ret1, ret2), f"{ret1=}, {ret2=}"
    print(f" [CORRECT] \n\n{adv1=}, \n\n{ret1=}")


def _make_group_index(batch_size: int, num_groups: int) -> np.ndarray:
    """Create a numpy index array ensuring each group has at least 2 samples."""
    assert num_groups * 2 <= batch_size, "batch_size must allow >=2 samples per group"
    counts: list[int] = [2] * num_groups
    remaining = batch_size - 2 * num_groups
    for _ in range(remaining):
        counts[random.randrange(num_groups)] += 1
    index = []
    for gid, c in enumerate(counts):
        index.extend([gid] * c)
    random.shuffle(index)
    return np.asarray(index, dtype=np.int64)


def _rand_mask(batch_size: int, seq_len: int) -> torch.Tensor:
    mask = torch.randint(0, 2, (batch_size, seq_len), dtype=torch.int64).float()
    rows_without_one = (mask.sum(dim=-1) == 0).nonzero(as_tuple=True)[0]
    if len(rows_without_one) > 0:
        mask[rows_without_one, -1] = 1.0
    return mask


def _expected_step_norm_adv(values: torch.Tensor, reached: torch.Tensor, index: np.ndarray, eps: float = 1e-6):
    index_t = torch.tensor(index, dtype=torch.long, device=values.device)
    groups = sorted(set(index.tolist()))
    scalars = torch.zeros_like(values, dtype=torch.float32)
    for step in range(values.shape[1]):
        for gid in groups:
            mask = (index_t == gid) & reached[:, step]
            vals = values[mask, step]
            if vals.numel() == 0:
                continue
            if vals.numel() <= 1:
                mean = torch.tensor(0.0, device=values.device)
                std = torch.tensor(1.0, device=values.device)
            else:
                mean = vals.mean()
                s1 = vals.sum()
                s2 = (vals * vals).sum()
                count = torch.tensor(float(vals.numel()), device=values.device)
                var_num = s2 - (s1 * s1) / count
                denom = (count - 1.0).clamp_min(1.0)
                var = var_num / denom
                std = torch.sqrt(torch.clamp(var, min=eps))
            scalars[mask, step] = (vals - mean) / (std + eps)
    return scalars


def _expected_allstep_mean_center(values: torch.Tensor, reached: torch.Tensor, index: np.ndarray) -> torch.Tensor:
    index_t = torch.tensor(index, dtype=torch.long, device=values.device)
    reached_f = reached.to(dtype=torch.float32)
    per_b_sum = (values * reached_f).sum(dim=1)
    per_b_cnt = reached_f.sum(dim=1)
    G = int(index_t.max().item()) + 1 if index_t.numel() > 0 else 0
    sum_g = torch.zeros(G, dtype=torch.float32, device=values.device).index_add_(0, index_t, per_b_sum)
    cnt_g = torch.zeros(G, dtype=torch.float32, device=values.device).index_add_(0, index_t, per_b_cnt)
    mean_g = sum_g / cnt_g.clamp_min(1.0)
    return (values - mean_g[index_t].unsqueeze(1)) * reached_f


def _compute_verk_weights(turn_successes: torch.Tensor, index: np.ndarray, eps: float = 1e-8):
    index_t = torch.tensor(index, dtype=torch.long, device=turn_successes.device)
    reached = turn_successes != -1
    success = turn_successes == 1
    G = int(index_t.max().item()) + 1 if index_t.numel() > 0 else 0
    K = turn_successes.shape[1]
    sum_reached = torch.zeros((G, K), dtype=torch.float32, device=turn_successes.device).index_add_(
        0, index_t, reached.float()
    )
    sum_success = torch.zeros((G, K), dtype=torch.float32, device=turn_successes.device).index_add_(
        0, index_t, success.float()
    )
    p_hat = torch.zeros_like(sum_success)
    mask = sum_reached > 0
    p_hat[mask] = sum_success[mask] / sum_reached[mask]
    one_minus = (1.0 - p_hat).clamp(min=eps, max=1.0)
    if K == 1:
        w = torch.ones((G, 1), dtype=torch.float32, device=turn_successes.device)
    else:
        prefix = torch.cumprod(one_minus, dim=1)
        suffix = torch.cumprod(one_minus.flip(1), dim=1).flip(1)
        left = torch.ones_like(one_minus)
        right = torch.ones_like(one_minus)
        left[:, 1:] = prefix[:, :-1]
        right[:, :-1] = suffix[:, 1:]
        w = left * right
    return w, mask


def _compute_future_only_weights(turn_successes: torch.Tensor, index: np.ndarray, eps: float = 1e-8):
    index_t = torch.tensor(index, dtype=torch.long, device=turn_successes.device)
    reached = turn_successes != -1
    success = turn_successes == 1
    G = int(index_t.max().item()) + 1 if index_t.numel() > 0 else 0
    K = turn_successes.shape[1]
    sum_reached = torch.zeros((G, K), dtype=torch.float32, device=turn_successes.device).index_add_(
        0, index_t, reached.float()
    )
    sum_success = torch.zeros((G, K), dtype=torch.float32, device=turn_successes.device).index_add_(
        0, index_t, success.float()
    )
    p_hat = torch.zeros_like(sum_success)
    step_mask = sum_reached > 0
    p_hat[step_mask] = sum_success[step_mask] / sum_reached[step_mask]
    one_minus_group = torch.ones_like(p_hat)
    one_minus_group[step_mask] = (1.0 - p_hat[step_mask]).clamp(min=eps, max=1.0)
    if K == 1:
        w_group = torch.ones((G, 1), dtype=torch.float32, device=turn_successes.device)
    else:
        suffix = torch.cumprod(one_minus_group.flip(1), dim=1).flip(1)
        w_group = torch.ones_like(one_minus_group)
        w_group[:, :-1] = suffix[:, 1:]
    return w_group, step_mask, p_hat


def _mean_normalize_weights(w: torch.Tensor, mask: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    denom = (w * mask).sum(dim=1) / mask.sum(dim=1).clamp_min(1.0)
    denom = denom.clamp_min(eps)
    return w / denom.unsqueeze(1)


def _rms_normalize_weights(w: torch.Tensor, mask: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    denom = torch.sqrt((w * w * mask).sum(dim=1) / mask.sum(dim=1).clamp_min(1.0))
    denom = denom.clamp_min(eps)
    return w / denom.unsqueeze(1)


@pytest.mark.parametrize(
    "batch_size,seq_len,num_groups,seed",
    [
        (64, 128, 5, 0),
        (128, 256, 8, 1),
        (512, 512, 10, 2),
    ],
)
def test_rloo_and_vectorized_equivalence(batch_size: int, seq_len: int, num_groups: int, seed: int):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    index = _make_group_index(batch_size, num_groups)
    response_mask = _rand_mask(batch_size, seq_len)
    base_rewards = torch.randn(batch_size, seq_len, dtype=torch.float32)
    token_level_rewards = base_rewards * response_mask
    adv1, ret1 = compute_rloo_outcome_advantage(
        token_level_rewards=token_level_rewards,
        response_mask=response_mask,
        index=index,
    )
    adv2, ret2 = compute_rloo_vectorized_outcome_advantage(
        token_level_rewards=token_level_rewards,
        response_mask=response_mask,
        index=index,
    )
    # Print concise diagnostics for visibility during test runs
    adv_max_diff = (adv1 - adv2).abs().max().item()
    ret_max_diff = (ret1 - ret2).abs().max().item()
    total_mask_tokens = int(response_mask.sum().item())
    print(
        f"[RLOO] seed={seed} groups={num_groups} shape={adv1.shape} "
        f"mask_tokens={total_mask_tokens} adv_max_diff={adv_max_diff:.3e} ret_max_diff={ret_max_diff:.3e}"
    )
    assert adv1.shape == adv2.shape == (batch_size, seq_len)
    assert ret1.shape == ret2.shape == (batch_size, seq_len)
    assert torch.allclose(adv1, adv2, rtol=1e-5, atol=1e-6)
    assert torch.allclose(ret1, ret2, rtol=1e-5, atol=1e-6)


@pytest.mark.parametrize(
    "batch_size,seq_len,num_groups,seed",
    [
        (64, 128, 5, 0),
        (128, 256, 8, 1),
        (512, 512, 10, 2),
    ],
)
def test_grpo_and_vectorized_equivalence(batch_size: int, seq_len: int, num_groups: int, seed: int):
    # Set seeds for reproducibility
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

    # Generate group indices (numpy array of shape [batch_size])
    index = _make_group_index(batch_size, num_groups)

    # Generate binary response mask (at least one valid token per row)
    response_mask = _rand_mask(batch_size, seq_len)

    # Generate token-level rewards and apply mask
    base_rewards = torch.randn(batch_size, seq_len, dtype=torch.float32)
    token_level_rewards = base_rewards * response_mask

    # Compute GRPO outcome advantage (original implementation)
    adv1, ret1 = compute_grpo_outcome_advantage(
        token_level_rewards=token_level_rewards,
        response_mask=response_mask,
        index=index,
    )

    # Compute GRPO outcome advantage (vectorized implementation)
    adv2, ret2 = compute_grpo_vectorized_outcome_advantage(
        token_level_rewards=token_level_rewards,
        response_mask=response_mask,
        index=index,
    )

    # Diagnostic info for visibility (same style as RLOO test)
    adv_max_diff = (adv1 - adv2).abs().max().item()
    ret_max_diff = (ret1 - ret2).abs().max().item()
    total_mask_tokens = int(response_mask.sum().item())
    print(
        f"[GRPO] seed={seed} groups={num_groups} shape={adv1.shape} "
        f"mask_tokens={total_mask_tokens} adv_max_diff={adv_max_diff:.3e} ret_max_diff={ret_max_diff:.3e}"
    )

    # Assert shape and numerical equivalence
    assert adv1.shape == adv2.shape == (batch_size, seq_len)
    assert ret1.shape == ret2.shape == (batch_size, seq_len)
    assert torch.allclose(adv1, adv2, rtol=1e-5, atol=1e-6)
    assert torch.allclose(ret1, ret2, rtol=1e-5, atol=1e-6)


def test_grpo_verk_step_reward_step_norm_advantage():
    batch_size = 6
    seq_len = 6
    index = np.array([0, 0, 0, 1, 1, 1], dtype=np.int64)

    response_mask = torch.ones(batch_size, seq_len, dtype=torch.float32)
    assistant_turn_ids = torch.tensor([[0, 0, 1, 1, 2, 2]] * batch_size, dtype=torch.long)

    assistant_turn_rewards = torch.tensor(
        [
            [0.0, 2.0, 9.0],
            [1.0, 4.0, 5.0],
            [2.0, 8.0, 9.0],
            [1.0, 0.0, 7.0],
            [3.0, 9.0, 9.0],
            [5.0, 2.0, 9.0],
        ],
        dtype=torch.float32,
    )
    turn_successes = torch.tensor(
        [
            [1, 0, -1],
            [0, 1, 1],
            [0, -1, -1],
            [1, 0, 1],
            [0, -1, -1],
            [0, 0, -1],
        ],
        dtype=torch.int64,
    )

    adv, ret = compute_grpo_verk_step_reward_step_norm_advantage(
        token_level_rewards=torch.zeros(batch_size, seq_len, dtype=torch.float32),
        response_mask=response_mask,
        index=index,
        assistant_turn_ids=assistant_turn_ids,
        turn_successes=turn_successes,
        assistant_turn_rewards=assistant_turn_rewards,
    )

    reached = turn_successes != -1
    scalars = _expected_step_norm_adv(assistant_turn_rewards, reached, index)
    expected = torch.gather(scalars, 1, assistant_turn_ids) * response_mask

    assert torch.allclose(adv, expected, rtol=1e-5, atol=1e-6)
    assert torch.allclose(ret, expected, rtol=1e-5, atol=1e-6)


def test_grpo_verk_step_reward_global_norm_nostd_advantage():
    batch_size = 6
    seq_len = 6
    index = np.array([0, 0, 0, 1, 1, 1], dtype=np.int64)

    response_mask = torch.ones(batch_size, seq_len, dtype=torch.float32)
    assistant_turn_ids = torch.tensor([[0, 0, 1, 1, 2, 2]] * batch_size, dtype=torch.long)

    assistant_turn_rewards = torch.tensor(
        [
            [0.0, 2.0, 9.0],
            [1.0, 4.0, 5.0],
            [2.0, 8.0, 9.0],
            [1.0, 0.0, 7.0],
            [3.0, 9.0, 9.0],
            [5.0, 2.0, 9.0],
        ],
        dtype=torch.float32,
    )
    turn_successes = torch.tensor(
        [
            [1, 0, -1],
            [0, 1, 1],
            [0, -1, -1],
            [1, 0, 1],
            [0, -1, -1],
            [0, 0, -1],
        ],
        dtype=torch.int64,
    )

    adv, ret = compute_grpo_verk_step_reward_global_norm_nostd_advantage(
        token_level_rewards=torch.zeros(batch_size, seq_len, dtype=torch.float32),
        response_mask=response_mask,
        index=index,
        assistant_turn_ids=assistant_turn_ids,
        turn_successes=turn_successes,
        assistant_turn_rewards=assistant_turn_rewards,
    )

    reached = turn_successes != -1
    scalars = _expected_allstep_mean_center(assistant_turn_rewards, reached, index)
    expected = torch.gather(scalars, 1, assistant_turn_ids) * response_mask

    assert torch.allclose(adv, expected, rtol=1e-5, atol=1e-6)
    assert torch.allclose(ret, expected, rtol=1e-5, atol=1e-6)


def test_grpo_verk_step_reward_global_norm_invcount_nostd_advantage():
    batch_size = 4
    seq_len = 4
    index = np.array([0, 0, 1, 1], dtype=np.int64)

    response_mask = torch.ones(batch_size, seq_len, dtype=torch.float32)
    assistant_turn_ids = torch.tensor([[0, 0, 1, 1]] * batch_size, dtype=torch.long)

    assistant_turn_rewards = torch.tensor(
        [
            [1.0, 2.0],
            [3.0, 9.0],
            [4.0, 5.0],
            [6.0, 9.0],
        ],
        dtype=torch.float32,
    )
    turn_successes = torch.tensor(
        [
            [1, 1],
            [0, -1],
            [1, 0],
            [0, -1],
        ],
        dtype=torch.int64,
    )

    adv, ret = compute_grpo_verk_step_reward_global_norm_invcount_nostd_advantage(
        token_level_rewards=torch.zeros(batch_size, seq_len, dtype=torch.float32),
        response_mask=response_mask,
        index=index,
        assistant_turn_ids=assistant_turn_ids,
        turn_successes=turn_successes,
        assistant_turn_rewards=assistant_turn_rewards,
    )

    reached = turn_successes != -1
    scalars = _expected_allstep_mean_center(assistant_turn_rewards, reached, index)

    index_t = torch.tensor(index, dtype=torch.long)
    reached_f = reached.to(dtype=torch.float32)
    G = int(index_t.max().item()) + 1 if index_t.numel() > 0 else 0
    K = turn_successes.shape[1]
    count_reached = torch.zeros((G, K), dtype=torch.float32).index_add_(0, index_t, reached_f)
    w = torch.zeros_like(count_reached)
    nonzero = count_reached > 0
    w[nonzero] = 1.0 / count_reached[nonzero]
    step_mask = count_reached > 0
    w_norm = _mean_normalize_weights(w, step_mask)

    expected = (
        torch.gather(scalars, 1, assistant_turn_ids)
        * torch.gather(w_norm[index_t], 1, assistant_turn_ids)
        * response_mask
    )

    assert torch.allclose(adv, expected, rtol=1e-5, atol=1e-6)
    assert torch.allclose(ret, expected, rtol=1e-5, atol=1e-6)


def test_grpo_verk_step_reward_global_norm_reweight_future_only_nostd_advantage():
    batch_size = 3
    seq_len = 4
    index = np.array([0, 0, 0], dtype=np.int64)

    response_mask = torch.ones(batch_size, seq_len, dtype=torch.float32)
    assistant_turn_ids = torch.tensor([[0, 0, 1, 1]] * batch_size, dtype=torch.long)

    assistant_turn_rewards = torch.tensor(
        [
            [1.0, 2.0],
            [3.0, 4.0],
            [5.0, 6.0],
        ],
        dtype=torch.float32,
    )
    turn_successes = torch.tensor(
        [
            [1, -1],
            [0, 1],
            [0, 0],
        ],
        dtype=torch.int64,
    )

    adv, ret = compute_grpo_verk_step_reward_global_norm_reweight_future_only_nostd_advantage(
        token_level_rewards=torch.zeros(batch_size, seq_len, dtype=torch.float32),
        response_mask=response_mask,
        index=index,
        assistant_turn_ids=assistant_turn_ids,
        turn_successes=turn_successes,
        assistant_turn_rewards=assistant_turn_rewards,
    )

    reached = turn_successes != -1
    scalars = _expected_allstep_mean_center(assistant_turn_rewards, reached, index)
    w_group, step_mask, _ = _compute_future_only_weights(turn_successes, index)
    w_norm = _mean_normalize_weights(w_group, step_mask)
    index_t = torch.tensor(index, dtype=torch.long)
    expected = (
        torch.gather(scalars, 1, assistant_turn_ids)
        * torch.gather(w_norm[index_t], 1, assistant_turn_ids)
        * response_mask
    )

    assert torch.allclose(adv, expected, rtol=1e-5, atol=1e-6)
    assert torch.allclose(ret, expected, rtol=1e-5, atol=1e-6)


def test_grpo_verk_step_reward_step_norm_reweight_advantage():
    batch_size = 4
    seq_len = 4
    index = np.array([0, 0, 1, 1], dtype=np.int64)

    response_mask = torch.ones(batch_size, seq_len, dtype=torch.float32)
    assistant_turn_ids = torch.tensor([[0, 0, 1, 1]] * batch_size, dtype=torch.long)

    assistant_turn_rewards = torch.tensor(
        [
            [1.0, 2.0],
            [3.0, 9.0],
            [4.0, 5.0],
            [6.0, 9.0],
        ],
        dtype=torch.float32,
    )
    turn_successes = torch.tensor(
        [
            [1, 1],
            [0, -1],
            [1, 0],
            [0, -1],
        ],
        dtype=torch.int64,
    )

    adv, ret = compute_grpo_verk_step_reward_step_norm_reweight_advantage(
        token_level_rewards=torch.zeros(batch_size, seq_len, dtype=torch.float32),
        response_mask=response_mask,
        index=index,
        assistant_turn_ids=assistant_turn_ids,
        turn_successes=turn_successes,
        assistant_turn_rewards=assistant_turn_rewards,
    )

    reached = turn_successes != -1
    w, step_mask = _compute_verk_weights(turn_successes, index)
    w_norm = _mean_normalize_weights(w, step_mask)
    scalars = _expected_step_norm_adv(assistant_turn_rewards, reached, index)
    expected = (
        torch.gather(scalars, 1, assistant_turn_ids)
        * torch.gather(w_norm[torch.tensor(index, dtype=torch.long)], 1, assistant_turn_ids)
        * response_mask
    )

    assert torch.allclose(adv, expected, rtol=1e-5, atol=1e-6)
    assert torch.allclose(ret, expected, rtol=1e-5, atol=1e-6)


def test_grpo_verk_step_reward_step_norm_reweight_future_only_advantage():
    batch_size = 3
    seq_len = 4
    index = np.array([0, 0, 0], dtype=np.int64)

    response_mask = torch.ones(batch_size, seq_len, dtype=torch.float32)
    assistant_turn_ids = torch.tensor([[0, 0, 1, 1]] * batch_size, dtype=torch.long)

    assistant_turn_rewards = torch.tensor(
        [
            [1.0, 2.0],
            [3.0, 4.0],
            [5.0, 6.0],
        ],
        dtype=torch.float32,
    )
    turn_successes = torch.tensor(
        [
            [1, -1],
            [0, 1],
            [0, 0],
        ],
        dtype=torch.int64,
    )

    adv, ret = compute_grpo_verk_step_reward_step_norm_reweight_future_only_advantage(
        token_level_rewards=torch.zeros(batch_size, seq_len, dtype=torch.float32),
        response_mask=response_mask,
        index=index,
        assistant_turn_ids=assistant_turn_ids,
        turn_successes=turn_successes,
        assistant_turn_rewards=assistant_turn_rewards,
    )

    reached = turn_successes != -1
    w_group, step_mask, p_hat = _compute_future_only_weights(turn_successes, index)
    assert step_mask[0, 1]
    assert torch.allclose(w_group[0, 1], torch.tensor(1.0), rtol=1e-6, atol=1e-6)
    expected_step0 = 1.0 - p_hat[0, 1]
    assert torch.allclose(w_group[0, 0], expected_step0, rtol=1e-6, atol=1e-6)

    scalars = _expected_step_norm_adv(assistant_turn_rewards, reached, index)
    w_norm = _mean_normalize_weights(w_group, step_mask)
    index_t = torch.tensor(index, dtype=torch.long)
    expected = (
        torch.gather(scalars, 1, assistant_turn_ids)
        * torch.gather(w_norm[index_t], 1, assistant_turn_ids)
        * response_mask
    )

    assert torch.allclose(adv, expected, rtol=1e-5, atol=1e-6)
    assert torch.allclose(ret, expected, rtol=1e-5, atol=1e-6)


def test_grpo_verk_step_reward_step_norm_reweight_rms_advantage():
    batch_size = 4
    seq_len = 4
    index = np.array([0, 0, 1, 1], dtype=np.int64)

    response_mask = torch.ones(batch_size, seq_len, dtype=torch.float32)
    assistant_turn_ids = torch.tensor([[0, 0, 1, 1]] * batch_size, dtype=torch.long)

    assistant_turn_rewards = torch.tensor(
        [
            [1.0, 2.0],
            [3.0, 9.0],
            [4.0, 5.0],
            [6.0, 9.0],
        ],
        dtype=torch.float32,
    )
    turn_successes = torch.tensor(
        [
            [1, 1],
            [0, -1],
            [1, 0],
            [0, -1],
        ],
        dtype=torch.int64,
    )

    adv, ret = compute_grpo_verk_step_reward_step_norm_reweight_rms_advantage(
        token_level_rewards=torch.zeros(batch_size, seq_len, dtype=torch.float32),
        response_mask=response_mask,
        index=index,
        assistant_turn_ids=assistant_turn_ids,
        turn_successes=turn_successes,
        assistant_turn_rewards=assistant_turn_rewards,
    )

    reached = turn_successes != -1
    w, step_mask = _compute_verk_weights(turn_successes, index)
    w_norm = _rms_normalize_weights(w, step_mask)
    scalars = _expected_step_norm_adv(assistant_turn_rewards, reached, index)
    expected = (
        torch.gather(scalars, 1, assistant_turn_ids)
        * torch.gather(w_norm[torch.tensor(index, dtype=torch.long)], 1, assistant_turn_ids)
        * response_mask
    )

    assert torch.allclose(adv, expected, rtol=1e-5, atol=1e-6)
    assert torch.allclose(ret, expected, rtol=1e-5, atol=1e-6)


def test_grpo_verk_step_reward_step_norm_no_reweight():
    batch_size = 4
    seq_len = 4
    index = np.array([0, 0, 1, 1], dtype=np.int64)

    response_mask = torch.ones(batch_size, seq_len, dtype=torch.float32)
    assistant_turn_ids = torch.tensor([[0, 0, 1, 1]] * batch_size, dtype=torch.long)

    assistant_turn_rewards = torch.tensor(
        [
            [1.0, 2.0],
            [3.0, 9.0],
            [4.0, 5.0],
            [6.0, 9.0],
        ],
        dtype=torch.float32,
    )
    turn_successes = torch.tensor(
        [
            [1, 1],
            [0, -1],
            [1, 0],
            [0, -1],
        ],
        dtype=torch.int64,
    )

    adv, ret = compute_grpo_verk_step_reward_step_norm_advantage(
        token_level_rewards=torch.zeros(batch_size, seq_len, dtype=torch.float32),
        response_mask=response_mask,
        index=index,
        assistant_turn_ids=assistant_turn_ids,
        turn_successes=turn_successes,
        assistant_turn_rewards=assistant_turn_rewards,
    )

    reached = turn_successes != -1
    scalars = _expected_step_norm_adv(assistant_turn_rewards, reached, index)
    expected = torch.gather(scalars, 1, assistant_turn_ids) * response_mask

    assert torch.allclose(adv, expected, rtol=1e-5, atol=1e-6)
    assert torch.allclose(ret, expected, rtol=1e-5, atol=1e-6)


def test_grpo_verk_step_reward_step_norm_reweight_changes_values():
    batch_size = 6
    seq_len = 4
    index = np.array([0, 0, 0, 1, 1, 1], dtype=np.int64)

    response_mask = torch.ones(batch_size, seq_len, dtype=torch.float32)
    assistant_turn_ids = torch.tensor([[0, 0, 1, 1]] * batch_size, dtype=torch.long)

    assistant_turn_rewards = torch.tensor(
        [
            [1.0, 4.0],
            [2.0, 6.0],
            [3.0, 5.0],
            [7.0, 1.0],
            [9.0, 2.0],
            [8.0, 3.0],
        ],
        dtype=torch.float32,
    )
    turn_successes = torch.tensor(
        [
            [1, 0],
            [1, 0],
            [0, 0],
            [0, 1],
            [0, 1],
            [0, 0],
        ],
        dtype=torch.int64,
    )

    adv_unweighted, ret_unweighted = compute_grpo_verk_step_reward_step_norm_advantage(
        token_level_rewards=torch.zeros(batch_size, seq_len, dtype=torch.float32),
        response_mask=response_mask,
        index=index,
        assistant_turn_ids=assistant_turn_ids,
        turn_successes=turn_successes,
        assistant_turn_rewards=assistant_turn_rewards,
    )
    adv_weighted, ret_weighted = compute_grpo_verk_step_reward_step_norm_reweight_advantage(
        token_level_rewards=torch.zeros(batch_size, seq_len, dtype=torch.float32),
        response_mask=response_mask,
        index=index,
        assistant_turn_ids=assistant_turn_ids,
        turn_successes=turn_successes,
        assistant_turn_rewards=assistant_turn_rewards,
    )

    assert not torch.allclose(adv_weighted, adv_unweighted, rtol=1e-5, atol=1e-6)
    assert not torch.allclose(ret_weighted, ret_unweighted, rtol=1e-5, atol=1e-6)


if __name__ == "__main__":
    unittest.main()
