from typing import Dict, List, Optional, Tuple

import numpy as np
from ray.rllib.policy.sample_batch import SampleBatch
import torch

from offline_rl.data.rllib_data_utils import load_sample_batches


class PairedTrajectorySegmentDataset(torch.utils.data.Dataset):
    """A torch dataset that returns demonstrations / trajectory segments in pairs.

    This is useful for training preference-based reward models. Here's what it does:
    1. Loads rllib trajectories into memory.
        This dataset assumes that it is loading rllib output (SampleBatch).
    2. Splits those trajectories into `segment_length` segments.
        Currently, this class ignores terminal states. As such it does not make sense
        to use this class with a discount factor when computing returns on segments.
        The segments are defined such that no timestep of any trajectory is shared between segments.
    3. Returns samples where each sample is a pair of segments.
        Specifically, each sample consists of a set of tensors (state, action, next_state), which
        are of shape (2, segment_length, state/action_dim), where the `2` corresponds to the
        first and second segments.

    Args:
        inputs: List of files or file patterns from which to load the sample batches.
            This is the same as rllib's JsonReader input.
        segment_length: The length of the segments to load.
        max_num_pairs: The maximum number of pairs to generate for each segment.
            If None, generates a dataset with all pairs.
            If an int, generates this number of pairs randomly upload initialization.
        debug_size: The number of unique timesteps to load. Not that this is _not_ the
            number of _samples_ (i.e., paired trajectory segments).
    """
    # The SampleBatch keys to return paired segments of.
    SAMPLE_BATCH_KEYS = [SampleBatch.OBS, SampleBatch.ACTIONS, SampleBatch.NEXT_OBS]

    def __init__(
            self,
            inputs: List[str],
            segment_length: int,
            max_num_pairs: Optional[int] = None,
            debug_size: Optional[int] = None,
    ):
        # The batch must be loaded in order.
        self.batch = load_sample_batches(inputs, debug_size, debug_size_mode="ordered")
        assert len(self.batch) > segment_length * 2, \
            "At least two segments must exist. Length of data: {len(self.batch)}"
        self.segment_length = segment_length
        # Split into segments completely ignoring terminal states.
        self.num_segments = len(self.batch) // self.segment_length
        self.sample_index_to_segment_indices = self._compute_segment_indices(max_num_pairs)

    def _compute_segment_indices(self, max_num_pairs: Optional[int]) -> Dict:
        """Precomputes mapping from sample index to segment indices.

        Args:
            max_num_pairs: See __init__ docs.

        Returns:
            A dict mapping sample_index to a pair of segment indices.
        """
        if max_num_pairs is None:
            # Generate all pairs.
            # There's a closed form method of computing this mapping, but this works for small datasets.
            sample_index_to_segment_indices = dict()
            sample_index = 0
            for i in range(self.num_segments):
                for j in range(i + 1, self.num_segments):
                    sample_index_to_segment_indices[sample_index] = (i, j)
                    sample_index += 1
            return sample_index_to_segment_indices
        else:
            assert max_num_pairs > 0
            # Cap max_num_pairs at self.num_segments o/w you end up with repeated pairs for no reason.
            max_num_pairs = min(max_num_pairs, self.num_segments)
            num_samples = self.num_segments * max_num_pairs
            sample_indices = np.arange(num_samples)
            segment_a_indices = np.repeat(np.arange(self.num_segments), max_num_pairs)
            # Sample in a manner potentially allowing pairs of identical segments as well as repeats.
            # Neither should matter with a large dataset.
            segment_b_indices = np.random.randint(self.num_segments, size=num_samples)
            return dict(zip(sample_indices, zip(segment_a_indices, segment_b_indices)))

    def __len__(self) -> int:
        return len(self.sample_index_to_segment_indices)

    def _convert_sample_index_to_segment_indices(self, sample_index: int) -> Tuple[int, int]:
        """Converts a sample index into two indices, one for each relevant segment."""
        return self.sample_index_to_segment_indices[sample_index]

    def _get_segment(self, segment_index: int) -> Dict:
        """Gets the tensors associated with a segment.

        Args:
            segment_index: Index of the segment to select.

        Returns:
            Dictionary of tensors associated with that segment.
        """
        start = segment_index * self.segment_length
        end = start + self.segment_length
        return {k: self.batch[k][start:end] for k in self.SAMPLE_BATCH_KEYS}

    def __getitem__(self, sample_index: int) -> Dict:
        """Gets a sample from this dataset.

        See class notes for description of sample.
        """
        segment_1_index, segment_2_index = self._convert_sample_index_to_segment_indices(sample_index)

        segment_1_tensors = self._get_segment(segment_1_index)
        segment_2_tensors = self._get_segment(segment_2_index)

        sample = dict()
        for key, segment_1_tensor in segment_1_tensors.items():
            assert key in segment_2_tensors
            assert segment_2_tensors[key].shape == segment_1_tensor.shape, \
                f"key: {key}, seg 2 shape: {segment_2_tensors[key].shape}, seg 1 shape: {segment_1_tensor.shape}"
            sample[key] = np.stack((segment_1_tensor, segment_2_tensors[key]))

        return sample
