"""Metrics tracking for asynchronous RL experiments."""

from typing import Any, Dict, List

import numpy as np


class AsyncMetricsTracker:
	"""Track asynchronicity-specific metrics during training."""

	def __init__(self, track_signal_density: bool = True) -> None:
		self.track_signal_density = track_signal_density
		self.signal_reception_history: List[Dict[str, bool]] = []
		self.episode_rewards: List[float] = []
		self.step_count = 0

	def record_step(self, tensordict: Any, env_output: Any = None) -> None:
		"""Record metrics for a single environment step.

		Args:
			tensordict (Any): TensorDict containing environment
				observations.
			env_output (Any): Optional environment output information.
		"""
		self.step_count += 1

		if self.track_signal_density and tensordict is not None:
			signal_reception = {}

			for key in tensordict.keys():
				if key.startswith(
					("front_", "back_", "torso", "thigh", "leg")
				):
					if key in tensordict and len(tensordict[key].shape) > 0:
						signal_reception[key] = (
							tensordict[key][..., -1].item() == 0.0
							if hasattr(tensordict[key], "item")
							else float(tensordict[key][..., -1]) == 0.0
						)

			self.signal_reception_history.append(signal_reception)

	def record_episode_reward(self, reward: float) -> None:
		"""Record the total reward for a completed episode."""
		self.episode_rewards.append(reward)

	def get_signal_density_stats(
		self, window_size: int = 1000
	) -> Dict[str, Any]:
		"""Calculate signal reception density statistics.

		Args:
			window_size (int): Number of recent steps to analyze.

		Returns:
			Dict[str, Any]: Dictionary with density statistics.
		"""
		if not self.signal_reception_history:
			return {"mean_density": 0.0, "total_signals": 0}

		recent_history = self.signal_reception_history[-window_size:]

		if not recent_history:
			return {"mean_density": 0.0, "total_signals": 0}

		total_possible = 0
		total_received = 0
		signal_counts = {}

		for step_signals in recent_history:
			for signal_name, was_received in step_signals.items():
				total_possible += 1
				if was_received:
					total_received += 1

				if signal_name not in signal_counts:
					signal_counts[signal_name] = {"received": 0, "total": 0}
				signal_counts[signal_name]["total"] += 1
				if was_received:
					signal_counts[signal_name]["received"] += 1

		overall_density = (
			total_received / total_possible if total_possible > 0 else 0.0
		)

		signal_densities = {}
		for signal_name, counts in signal_counts.items():
			signal_densities[signal_name] = (
				counts["received"] / counts["total"]
				if counts["total"] > 0
				else 0.0
			)

		return {
			"mean_density": overall_density,
			"total_signals": len(signal_counts),
			"signal_densities": signal_densities,
			"total_steps_analyzed": len(recent_history),
		}

	def get_performance_stats(self) -> Dict[str, float]:
		"""Get performance statistics.

		Returns:
			Dict[str, float]: Dictionary with performance metrics.
		"""
		if not self.episode_rewards:
			return {"mean_reward": 0.0, "std_reward": 0.0, "num_episodes": 0}

		rewards = np.array(self.episode_rewards)
		return {
			"mean_reward": float(np.mean(rewards)),
			"std_reward": float(np.std(rewards)),
			"min_reward": float(np.min(rewards)),
			"max_reward": float(np.max(rewards)),
			"final_reward": float(rewards[-1]) if len(rewards) > 0 else 0.0,
			"num_episodes": len(rewards),
		}

	def get_all_metrics(self) -> Dict[str, Any]:
		"""Get all tracked metrics.

		Returns:
			Dict[str, Any]: Dictionary with all metrics.
		"""
		metrics = {}
		metrics.update(self.get_signal_density_stats())
		metrics.update(self.get_performance_stats())
		return metrics

	def reset(self) -> None:
		"""Reset all tracking data."""
		self.signal_reception_history.clear()
		self.episode_rewards.clear()
		self.step_count = 0


def extract_asynchronicity_from_config(config: Any) -> float:
	"""Extract expected asynchronicity rate from configuration.

	Args:
		config (Any): Experiment configuration.

	Returns:
		float: Estimated asynchronicity rate.
	"""
	try:
		if hasattr(config.env, "period_range") and config.env.period_range:
			period_range = config.env.period_range
			avg_period = (period_range[0] + period_range[1]) / 2.0
			return min(1.0, 1.0 / avg_period) if avg_period > 0 else 0.0
		elif hasattr(config.env, "train_make_kwargs"):
			kwargs = config.env.train_make_kwargs
			if hasattr(kwargs, "period_range") and kwargs.period_range:
				period_range = kwargs.period_range
				avg_period = (period_range[0] + period_range[1]) / 2.0
				return min(1.0, 1.0 / avg_period) if avg_period > 0 else 0.0
	except Exception:
		pass

	return 0.0
