from collections import deque

from common.imports import *
from common.logger import Logger
from common.utils import to_list_tensor, to_agent_shape
from .utils import make_env

class Evaluator:
    """Evaluator class for evaluating a reinforcement learning model deterministically.

    Attributes:
        env (gym.Env): Vectorized environment for evaluation.
        max_steps (int): Maximum number of steps in an episode.
        logger (Logger): Logger for storing evaluation metrics.
        device (th.device): Device to run the model on (e.g., 'cpu' or 'cuda').
    """

    def __init__(self, args: Dict[str, Any], logger: Logger, device: th.device) -> None:
        """Initialize the Evaluator with the given arguments, logger, and device.

        Args:
            args: Arguments containing environment configuration.
            logger: Logger for storing evaluation metrics.
            device: Device to run the model on.
        """

        self.env = gym.vector.AsyncVectorEnv([make_env(args, 0)])  # Initialize synchronized vector environment
        self.max_steps = args.max_steps  # Get max episode duration
        self.h_size = args.h_size
        self.logger = logger  # Logger for evaluation metrics
        self.device = device  # Device for model inference

    @th.no_grad()
    def evaluate(self, glob_step: int, qnet: object, eval_ep: int = 1) -> None:
        """Evaluate the model over a specified number of episodes.

        Args:
            glob_step: Global step for logging purposes.
            model: Model to be evaluated.
            eval_ep: Number of episodes for evaluation.
        """
        obs, _ = self.env.reset()
        obs = to_list_tensor(obs, self.device)

        agents = range(len(qnet.keys()))

        returns_q: Deque[float] = deque(maxlen=eval_ep)  # Queue to store returns of episodes
        lengths_q: Deque[float] = deque(maxlen=eval_ep)  # Queue to store length of episodes
        ep_reward, ep_length = 0, 0

        m_obs = {a: th.empty(obs[a].shape).to(self.device) for a in agents}   
        m_act = {a: th.empty(1, dtype=int) for a in agents}
        m_valid = {a: th.ones(1, dtype=bool) for a in agents}
        m_h = {a: th.zeros([1, self.h_size]).to(self.device) for a in agents}
        m_next_h = deepcopy(m_h)   

        while len(returns_q) < eval_ep:
            for a in agents:
                m_obs[a][m_valid[a]] = obs[a][m_valid[a]]
                m_h[a][m_valid[a]] = m_next_h[a][m_valid[a]] 
                (m_act[a][m_valid[a]], m_next_h[a][m_valid[a]]) = qnet[a].get_action(m_obs[a][m_valid[a]], m_h[a][m_valid[a]])

            next_obs, reward, _, _, info = self.env.step(m_act)
            next_obs = to_list_tensor(next_obs, self.device)
            m_valid = {a: md for a, md in enumerate(to_agent_shape(info['mac_done'], self.device, bool))}
            ep_reward += np.mean(reward)
            ep_length += 1

            # Record rewards for plotting purposes
            if "final_info" in info:
                returns_q.append(ep_reward)
                lengths_q.append(ep_length)
                ep_reward, ep_length = 0, 0
                m_next_h = {a: th.zeros([1, self.h_size]).to(self.device) for a in agents}
                m_valid = {a: th.ones(1, dtype=bool) for a in agents}

            obs = next_obs

        # Calculate average survival rate and return over the evaluated episodes
        avg_return = np.mean(returns_q)
        avg_length = np.mean(lengths_q)

        # Log the metrics if logger is available
        if self.logger: self.logger.store_metrics(glob_step, avg_return, avg_length)

        print(f"Eval at step {glob_step}, return={avg_return:.3f}")

