from typing import Dict, List, Optional, Tuple

import numpy as np
from ray.rllib.policy.sample_batch import SampleBatch
import torch

from offline_rl.data.rllib_data_utils import load_sample_batches


class TransitionShufflingDataset(torch.utils.data.Dataset):
    """A torch dataset that shuffles (state, action, next state) triples.

    This is useful for training models with respect to a distribution over triples
    that doesn't reflect that of the transition model that generated the data. For example,
    the distribution used by EPIC (see reward evaluation directory).

    The shuffling pairs each state from each sample with some number of (action, next_state) pairs
    in a random fashion such that each time this dataset is created the dataset is different.

    Args:
        inputs: List of files or file patterns from which to load the sample batches.
            This is the same as rllib's JsonReader input.
        maintain_original_pairing: If True, one of the pairs is the original pairing.
        debug_size: If provided, limit the size of the dataset to this amount.
        debug_size_mode: Mode for loading debug_size. See `load_sample_batches` documentation.
    """
    def __init__(
            self,
            inputs: List[str],
            num_pairs: int,
            maintain_original_pairing: bool = True,
            debug_size: Optional[int] = None,
            debug_size_mode: str = "ordered",
    ):
        self.batch = load_sample_batches(inputs, debug_size, debug_size_mode)
        self.sample_index_to_batch_indices = self._compute_sample_indices(num_pairs, maintain_original_pairing)

    def _compute_sample_indices(self, num_pairs: int, maintain_original_pairing: bool) -> Dict[int, Tuple[int, int]]:
        """Builds the mapping from sample indices to pairs of indices into the original batch.

        Args:
            num_pairs: The number of pairs of (action, next_state) to associate with each original state.
            maintain_original_pairing: If True, one of the pairs is the original pairing.

        Returns:
            A mapping from sample index to a tuple of indices. The first element of the tuple is
            an index into self.batch where to get the state. The second element of the tuple is
            an index into self.batch where to get the action and next state.
        """
        assert num_pairs > 0
        sample_indices = np.arange(len(self.batch) * num_pairs)
        state_indices = np.repeat(np.arange(len(self.batch)), num_pairs)
        action_next_state_indices = np.random.randint(len(self.batch), size=len(state_indices))

        if maintain_original_pairing:
            assert num_pairs > 1, "If maintaining original pairing, more than two pairs should be used."
            # This sets the index of the first (action, next_state) pair associated with each state
            # to be the index of the original state.
            action_next_state_indices[::num_pairs] = state_indices[::num_pairs]

        return dict(zip(sample_indices, zip(state_indices, action_next_state_indices)))

    def __len__(self) -> int:
        return len(self.sample_index_to_batch_indices)

    def __getitem__(self, sample_index: int) -> Dict:
        state_index, action_next_state_index = self.sample_index_to_batch_indices[sample_index]
        # Probably doesn't make sense to make use of terminals downstream, but include to adhere to interfaces.
        return {
            SampleBatch.OBS: np.atleast_1d(self.batch[SampleBatch.OBS][state_index]),
            SampleBatch.ACTIONS: np.atleast_1d(self.batch[SampleBatch.ACTIONS][action_next_state_index]),
            SampleBatch.NEXT_OBS: np.atleast_1d(self.batch[SampleBatch.NEXT_OBS][action_next_state_index]),
            SampleBatch.DONES: np.atleast_1d(self.batch[SampleBatch.DONES][action_next_state_index]),
        }
