from typing import Dict, Optional

import torch
from tensordict.nn import TensorDictModule, TensorDictModuleBase
from torch.linalg import vector_norm
from torchrl.record.loggers import Logger

from src.async_rl.module.trainer import RLTrainer, RLTrainerHook


class LogWeightsHook(RLTrainerHook):
	"""Hook to log the weights of a loss module.

	This hook allows logging the maximum and norm of the weights of a
	loss module in order to monitor the training process.

	Args:
		loss_module (TensorDictModule | TensorDictModuleBase): The loss
			module to log the weights of.
		log_interval (int): Number of optimization steps between each
			logging of the weights. Defaults to `32`.
	"""

	def __init__(
		self,
		loss_module: TensorDictModule | TensorDictModuleBase,
		log_interval: int = 32,
	) -> None:
		self.loss_module = loss_module

		self.last_logged = 0
		self.log_interval = log_interval

	def register(
		self, trainer: RLTrainer, name: str = "LogWeightsHook"
	) -> None:
		"""Register the hook to the trainer.

		Args:
			trainer (RLTrainer): The trainer to register the hook to.
			name (str): The name of the hook. Defaults to
				`"LogWeightsHook"`.
		"""
		super().register(trainer, name)
		trainer.register_hook("end_traj", self._log_weights)

	@torch.inference_mode()
	def _log_weights(self, metrics: Dict, logger: Optional[Logger]) -> None:
		if logger is None:
			return

		if metrics["optim_steps"] - self.last_logged < self.log_interval:
			return
		self.last_logged = metrics["optim_steps"] - (
			metrics["optim_steps"] % self.log_interval
		)

		for name, param in self.loss_module.named_parameters():
			logger.log_scalar(
				f"weights/{name}_max",
				param.abs().max().item(),
				metrics["optim_steps"],
			)
			logger.log_scalar(
				f"weights/{name}_norm",
				vector_norm(param).item(),
				metrics["optim_steps"],
			)
