import random
from typing import List

import torch
from tensordict import TensorDictBase

from src.async_rl.module.utils import Key


@torch.no_grad()
@torch.jit.script
def sample_oneseq(seq: torch.Tensor, ratio: float) -> torch.Tensor:
	"""Sample signals' occurrences in the input sequence.

	Given a sequence of size `(L, D + 1)` with observations for each
	timestep, this function samples a subset of signals occurrences
	to simulate higher latency in the signal observations.

	Args:
		seq (torch.Tensor): The input sequence of shape `(L, D + 1)`,
			where `L` is the sequence length and `D` is the signal
			dimension. The last dimension of the sequence indicates the
			time passed since the last signal observation.
		ratio (float): The ratio of signals to be sampled over the
			ones that are already observed.

	Returns:
		torch.Tensor: The input sequence with the sampled signals
			simulating higher latency.
	"""
	seen_idxs = torch.where(seq[:, -1] == 0)[0]
	if len(seen_idxs) == 0:
		return seq
	init_time = seq[0, -1].to(torch.long)
	num = max(1, int(len(seen_idxs) * ratio))

	keep_idxs, _ = seen_idxs[torch.randperm(len(seen_idxs))[:num]].sort()
	if keep_idxs[0] != 0:
		keep_idxs = torch.cat(
			[
				torch.tensor([0], device=seq.device),
				keep_idxs,
			]
		)

	diffs = torch.diff(
		keep_idxs,
		append=torch.tensor([seq.size(0)], device=seq.device),
	)
	sub = torch.repeat_interleave(keep_idxs, diffs, dim=0)
	new_last_seen = torch.arange(0, len(seq), device=seq.device) - sub
	new_last_seen[: keep_idxs[1]] += init_time

	new_signals = seq[keep_idxs, :-1].repeat_interleave(diffs, dim=0)

	return torch.cat(
		[
			new_signals,
			new_last_seen.unsqueeze(-1),
		],
		dim=1,
	).reshape(seq.shape)


@torch.no_grad()
def sample_signals(
	td: TensorDictBase,
	signal_keys: List[Key],
	min_ratio: float = 0.1,
	per_batch: bool = True,
	aggregate: bool = False,
) -> TensorDictBase:
	"""Sample signals' occurrences in the input tensor dictionary.

	Given a `TensorDictBase` whose observations contain asynchronous
	signal observations, this function samples a subset of signals
	occurrences to simulate higher latency in the signal observations.

	Args:
		td (TensorDictBase): The input tensor dictionary containing
			signal observations.
		signal_keys (List[Key]): The keys used to store signals in the
			input tensor dictionary.
		min_ratio (float): The minimum ratio of signals to be sampled.
			Defaults to `0.1`.
		per_batch (bool): Whether to sample a different ratio of signals
			for each batch separately. Defaults to `True`.
		aggregate (bool): Whether to aggregate the sampled signals,
			together with the action history, into a single
			`"observation"` key at the end of the sampling procedure.
			Defaults to `False`.

	Returns:
		TensorDictBase: The input tensor dictionary with the sampled
			signals.
	"""
	result = td.clone()
	to_aggregate, to_aggregate_next = [], []

	if per_batch:
		ratios = [random.uniform(min_ratio, 1.0) for _ in range(len(td))]
	else:
		ratios = [random.uniform(min_ratio, 1.0)] * len(td)

	for key in signal_keys:
		seq = torch.cat(
			[result[key], result["next"][key][:, -1].unsqueeze(1)], dim=1
		)
		sampled = []
		for b, ratio in zip(range(seq.size(0)), ratios, strict=True):
			sampled.append(sample_oneseq(seq[b], ratio))

		sampled = torch.stack(sampled, dim=0)

		if aggregate:
			to_aggregate.append(sampled[:, :-1])
			to_aggregate_next.append(sampled[:, 1:])
		else:
			result.set(key, sampled[:, :-1])
			result.set(("next", key), sampled[:, 1:])

	if aggregate:
		to_aggregate.append(result["action_history"])
		to_aggregate_next.append(result["next"]["action_history"])

		result.set("observation", torch.cat(to_aggregate, dim=-1))
		result.set(
			("next", "observation"), torch.cat(to_aggregate_next, dim=-1)
		)

	return result
