from __future__ import annotations

import weakref
from abc import ABC, abstractmethod
from collections import OrderedDict
from typing import Callable, Dict, List, Optional, Tuple

import rich.progress as rp
import torch
from tensordict import TensorDictBase
from tensordict.nn import TensorDictModule
from torch.optim.optimizer import Optimizer
from torchrl.collectors import DataCollectorBase
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.record.loggers import Logger

from src.async_rl.module.collector import CollectorInfoReader

Module = torch.nn.Module | Optimizer | TensorDictModule

INCLUDE_KEYS = [
	"observation",
	"action",
	"is_init",
	"reward",
	"next.observation",
	"next.done",
	"next.reward",
	"next.is_init",
	"qv_recurrent_state",
	"next.qv_recurrent_state",
]


class RLTrainerHook(ABC):
	"""Base class for hooks in the reinforcement learning trainer."""

	def __init__(self) -> None:
		self.trainer: Optional[weakref.ReferenceType[RLTrainer]] = None

	@abstractmethod
	def register(self, trainer: RLTrainer, name: str) -> None:
		"""Registers the hook to the trainer.

		Args:
			trainer (RLTrainer): Trainer to register the hook to.
			name (str): Name of the hook. No two hooks can have the same
				name.
		"""
		self.trainer = weakref.ref(trainer)


class RLTrainer:
	"""Trainer for reinforcement learning agents.

	This class replaces the existing one provided by TorchRL to provide
	functionality to train agents which losses are decoupled, something
	that the original implementation does not support.

	Args:
		collector (DataCollectorBase): Data collector for the agent.
		buffer (ReplayBuffer): Replay buffer for the agent.
		optim_steps_per_batch (int | str): Number of optimization steps
			to perform per batch or `"traj"` to automatically match it
			with the number of steps in the trajectory.
		include_additional_keys (Optional[List[str]], optional):
			Additional keys to include in the buffer, if needed.
			Defaults to `None`, which means the included keys are only
			the ones in `INCLUDE_KEYS`.
		logger (Optional[Logger], optional): Logger to use. Defaults to
			`None`.
		progress_bar (bool, optional): Whether to show a progress bar.
			Defaults to `True`.
		device (Optional[torch.DeviceObjType], optional): Device to use.
			Defaults to `None`. When `None`, `torch.device("cpu")` is
			used.
	"""

	def __init__(
		self,
		collector: DataCollectorBase,
		buffer: ReplayBuffer,
		optim_steps_per_batch: int | str,
		include_additional_keys: Optional[List[str]] = None,
		logger: Optional[Logger] = None,
		progress_bar: bool = True,
		device: Optional[torch.device] = None,
	) -> None:
		self.collector = collector
		self.collector_info = CollectorInfoReader(collector)
		self.buffer = buffer
		self.logger = logger

		self.optim_steps_per_batch = optim_steps_per_batch
		self.include_keys = INCLUDE_KEYS
		if include_additional_keys is not None:
			self.include_keys += include_additional_keys
		self.progress_bar = progress_bar

		self.device = device or torch.device("cpu")

		self.buffer_keys: List[str | Tuple[str, str]] = []

		self._start_train_hooks = []
		self._sample_hooks = []
		self._optim_hooks = []
		self._end_optim_hooks = []
		self._end_traj_hooks = []

		self.metrics = {"optim_steps": 0}
		self.collected_frames = 0

		self._modules: Dict[str, Module] = {}
		self._batch_queue = []

	def train(self) -> None:
		"""Trains the agent."""
		for func in self._start_train_hooks:
			func()

		bar: Optional[rp.Progress] = None
		training_task: Optional[rp.TaskID] = None
		bar = rp.Progress(
			rp.TextColumn("{task.description}"),
			rp.BarColumn(),
			rp.TextColumn("{task.completed} of {task.total}"),
			rp.TaskProgressColumn(),
			rp.TimeElapsedColumn(),
			rp.TimeRemainingColumn(),
			disable=not self.progress_bar,
		)
		bar.start()
		collector_task = bar.add_task(
			"[blue]Collecting Steps",
			total=self.collector_info.rem_iterations,
		)
		random_task = bar.add_task(
			"[red]Random Phase", total=self.collector_info.random_iterations
		)

		for traj in self.collector:
			traj = traj.reshape(-1)
			traj = traj.select(*self._read_buffer_keys(traj))
			self.buffer.extend(traj.cpu().detach())
			self.collected_frames += traj.numel()

			if self.collector_info.random_phase():
				bar.update(
					collector_task,
					advance=self.collector_info.update_bar(traj),
				)
				bar.update(
					random_task, advance=self.collector_info.update_bar(traj)
				)
				continue

			try:
				bar.remove_task(random_task)
			except KeyError:
				pass

			optim_steps = (
				self.optim_steps_per_batch
				if isinstance(self.optim_steps_per_batch, int)
				else traj.numel()
			)

			training_task = bar.add_task("[green]Training", total=optim_steps)
			for _ in range(optim_steps):
				self._optim_step()
				self.metrics["optim_steps"] += 1
				for func in self._end_optim_hooks:
					func()
				bar.update(training_task, advance=1)
			bar.remove_task(training_task)

			self._log_metrics()

			for func in self._end_traj_hooks:
				func(self.metrics, self.logger)

			bar.update(
				collector_task,
				advance=self.collector_info.update_bar(traj),
			)

		bar.stop()

		self.collector.shutdown()

	def register_hook(self, hook_type: str, func: Callable) -> None:
		"""Registers a hook to the trainer.

		This is the main function that enables custom behavior tailored
		to the user's and agent's needs.

		There are various types of hooks that can be registered:
		- `"optim_step"`: Hook that is called at every optimization
			step. Mainly used for loss computation and backpropagation.
		- `"end_optim_step"`: Hook that is called at the end of every
			optimization step, _e.g._, updating target networks.
		- `"end_traj"`: Hook that is called at the end of every
			trajectory, _e.g._, logging metrics.

		Args:
			hook_type (str): Type of hook to register. Can be one of
				`"optim_step"`, `"end_optim_step"`, or `"end_traj"`.
			func (Callable): Function to register.

		Raises:
			ValueError: If `hook_type` is not one of the allowed values.
		"""
		if hook_type == "optim_step":
			self._optim_hooks.append(func)
		elif hook_type == "end_optim_step":
			self._end_optim_hooks.append(func)
		elif hook_type == "end_traj":
			self._end_traj_hooks.append(func)
		elif hook_type == "sample":
			self._sample_hooks.append(func)
		elif hook_type == "start_train":
			self._start_train_hooks.append(func)
		else:
			raise ValueError(f"Unknown hook type: {hook_type}")

	def register_module(self, module_name: str, module: Module) -> None:
		"""Registers a module to the trainer.

		A module is one of the following:
		- `torch.nn.Module`
		- `torch.optim.optimizer.Optimizer`
		- `tensordict.nn.TensorDictModule`

		Args:
			module_name (str): Name of the module. No two modules can
				have the same name.
			module (Module): Module to register.

		Raises:
			RuntimeError: If a module with the same name already exists.
		"""
		if module_name in self._modules:
			raise RuntimeError(f"Module name {module_name} already exists.")
		if not isinstance(module, Optimizer):
			module = module.to(self.device)
		self._modules[module_name] = module

	def load_state_dict(self, state_dict: OrderedDict) -> None:
		"""Loads the state of the trainer.

		Args:
			state_dict (Dict[str, torch.Tensor]): State dictionary to
				load.
		"""
		self.metrics = state_dict["trainer_state"]["metrics"]
		self.collected_frames = state_dict["trainer_state"]["collected_frames"]
		self.buffer_keys = state_dict["trainer_state"]["buffer_keys"]

		for name, module in self._modules.items():
			module.load_state_dict(state_dict["modules"][name])

		self.collector.load_state_dict(state_dict["collector"])

	def register_modules(self, modules: Dict[str, Module]) -> None:
		"""Registers multiple modules to the trainer.

		Args:
			modules (Dict[str, Module]): Modules to register.
		"""
		for name, module in modules.items():
			self.register_module(name, module)

	def _log_metrics(self) -> None:
		if self.logger is None:
			return
		for key, item in self.metrics.items():
			if key == "optim_steps":
				continue
			self.logger.log_scalar(key, item, self.metrics["optim_steps"])

	def _optim_step(self) -> None:
		batch = self.buffer.sample().clone()
		if batch.device != self.device:
			batch = batch.to(self.device)

		for fun in self._sample_hooks:
			batch = fun(batch)

		for func in self._optim_hooks:
			batch, losses = func(batch)
			self._update_losses_in_metrics(losses)

	def _update_losses_in_metrics(
		self, losses: Dict[str, torch.Tensor]
	) -> None:
		for key, value in losses.items():
			self.metrics[key] = value.item()  # type: ignore

	def _read_buffer_keys(
		self, traj: TensorDictBase
	) -> List[str | Tuple[str, str]]:
		if len(self.buffer_keys) == 0:
			self.buffer_keys = [
				x if "." not in x else (x.split(".")[0], x.split(".")[1])
				for x in traj.flatten_keys().sorted_keys
				if x in self.include_keys
			]
		return self.buffer_keys
