from __future__ import annotations

import os

import torch
from tensordict.nn import TensorDictModule
from torch.optim.optimizer import Optimizer

from src.async_rl.module.trainer import RLTrainer, RLTrainerHook

Module = torch.nn.Module | Optimizer | TensorDictModule


class SaveHook(RLTrainerHook):
	"""Trainer hook for saving the state of the trainer.

	This hook saves the state of the trainer, including all its modules
	and models' checkpoints, to be possibly used later.

	Args:
		save_path (str): Path to save the trainer's state. Do not
			include the file extension.
		save_interval (int, optional): Interval at which to save the
			trainer's state, based on number of optimization steps.
			Defaults to `1000`.
		override (bool, optional): Whether to override the save file
			every time a new save is made. Defaults to `False`.
	"""

	def __init__(
		self,
		save_dir: str,
		save_interval: int = 1_000,
		override: bool = False,
	) -> None:
		self.save_dir = save_dir
		self.save_interval = save_interval
		self.last_saved = 0
		self.override = override

		os.makedirs(save_dir, exist_ok=True)

	def register(self, trainer: RLTrainer, name: str = "SaveHook") -> None:
		"""Registers the save hook to the trainer.

		Args:
			trainer (RLTrainer): Trainer to register the hook to.
			name (str, optional): Name of the hook. Defaults to
				`"SaveHook"`.
		"""
		super().register(trainer, name)
		trainer.register_hook("end_optim_step", self._save_state)

	def _save_state(self) -> None:
		if self.trainer is None:
			return
		trainer = self.trainer()
		if trainer is None:
			return

		metrics = trainer.metrics
		if metrics["optim_steps"] - self.last_saved < self.save_interval:
			return
		self.last_saved = metrics["optim_steps"] - (
			metrics["optim_steps"] % self.save_interval
		)

		state = {
			"trainer_state": {
				"metrics": trainer.metrics,
				"collected_frames": trainer.collected_frames,
				"buffer_keys": trainer.buffer_keys,
			},
			"modules": {
				name: module.state_dict()
				for name, module in trainer._modules.items()
			},
			"collector": trainer.collector.state_dict(),
		}

		save_path = (
			"checkpoint.pth"
			if self.override
			else f"checkpoint_{metrics['optim_steps']}.pth"
		)
		torch.save(state, os.path.join(self.save_dir, save_path))
