import os
import numpy as np
import h5py as h5
from addict import Dict
from contextlib import contextmanager

import torch
from torch.utils.data import Dataset, DataLoader
from mawm.utils import jagged_slice


class RailEnvTrajectories(Dataset):
    # Globals (specific to the environment irrespective of the settings)
    NUM_ACTIONS = 5
    TARGET_CHANNELS = list(range(16 + 4)) + [21]

    def __init__(self, window_size, sequence_length, path=None, temporal_sample_spacing=1, split=None):
        """
        Parameters
        ----------
        window_size : list of int
            Size of the window visible to the agent.
        sequence_length : int
            Length of the input sequence.
        path : str
            Path to load the trajectories from.
        temporal_sample_spacing : int
            Temporal spacing between sequence samples. If set to N, every N-th timestep marks the
            possible start of a new sequence sample.
        split : float
            What fraction of dataset to use. The default behaviour (None) is to use everything.
            If set to "-0.8" and there are 100 trajectories in the dataset, we use the first 80 = 0.8 * 100.
            If set to "+0.8", we use the last 20.
        """
        if path is None:
            path = os.environ.get('TRAJ_PATH')
        assert path is not None, "Please specify the trajecotry location."
        self.path = path
        assert all([(_ws % 2) == 1 for _ws in window_size]), "Window sizes must be odd."
        self.window_size = list(window_size)
        self.sequence_length = sequence_length
        self.temporal_sample_spacing = temporal_sample_spacing
        self.split = split
        # Privates
        self._no_crop = False
        self._no_target_shift = False
        # Load up
        self._load_trajectories()

    def _load_trajectories(self):
        self.trajectories = Dict()
        with h5.File(self.path, 'r') as f:
            # Read number of trajectories
            num_trajectories = f['actions'].shape[0]
            split_slice = self._get_split_slice(num_trajectories)
            for key in f.keys():
                self.trajectories[key] = np.asarray(f[key][split_slice])

    def _get_split_slice(self, num_trajectories):
        if self.split is None:
            return slice(None)
        elif self.split < 0:
            start = 0
            stop = int(num_trajectories * (-self.split))
            return slice(start, stop)
        elif self.split > 0:
            start = int(num_trajectories * self.split)
            stop = None
            return slice(start, stop)
        else:
            raise ValueError

    @contextmanager
    def no_crop(self):
        self._no_crop = True
        yield
        self._no_crop = False

    @contextmanager
    def no_target_shift(self):
        self._no_target_shift = True
        yield
        self._no_target_shift = False

    @property
    def env_shape(self):
        return self.trajectories.map_spec.shape[1:3]

    @property
    def num_agents(self):
        return self.trajectories.actions.shape[-1]

    @property
    def num_directions(self):
        return 4

    @property
    def num_steps(self):
        return self.trajectories.actions.shape[1]

    @property
    def num_trajectories(self):
        return self.trajectories.actions.shape[0]

    @staticmethod
    def _r(size):
        return list(range(size))

    def _fetch_global_state(self, traj_idx, start_t):
        # Get a one-hot ego-marker for all agents (shape: TACHW with C = 1, i.e. TA1HW)
        ego_markers = np.zeros((self.sequence_length, self.num_agents, 1) + self.env_shape,
                               dtype='float32')
        _seq_range = list(range(start_t, start_t + self.sequence_length))
        for agent_idx in range(self.num_agents):
            ego_markers[
                _seq_range, agent_idx, 0,
                self.trajectories.positions[traj_idx, _seq_range, agent_idx, 0],
                self.trajectories.positions[traj_idx, _seq_range, agent_idx, 1]
            ] = 1
        # Get a integer direction markers for all agents, which we will one-hotify. It has the shape (TA1HW)
        integer_dir_markers = np.zeros((self.sequence_length, self.num_agents, 1) + self.env_shape,
                                       dtype='uint8')
        for agent_idx in range(self.num_agents):
            integer_dir_markers[_seq_range, agent_idx, 0,
                                self.trajectories.positions[traj_idx, _seq_range, agent_idx, 0],
                                self.trajectories.positions[traj_idx, _seq_range, agent_idx, 1]] = \
                self.trajectories.directions[traj_idx, _seq_range, agent_idx] + 1
        # Now one-hotify to TA4HW
        dir_markers = np.equal(np.arange(1, self.num_directions + 1)[None, None, :, None, None],
                               integer_dir_markers).astype('float32')
        # Now, every agent should see the direction of every other agents, so we sum it up across all agents
        universal_dir_markers = dir_markers.sum(axis=1, keepdims=True)
        # Done
        storage = Dict()
        storage.ego_markers = ego_markers
        storage.dir_markers = dir_markers
        storage.universal_dir_markers = universal_dir_markers
        return storage

    def _fetch_and_crop_global_state(self, traj_idx, start_t, globals_=None):
        # First things first, get global state if not available
        if globals_ is None:
            globals_ = self._fetch_global_state(traj_idx, start_t)
        # This function should do the cropping for each agent.
        # Cropped state comprises:
        #   - TA(16)hw of map-spec
        #   - TA(4)hw of dir markers
        #   - TA(1)hw of ego markers
        #   - TA(1)hw of valid markers
        # We concatenate all of these and do one jagged slice.
        # Start with mapspec, which is originally NHWC. Slice to HWC, permute to CHW, and finally expand to TACHW
        map_spec = (np.moveaxis(self.trajectories.map_spec[traj_idx], 2, 0)[None, None]
                    .repeat(self.sequence_length, axis=0).repeat(self.num_agents, axis=1))
        # Now cat
        global_state = np.concatenate([map_spec, globals_.dir_markers, globals_.ego_markers], axis=2).astype('float32')
        if not self._no_crop:
            # and slice with a TA2 position tensor to obtain a local crop (TA(21)hw) and valid markers (TAhw)
            local_state, valid_positions = \
                jagged_slice(global_state,
                             self.trajectories.positions[traj_idx, start_t:start_t+self.sequence_length],
                             np.array(self.window_size))
            # concatenate local_state with the valid marker
            local_state = np.concatenate([local_state, valid_positions[:, :, None, :, :]], axis=2)
        else:
            # We actually want the global state (TA(21)hw), but in the same format as the local state, so we cat an
            # array of ones denoting valid states, since all positions are valid.
            local_state = np.concatenate([global_state, np.ones_like(global_state[:, :, 0:1, :, :])], axis=2)
        # local_state is now the TA(22)hw tensor we're after.
        return local_state

    def _fetch_done_markers(self, traj_idx, start_t, bloat=False):
        # agent_dones is a (A,) tensor. We need to expand it to a TA tensor.
        agent_dones = np.greater(np.arange(start_t, start_t + self.sequence_length)[:, None],
                                 self.trajectories.dones[traj_idx][None, :]).astype('float32')
        # Repeat to TA1hw if need be
        if bloat:
            h, w = self.window_size if not self._no_crop else self.env_shape
            agent_dones = (agent_dones[..., None, None]
                           .repeat(h, axis=-2)
                           .repeat(w, axis=-1)[:, :, None, :, :])
        return agent_dones

    def _fetch_onehot_actions(self, traj_idx, start_t):
        # actions is a TA tensor
        actions = self.trajectories.actions[traj_idx, start_t:(start_t + self.sequence_length)]
        # We convert it to TA5 tensor, because there are 5 actions in the environment
        onehot_actions = np.equal(np.arange(self.NUM_ACTIONS)[None, None, :], actions[:, :, None]).astype('float32')
        return onehot_actions

    def _fetch_positions(self, traj_idx, start_t):
        # Simply return the TA2 stuff
        return self.trajectories.positions[traj_idx, start_t:(start_t + self.sequence_length)].astype('float32')

    def _fetch_goals(self, traj_idx):
        # The goals are positions, but constant over time. This makes them A2 tensors.
        return self.trajectories.targets[traj_idx].astype('float32')

    @property
    def num_valid_starts(self):
        # This function converts the global index to a traj_idx and a start_t
        max_start = self.num_steps - self.sequence_length
        # Get number of valid starting indices
        num_valid_starts = (max_start // self.temporal_sample_spacing) + 1
        return num_valid_starts

    def _parse_index(self, idx):
        traj_idx, start_valid_idx = np.unravel_index(idx, (self.num_trajectories, self.num_valid_starts))
        start_t = start_valid_idx * self.temporal_sample_spacing
        return traj_idx, start_t

    def _process_to_target_state(self, state):
        # state is a TA(23)hw tensor, but contains agent specific channels which we do not expect the world-model
        # to care about. In this function, we remove these channels to obtain a real target for the world-model's
        # generative model.
        # Make a list of channels to keep. These are the first 16 channels of map-spec, the next 4 channels of
        # agent directions, and the channel containing the valid-marker (i.e. we drop the ego and done markers)
        target_channels = self.TARGET_CHANNELS
        # Now this would result in a target TA(21)hw tensor. However, at time t, we want our model to predict the
        # state at timestep t + 1. Accordingly, we shift the target one step to the future, resulting in a
        # (T-1)A(21)hw tensor.
        if self._no_target_shift:
            return state[:, :, target_channels, :, :]
        else:
            return state[1:, :, target_channels, :, :]

    def __len__(self):
        return self.num_trajectories * self.num_valid_starts

    def get(self, idx=None, traj_idx=None, start_t=None):
        """
        Generates a sequence sample.

        Parameters
        ----------
        idx : int
            Global (ravelled) index. Accepted only when `traj_idx` and `start_t` are None.
        traj_idx : int
            Trajectory index. Accepted only if `idx` is None.
        start_t : int
            Start time-step. Accepted only if `idx` is None.

        Notes
        -----
        In the following, T is the sequence length, A is the number of agents, h and w are height and width
        of the window (i.e. h, w = self.window_size). By a "TA(C)hw tensor" or "TAChw tensor", we mean a tensor
        with shape (T, A, C, h, w), where C can be a numeral. Note however that if the no_crop context manager
        is active, (h, w) is not the window size but the env_shape.

        Returns
        -------
        actions : numpy.ndarray
            An array of one-hot actions of shape TA5 (there are 5 actions in the environment).
        positions : numpy.ndarray
            An array of integer (i,j)-positions (but of dtype float32) of shape TA2.
        goals : numpy.ndarray
            An array of integer (i,j)-goals (but of dtype float32) of shape A2.
        states : numpy.ndarray
            An array of states of shape TA(23)hw. The 23 channels are comprised of:
                16 channels specifying the map (these are the transition maps in RailEnv)
                04 channels specifying the one-hot direction every agent in FOV is pointing towards.
                01 channels specifying the ego-marker (this channel is always constant).
                01 channels specifying the valid-marker, i.e. pixels of the FOV that are valid (not outside env bounds).
                01 channels specifying whether the agent agent is done (i.e. a binary constant over h,w).
        target_states : numpy.ndarray
            An array of states of shape (T-1)A(21)hw. These 21 channels are comprised of:
                16 channels specifying the map (like `states`).
                04 channels specifying the one-hot direction of every agent in FOV.
                01 channels specifying the valid-marker.
            The T-1 is due to the fact that the target states are one step to the future. Use the context manager
            `no_target_shfit` to temporarily deactivate this time shift.
        """
        if idx is not None:
            assert traj_idx is None
            assert start_t is None
            # Convert idx to traj_idx and a start_t
            traj_idx, start_t = self._parse_index(idx)
        else:
            assert traj_idx is not None
            assert start_t is not None
        # Fetch onehot actions (a TA5 tensor)
        actions = self._fetch_onehot_actions(traj_idx, start_t)
        # Fetch positions (a TA2 tensor)
        positions = self._fetch_positions(traj_idx, start_t)
        # Fetch goals (a A2 tensor)
        goals = self._fetch_goals(traj_idx)
        # Fetch states (a TA(22)hw tensor)
        states_without_dones = self._fetch_and_crop_global_state(traj_idx, start_t)
        # Fetch done markers (a TA1hw tensor)
        dones = self._fetch_done_markers(traj_idx, start_t, bloat=True)
        states = np.concatenate([states_without_dones, dones], axis=2)
        # Get target states
        target_states = self._process_to_target_state(states)
        # Return 'em
        return actions, positions, goals, states, target_states

    def __getitem__(self, item):
        actions, positions, goals, states, target_states = self.get(idx=item)
        # Convert to torch tensors
        return (torch.from_numpy(actions),
                torch.from_numpy(positions),
                torch.from_numpy(goals),
                torch.from_numpy(states),
                torch.from_numpy(target_states))


def rail_env_trajectories_data_loader(*dataset_args, num_workers=0, batch_size, **dataset_kwargs):
    dataset = RailEnvTrajectories(*dataset_args, **dataset_kwargs)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    return loader


