import logging
from typing import Any

import gymnasium as gym

from tianshou.data import Collector, CollectStats
from tianshou.data import (
    ReplayBuffer,
)
from tianshou.env import BaseVectorEnv
from tianshou.policy import BasePolicy

log = logging.getLogger(__name__)


class CustomCollector(Collector):
    # NAMING CONVENTION (mostly suffixes):
    # episode - An episode means a rollout until done (terminated or truncated). After an episode is completed,
    # the corresponding env is either reset or removed from the ready envs.
    # N - number of envs, always fixed and >= R.
    # R - number ready env ids. Note that this might change when envs get idle.
    #     This can only happen in n_episode case, see explanation in the corresponding block.
    #     For n_step, we always use all envs to collect the data, while for n_episode,
    #     R will be at most n_episode at the beginning, but can decrease during the collection.
    # O - dimension(s) of observations
    # A - dimension(s) of actions
    # H - dimension(s) of hidden state
    # D - number of envs that reached done in the current collect iteration. Only relevant in n_episode case.
    # S - number of surplus envs, i.e. envs that are ready but won't be used in the next iteration.
    #     Only used in n_episode case. Then, R becomes R-S.
    def __init__(
        self,
        policy: BasePolicy,
        env: gym.Env | BaseVectorEnv,
        buffer: ReplayBuffer | None = None,
        exploration_noise: bool = False,
        *args,
        **kwargs
    ) -> None:
        """:param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
        :param env: a ``gym.Env`` environment or an instance of the
            :class:`~tianshou.env.BaseVectorEnv` class.
        :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class.
            If set to None, will instantiate a :class:`~tianshou.data.VectorReplayBuffer`
            as the default buffer.
        :param exploration_noise: determine whether the action needs to be modified
            with the corresponding policy's exploration noise. If so, "policy.
            exploration_noise(act, batch)" will be called automatically to add the
            exploration noise into action. Default to False.
        """
        super().__init__(policy, env, buffer, exploration_noise=exploration_noise)

        self.policy_eval_results = []
        self.policy_debug_results = []


    def _collect(
        self,
        n_step: int | None = None,
        n_episode: int | None = None,
        random: bool = False,
        render: float | None = None,
        gym_reset_kwargs: dict[str, Any] | None = None,
    ) -> CollectStats:

        result = super()._collect(
            n_step,
            n_episode,
            random,
            render,
            gym_reset_kwargs
        )

        if result.n_collected_episodes > 0:
            returns_mean, returns_std = result.returns.mean(), result.returns.std()
        else:
            returns_mean = returns_std = 0

        self.policy_eval_results.append([returns_mean, returns_std])

        return result
