from collections import OrderedDict
from typing import Iterator, Optional, cast

import torch
from tensordict import TensorDictBase
from tensordict.nn import InteractionType, TensorDictModule
from torchrl.collectors import (
	DataCollectorBase,
	RandomPolicy,
	SyncDataCollector,
)
from torchrl.envs import EnvBase
from torchrl.envs.batched_envs import BatchedEnvBase
from torchrl.envs.transforms import TransformedEnv
from torchrl.envs.utils import set_exploration_type


class CollectorInfoReader:
	"""Helper class for reading information from a data collector.

	This helper class is used to read information from the collector
	regarding its progress, so that it can be displayed in a progress
	bar during training.

	Args:
		collector (DataCollectorBase): Data collector to read
			information from. It must be either a `RolloutCollector` or
			a `SyncDataCollector`.
	"""

	def __init__(self, collector: DataCollectorBase) -> None:
		assert isinstance(
			collector, (RolloutCollector, SyncDataCollector)
		), "Collector must be either RolloutCollector or SyncDataCollector"
		self.collector = collector

	@property
	def total_iterations(self) -> int:
		"""Total number of iterations to collect.

		Returns:
			int: Total number of iterations.

		Raises:
			TypeError: If the collector type is invalid.
		"""
		if isinstance(self.collector, RolloutCollector):
			return self.collector.total_iterations
		elif isinstance(self.collector, SyncDataCollector):
			return self.collector.total_frames
		else:
			raise TypeError("Invalid collector type")

	@property
	def random_iterations(self) -> int:
		"""Number of iterations to collect random data.

		Returns:
			int: Number of iterations to collect random data.

		Raises:
			TypeError: If the collector type is invalid.
		"""
		if isinstance(self.collector, RolloutCollector):
			return self.collector.random_iterations
		elif isinstance(self.collector, SyncDataCollector):
			return self.collector.init_random_frames
		else:
			raise TypeError("Invalid collector type")

	@property
	def rem_iterations(self) -> int:
		"""Remaining number of iterations to collect.

		Returns:
			int: Remaining number of iterations.

		Raises:
			TypeError: If the collector type is invalid.
		"""
		if isinstance(self.collector, RolloutCollector):
			return self.collector.rem_rollouts
		elif isinstance(self.collector, SyncDataCollector):
			return self.collector.total_frames - self.collector._frames
		else:
			raise TypeError("Invalid collector type")

	def random_phase(self) -> bool:
		"""Whether the collector is in the random phase.

		Returns:
			bool: Whether the collector is in the random phase.

		Raises:
			TypeError: If the collector type is invalid.
		"""
		if isinstance(self.collector, RolloutCollector):
			return self.collector.rem_random_rollouts > 0
		elif isinstance(self.collector, SyncDataCollector):
			return self.collector._frames < self.collector.init_random_frames
		else:
			raise TypeError("Invalid collector type")

	def update_bar(self, traj: TensorDictBase) -> int:
		"""Number of iterations to update the progress bar.

		Args:
			traj (TensorDictBase): Trajectory to update the progress
				bar.

		Returns:
			int: Number of iterations to update the progress bar.

		Raises:
			TypeError: If the collector type is invalid.
		"""
		if isinstance(self.collector, RolloutCollector):
			return 1
		elif isinstance(self.collector, SyncDataCollector):
			return traj.numel()
		else:
			raise TypeError("Invalid collector type")


class RolloutCollector(DataCollectorBase):
	"""Data collector for collecting rollouts from the environment.

	Compared to the `SyncDataCollector` from TorchRL, this collector
	allows performing whole rollouts instead of individual steps. In
	particular, this means that it does not define a batch size, but
	collects a complete trajectory at a time, regardless of the number
	of steps it contains.

	Args:
		create_env_fn (EnvBase): Environment to collect data from.
		policy (Optional[TensorDictModule], optional): Policy to use for
			rollouts. Defaults to `None`. If `None`, a random policy is
			used.
		total_rollouts (int): Total number of rollouts to collect.
		init_random_rollouts (int): Number of initial random rollouts.
		max_frames_per_traj (int): Maximum number of frames per
			trajectory.
		exploration_type (str, optional): Type of exploration to use.
			Defaults to `"random"`.
		device (torch.device): Device to use for computation.
	"""

	def __init__(
		self,
		create_env_fn: EnvBase,
		policy: Optional[TensorDictModule] = None,
		*,
		total_rollouts: int,
		init_random_rollouts: int,
		max_frames_per_traj: int,
		exploration_type: str = "random",
		device: torch.device,
	) -> None:
		assert exploration_type in [
			"random",
			"deterministic",
		], "Invalid exploration type"

		self.env = create_env_fn.to(device)

		if policy is None:
			self.policy = RandomPolicy(self.env.full_action_spec)
		else:
			self.policy = policy

		self.rem_rollouts = total_rollouts + init_random_rollouts
		self.rem_random_rollouts = init_random_rollouts
		self.max_frames_per_traj = max_frames_per_traj
		self.total_iterations = total_rollouts + init_random_rollouts
		self.random_iterations = init_random_rollouts
		self.exploration_type = InteractionType.from_str(exploration_type)

	def iterator(self) -> Iterator[TensorDictBase]:
		"""Returns an iterator over the collected rollouts.

		Yields:
			TensorDictBase: Collected rollout.
		"""
		while self.rem_rollouts > 0:
			with set_exploration_type(self.exploration_type):
				yield self._perform_rollout()

	def set_seed(self, seed: int, static_seed: bool = False) -> int:
		"""Sets the seed for the environment.

		Args:
			seed (int): Seed to set.
			static_seed (bool, optional): Whether to use a static seed.
				Defaults to `False`.

		Returns:
			int: Seed used.
		"""
		return cast(int, self.env.set_seed(seed, static_seed=static_seed))

	def shutdown(self) -> None:
		"""Shuts down the environment."""
		if not self.env.is_closed:
			self.env.close()

	def state_dict(self) -> OrderedDict:
		"""Returns the state dictionary.

		Returns:
			OrderedDict: State dictionary.
		"""
		if isinstance(self.env, TransformedEnv):
			env_state_dict = self.env.transform.state_dict()
		elif isinstance(self.env, BatchedEnvBase):
			env_state_dict = self.env.state_dict()
		else:
			env_state_dict = OrderedDict()

		if hasattr(self.policy, "state_dict"):
			policy_state_dict = self.policy.state_dict()  # type: ignore
			state_dict = OrderedDict(
				policy_state_dict=policy_state_dict,
				env_state_dict=env_state_dict,
			)
		else:
			state_dict = OrderedDict(env_state_dict=env_state_dict)

		state_dict.update(
			metrics={
				"rem_rollouts": self.rem_rollouts,
				"rem_random_rollouts": self.rem_random_rollouts,
				"max_frames_per_traj": self.max_frames_per_traj,
				"total_iterations": self.total_iterations,
				"random_iterations": self.random_iterations,
			},
			settings={
				"exploration_type": str(self.exploration_type),
			},
		)

		return state_dict

	def load_state_dict(self, state_dict: OrderedDict, **kwargs) -> None:  # noqa: ANN003
		"""Loads the state dictionary.

		Args:
			state_dict (OrderedDict): State dictionary to load. It must
				be a dictionary that was returned by `state_dict`.
			**kwargs: Additional keyword arguments.
		"""
		self.env.load_state_dict(state_dict["env_state_dict"])

		if "policy_state_dict" in state_dict:
			assert not isinstance(
				self.policy, RandomPolicy
			), "Trying to load state dict to random policy"
			self.policy.load_state_dict(state_dict["policy_state_dict"])
		else:
			self.policy = RandomPolicy(self.env.full_action_spec)

		metrics = state_dict["metrics"]
		self.rem_rollouts = metrics["rem_rollouts"]
		self.rem_random_rollouts = metrics["rem_random_rollouts"]
		self.max_frames_per_traj = metrics["max_frames_per_traj"]
		self.total_iterations = metrics["total_iterations"]
		self.random_iterations = metrics["random_iterations"]

		self.exploration_type = InteractionType.from_str(
			state_dict["settings"]["exploration_type"]
		)

	@torch.no_grad()
	def _perform_rollout(self) -> TensorDictBase:
		"""Performs a rollout.

		Returns:
			TensorDictBase: Collected rollout.
		"""
		with set_exploration_type(self.exploration_type):
			if isinstance(self.policy, torch.nn.Module):
				self.policy.eval()

			policy = self.policy if self.rem_random_rollouts <= 0 else None
			rollout = self.env.rollout(
				policy=policy,
				max_steps=self.max_frames_per_traj,
				auto_reset=True,
				auto_cast_to_device=True,
			)

			self.rem_random_rollouts -= 1
			self.rem_rollouts -= 1

			if isinstance(self.policy, torch.nn.Module):
				self.policy.train()

			return rollout.clone()  # type: ignore
