import os
from contextlib import contextmanager
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, ConcatDataset
#from ..rail_env import RailEnvTrajectories


class BBTrajectories(Dataset):

    def __init__(self, window_size, sequence_length, num_of_views, path=None, temporal_sample_spacing=1):
        """
        Parameters
        ----------
        window_size : list of int
            Size of the window visible to the agent.
        sequence_length : int
            Length of the input sequence.
        num_of_views: int
            Number of partial views to be taken of the global env. This equivalent to num_agents in SC2.
        path : str
            Path to load the trajectories from.
        temporal_sample_spacing : int
            Temporal spacing between sequence samples. If set to N, every N-th timestep marks the
            possible start of a new sequence sample.
        """
        if path is None:
            path = os.environ.get('TRAJ_PATH')
        assert path is not None, "Please specify the trajecotry location."
        self.path = path
        assert all([(_ws % 2) == 1 for _ws in window_size]), "Window sizes must be odd."
        self.window_size = list(window_size)
        self.sequence_length = sequence_length
        self.temporal_sample_spacing = temporal_sample_spacing
        self.num_of_views = num_of_views
        self._no_target_shift = False
        self._no_crop = False
        # Load up
        self._load_trajectories()

    def _load_trajectories(self):
        npzfile = np.load(self.path, allow_pickle=True)
        self.trajectories = npzfile['images'].astype(np.bool)
        self.trajectories = np.expand_dims(self.trajectories, axis=4)
        assert len(self.trajectories.shape) == 5, \
            "Trajectories dims {} must be N,T,H,W,C".format(self.trajectories.shape)
        self.trajectories = np.transpose(self.trajectories, (0, 1, 4, 2, 3))
        self.limit = self.trajectories.shape[0]

    @contextmanager
    def no_crop(self):
        self._no_crop = True
        yield
        self._no_crop = False

    def _generate_centers(self, seq, num_of_views, window_size):
        """
        Generate the valid center points for the partial random views in the global
        view along the sequence. Since we don't have specifically defined agents,
        we define random centers within the image plane as the position of the views.

        Parameters
        -----------
        seq : numpy.ndarray
            The sequence of length T, C, H, W.
        num_of_views: int
            The number of partial views to be taken.
        window_size: int
            The window size of the views. This is to make sure that the sampled centers
            lie within the image range.

        Notes
        ------
        Since images and the window sizes are symmetric, we sample centers such that the
        partial views are always within the global range. For now, we don't check if the
        two randomly sampled are coincidentally the same.

        Returns
        --------
        paritial_views_centers: numpy.ndarray (T, A, 2)
        """
        partial_views_centers = []
        h, w = window_size[0], window_size[1]
        for t in range(seq.shape[0]):
            img_h, img_w = seq.shape[-2:]
            start_range = h // 2
            end_range = img_h - (h // 2)
            coord_x = np.expand_dims(np.random.randint(start_range, end_range, num_of_views), 1)
            coord_y = np.expand_dims(np.random.randint(start_range, end_range, num_of_views), 1)
            center = np.concatenate([coord_x, coord_y], 1)
            partial_views_centers.append(center)
        return np.asarray(partial_views_centers)

    def _crop_partial_views(self, seq, centers, window_size):
        """
        Crop the partial views from global views centered on the random positions
        sampled before. Window size is again assumed to be symmetric.

        Parameters
        -----------
        seq: numpy.ndarray
            The sequences. (T, C, H, W)
        centers: numpy.ndarray
            The numpy arrays containing positions of randomly chosen centers within
            the global view. (T, A, 2)
        window_size: int
            The symmetric window size of partial views for e.g. [5x5].

        Returns
        --------
        partial_views: numpy.ndarray (T, A, C, window_size, window_size)
        """
        T, C, H, W = seq.shape
        T, A, _ = centers.shape
        x_ = np.expand_dims(seq, 1).repeat(A, axis=1)
        h, w = window_size[0], window_size[1]
        partial_views = np.zeros(shape=(T, A, C, h, w), dtype=np.uint8)
        for t in range(T):
            for a in range(A):
                center_i, center_j = centers[t, a, 0], centers[t, a, 1]
                from_i, from_j = center_i - (h // 2), center_j - (w // 2)
                to_i, to_j = center_i + (h // 2) + 1, center_j + (w // 2) + 1
                partial_views[t, a, ...] = x_[t, a, :, from_i:to_i, from_j:to_j]
        return partial_views

    def _process_to_target_state(self, state):
        # state is a T,A,1,h,w tensor.
        if self._no_target_shift:
            return state
        else:
            return state[1:, ...]

    def __len__(self):
        return self.limit

    def get(self, idx=None):
        """
        Generates a sequence sample.

        Parameters
        ----------
        idx : int
            Global index for specifying the temporal sequence to be sampled.

        Notes
        -----
        In the following, T is the sequence length, A is the number of views, h and w are height and width
        of the window (i.e. h, w = self.window_size). By a "TA(C)hw tensor" or "TAChw tensor", we mean a tensor
        with shape (T, A, C, h, w), where C can be a numeral.

        Returns
        -------
        positions : numpy.ndarray **
            An array of integer (i,j)-positions (but of dtype float32) of shape T,A,2.
        states : numpy.ndarray
            An array of states of shape T,A,1,h,w.
        target_states : numpy.ndarray
            An array of states of shape (T-1), A, 1, h, w.
            The T-1 is due to the fact that the target states are one step to the future. Use the context manager
            `no_target_shfit` to temporarily deactivate this time shift.
        """

        sequence = self.trajectories[idx]
        # noinspection PyTypeChecker
        positions = self._generate_centers(seq=sequence, num_of_views=self.num_of_views,
                                           window_size=self.window_size)
        if not self._no_crop:
            # noinspection PyTypeChecker
            states = self._crop_partial_views(seq=sequence, centers=positions, window_size=self.window_size)
        else:
            # Global View: To make the states sizes consistent, repeat the sequence dims
            # acc to self.num_of_views for A.
            states = np.expand_dims(sequence, 1).repeat(self.num_of_views, axis=1)
        target_states = self._process_to_target_state(states)
        # Make dummy variables to make it work with the trainer without much effort
        actions = np.zeros(shape=(positions.shape[0] - 1, positions.shape[1], 0))
        goals = np.zeros(shape=(positions.shape[1], 0))
        return actions, positions, goals, states, target_states

    def __getitem__(self, item):
        # noinspection PyTupleAssignmentBalance
        actions, positions, goals, states, target_states = self.get(idx=item)
        # Convert to torch tensors
        return (torch.from_numpy(actions).float(),
                torch.from_numpy(positions).float(),
                torch.from_numpy(goals).float(),
                torch.from_numpy(states).float(),
                torch.from_numpy(target_states).float())


def bb_trajectories_data_loader(*dataset_args, num_workers=0, batch_size, **dataset_kwargs):
    dataset = BBTrajectories(*dataset_args, **dataset_kwargs)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    return loader


if __name__ == '__main__':
    path = '/Users/redacted/Python/mawm/data/bb_first_seq_train.npz'
    dataset = BBTrajectories(window_size = (5, 5), sequence_length= 100, num_of_views=13,
                             path= path, temporal_sample_spacing=1)

    actions, positions, goals, states, target_states = dataset[0]
    print("Actions: ", actions.shape)
    print("Positions: ", positions.shape)
    print("Goals: ", goals.shape)
    print("States: ", states.shape)
    print("Target States: ", target_states.shape)

    pass
    # loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)
    #
    # for batch in loader:
    #     print (batch[2].shape, batch[1].shape)
    #     break
    # pass



