import torch
from collections.abc import Sequence
from prettytable import PrettyTable

from .manager_base import ManagerBase, ManagerTermBase
from .manager_term_cfg import SuccessTermCfg  # 类似 RewardTermCfg

class SuccessManager(ManagerBase):
    """Manager for computing episodic success score."""

    def __init__(self, cfg: object, env):
        self._term_names: list[str] = []
        self._term_cfgs: list[SuccessTermCfg] = []
        self._class_term_cfgs: list[SuccessTermCfg] = []

        super().__init__(cfg, env)

        self._episode_safe = torch.ones(self.num_envs, dtype=torch.float, device=self.device)

    def __str__(self) -> str:
        msg = f"<SuccessManager> contains {len(self._term_names)} active terms.\n"
        table = PrettyTable()
        table.title = "Active Success Terms"
        table.field_names = ["Index", "Name"]
        for index, (name, term_cfg) in enumerate(zip(self._term_names, self._term_cfgs)):
            table.add_row([index, name])
        msg += table.get_string()
        return msg

    def reset(self, env_ids: Sequence[int] | None = None) -> dict[str, torch.Tensor]:
        """结算并返回 episodic success score"""
        if env_ids is None:
            env_ids = slice(None)

        # 判定最终task reward
        extras = {}
        for term_idx, (name, term_cfg) in enumerate(zip(self._term_names, self._term_cfgs)):
            if name != 'criteria':
                continue
            term_value = term_cfg.func(self._env, **term_cfg.params)
            value = term_value * self._episode_safe

        # reset scores for next episode
        extras["success"] = value
        self._episode_safe[env_ids] = 1.0
        return extras

    def compute(self, dt: float):
        # step过程中，只考虑是否违反constraint
        success_dict = {}

        for term_idx, (name, term_cfg) in enumerate(zip(self._term_names, self._term_cfgs)):
            if name != 'constraint':
                continue

            term_value, infos = term_cfg.func(self._env, **term_cfg.params)
            self._episode_safe = term_value * self._episode_safe
        
        return self._episode_safe, infos

    def _prepare_terms(self):
        if isinstance(self.cfg, dict):
            cfg_items = self.cfg.items()
        else:
            cfg_items = self.cfg.__dict__.items()
        for term_name, term_cfg in cfg_items:
            if term_cfg is None:
                continue
            if not isinstance(term_cfg, SuccessTermCfg):
                raise TypeError(
                    f"Configuration for the term '{term_name}' is not of type SuccessTermCfg."
                    f" Received: '{type(term_cfg)}'."
                )
            self._resolve_common_term_cfg(term_name, term_cfg, min_argc=1)
            self._term_names.append(term_name)
            self._term_cfgs.append(term_cfg)
            if isinstance(term_cfg.func, ManagerTermBase):
                self._class_term_cfgs.append(term_cfg)

    @property
    def active_terms(self) -> list[str]:
        """Name of active reward terms."""
        return self._term_names
