# Copyright (c) 2022-2025, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md).
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

# Copyright (c) 2022-2025, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

"""Launch Isaac Sim Simulator first."""

from isaaclab.app import AppLauncher

# launch omniverse app
simulation_app = AppLauncher(headless=True).app

"""Rest everything follows."""

import torch
from collections import namedtuple

import pytest

from isaaclab.managers import SuccessManager, SuccessTermCfg
from isaaclab.sim import SimulationContext
from isaaclab.utils import configclass


def constraint_out_of_bounds(env, threshold=0.8):
    # 如果任何 env 的位置越界 -> 返回 0
    pos = torch.rand(env.num_envs)  # 取x坐标
    violated = torch.any(torch.abs(pos) > threshold)
    return torch.where(violated, torch.zeros_like(pos), torch.ones_like(pos))

def distance_success(env, target=1.0):
    pos = torch.rand(env.num_envs)
    return torch.clamp(pos / target, 0.0, 1.0) * 100.0

@pytest.fixture
def env():
    sim = SimulationContext()
    return namedtuple("ManagerBasedRLEnv", ["num_envs", "dt", "device", "sim"])(20, 0.1, "cpu", sim)


def test_str(env):
    """Test the string representation of the reward manager."""
    cfg = {
        "constraint": SuccessTermCfg(func=constraint_out_of_bounds, weight=10, params={"threshold": 0.8}),
        "criteria": SuccessTermCfg(func=distance_success, weight=5, params={"target": 1.0}),
    }
    rew_man = SuccessManager(cfg, env)
    assert len(rew_man.active_terms) == 2
    # print the expected string
    print()
    print(rew_man)


def test_compute(env):
    """Test the computation of reward."""
    cfg = {
        "constraint": SuccessTermCfg(func=constraint_out_of_bounds, weight=10, params={"threshold": 0.8}),
        "criteria": SuccessTermCfg(func=distance_success, weight=5, params={"target": 1.0}),
    }
    rew_man = SuccessManager(cfg, env)
    # compute expected reward
    # expected_reward = cfg["term_1"].weight * env.dt

    # compute reward using manager
    rewards = rew_man.compute(dt=env.dt)
    # check the reward for environment index 0
    print(rewards)
    # assert float(rewards[0]) == expected_reward
    # assert tuple(rewards.shape) == (env.num_envs,)
    
