import numpy as np
import torch
from torch.utils.data import Dataset


class PolicyDataset(Dataset):
    def __init__(
        self,
        weights: np.ndarray,
        states: np.ndarray,
        actions: np.ndarray,
        n_state_samples: int | None = None,
        rng: np.random.Generator | None = None,
        device: str | torch.device | None = None,
    ) -> None:
        """
        Args:
            weights: NumPy array of all policy weights (shape: [P, W]).
            states: NumPy array of all states (shape: [P, S, O]).
            actions: NumPy array of all actions (shape: [P, S, A]).
            n_state_samples: Number of state samples per policy. If None,
                all states are used.
            rng: Random number generator for reproducibility. If None,
                uses the default NumPy RNG.
            device: Device to which the tensors should be moved. If None,
                uses the default device.

        P is the total number of policies.
        W is the number of weights for a policy.
        S is the fixed number of steps/states per policy.
        O is the size of the state space.
        A is the size of the action space.
        """
        assert weights.shape[0] == states.shape[0] == actions.shape[0]
        assert states.shape[1] == actions.shape[1]

        self.n_state_samples = n_state_samples

        self.weights = weights
        self.states = states
        self.actions = actions

        self.rng = rng if rng is not None else np.random.default_rng()
        self.device = device

    def __len__(self) -> int:
        return self.weights.shape[0]

    def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
        policy_weights = torch.from_numpy(self.weights[idx]).float()
        policy_states_np = self.states[idx]
        policy_actions_np = self.actions[idx]

        if (
            self.n_state_samples is None
            or self.n_state_samples >= policy_states_np.shape[0]
        ):
            policy_states = torch.from_numpy(policy_states_np).float()
            policy_actions = torch.from_numpy(policy_actions_np).float()
        else:
            idxes = self.rng.choice(
                policy_states_np.shape[0], self.n_state_samples, replace=False
            )
            policy_states = torch.from_numpy(policy_states_np[idxes]).float()
            policy_actions = torch.from_numpy(policy_actions_np[idxes]).float()

        return {
            "weights": policy_weights.to(self.device),
            "states": policy_states.to(self.device),
            "actions": policy_actions.to(self.device),
        }


def custom_collate_fn(
    batch: list[dict[str, torch.Tensor]],
) -> dict[str, torch.Tensor]:
    # Assuming all items in the batch have the same keys and tensor shapes

    return {
        "weights": torch.stack([item["weights"] for item in batch], dim=0),
        "states": torch.stack([item["states"] for item in batch], dim=0),
        "actions": torch.stack([item["actions"] for item in batch], dim=0),
    }
