from __future__ import annotations

from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
from tensordict import TensorDictBase
from tensordict.nn import TensorDictModule
from torch.optim.optimizer import Optimizer
from torchrl.envs import EnvBase, ExplorationType, set_exploration_type
from torchrl.record.loggers import Logger

from src.async_rl.module.trainer import RLTrainer, RLTrainerHook

Module = torch.nn.Module | Optimizer | TensorDictModule


class EvalRolloutHook(RLTrainerHook):
	"""Trainer hook for evaluating the agent's performance.

	This is the analogous of the `Recorder` hook provided by TorchRL. It
	performs a rollout of the agent in an environment and logs the
	results.

	Args:
		env (EnvBase): Environment to evaluate the agent in.
		actor (TensorDictModule): Actor network to evaluate.
		max_frames (int, optional): Maximum number of frames to evaluate
			the agent for. Defaults to `1_024`.
		num_rollouts (int, optional): Number of rollouts to perform.
			Defaults to `1`. If `num_rollouts` is greater than `1`, the
			average of the metrics is logged.
		exploration_type (ExplorationType, optional): Exploration type
			to use during the rollout. Defaults to
			`ExplorationType.DETERMINISTIC`.
		log_interval (int, optional): Interval at which to log the
			results, based on number of optimization steps. Defaults to
			`32`.
		metrics (Optional[List[Tuple[str, str]]], optional): Metrics to
			log. Each metric is a tuple with the name of the metric and
			the aggregation mode. The aggregation mode can be one of
			`"mean"`, `"sum"`, or `"last"`. If `None`, the average of
			the rewards is logged. Defaults to `None`.
	"""

	def __init__(
		self,
		env: EnvBase,
		actor: TensorDictModule,
		max_frames: int = 1_024,
		num_rollouts: int = 1,
		exploration_type: ExplorationType = ExplorationType.DETERMINISTIC,  # type: ignore
		log_interval: int = 32,
		metrics: Optional[List[Tuple[str, str]]] = None,
	) -> None:
		self.env = env
		self.actor = actor
		self.max_frames = max_frames
		self.num_rollouts = num_rollouts
		self.exploration_type = exploration_type

		self.last_logged = 0
		self.log_interval = log_interval

		if metrics:
			self.metrics = metrics
		else:
			self.metrics = [("reward", "mean")]

	def register(
		self, trainer: RLTrainer, name: str = "EvalRolloutHook"
	) -> None:
		"""Registers the evaluation rollout hook to the trainer.

		Args:
			trainer (RLTrainer): Trainer to register the hook to.
			name (str, optional): Name of the hook. Defaults to
				"EvalRolloutHook".
		"""
		super().register(trainer, name)
		trainer.register_hook("end_traj", self._eval_rollout)

	@torch.inference_mode()
	def _eval_rollout(self, metrics: Dict, logger: Optional[Logger]) -> None:
		if logger is None:
			return

		if metrics["optim_steps"] - self.last_logged < self.log_interval:
			return
		self.last_logged = metrics["optim_steps"] - (
			metrics["optim_steps"] % self.log_interval
		)

		rollout: TensorDictBase
		rollout_metrics = {f"{key[0]}_{key[1]}": [] for key in self.metrics}
		with set_exploration_type(self.exploration_type):
			if isinstance(self.actor, torch.nn.Module):
				self.actor.eval()

			for _ in range(self.num_rollouts):
				rollout = self.env.rollout(
					policy=self.actor,
					max_steps=self.max_frames,
					auto_reset=True,
					auto_cast_to_device=True,
				).clone()["next"]  # type: ignore
				for key_mode in self.metrics:
					metric = self._retrieve_metric(rollout, key_mode)
					rollout_metrics[f"{key_mode[0]}_{key_mode[1]}"].append(
						metric
					)

			if isinstance(self.actor, torch.nn.Module):
				self.actor.train()

		for key_mode in self.metrics:
			logger.log_scalar(
				f"{key_mode[0]}_{key_mode[1]}",
				np.mean(rollout_metrics[f"{key_mode[0]}_{key_mode[1]}"]),  # type: ignore
				metrics["optim_steps"],
			)

	def _retrieve_metric(
		self, rollout: TensorDictBase, key_mode: Tuple[str, str]
	) -> float:
		value_not_aggr = rollout[key_mode[0]]
		if key_mode[1] == "mean":
			return value_not_aggr.mean().item()
		elif key_mode[1] == "sum":
			return value_not_aggr.sum().item()
		elif key_mode[1] == "last":
			return value_not_aggr[-1].item()
		raise ValueError(f"Unknown aggregation mode: {key_mode[1]}")
