import os
from addict import Dict
from contextlib import contextmanager
import h5py as h5

import numpy as np
from scipy.ndimage.interpolation import map_coordinates

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from ..rail_env import RailEnvTrajectories
from ...utils import round_to_bins, softmax_logit_map
from .conventions import UnitTypes, ActionTypes


class SC2Trajectories(RailEnvTrajectories):
    NUM_ACTIONS = len(ActionTypes())
    NUM_UNIT_TYPES = len(UnitTypes())

    def __init__(self, view_radius, num_radial_bins, num_angular_bins, sequence_length, path=None,
                 temporal_sample_spacing=1, add_spatial_markers=True, friendlies_only=True, mark_graves=False,
                 split=None):
        Dataset.__init__(self)
        if path is None:
            path = os.environ.get('TRAJ_PATH')
        self.path = path
        self.view_radius = view_radius
        self.num_radial_bins = num_radial_bins
        self.num_angular_bins = num_angular_bins
        self.add_spatial_markers = add_spatial_markers
        self.friendlies_only = friendlies_only
        self.sequence_length = sequence_length
        self.temporal_sample_spacing = temporal_sample_spacing
        self.mark_graves = mark_graves
        self.split = split
        # Privates
        self._no_crop = False
        self._no_target_shift = False
        self._cache = {}
        # Load trajectories
        self._load_trajectories()

    def _load_trajectories(self):
        self.trajectories = Dict()
        with h5.File(self.path, 'r') as f:
            # Read number of trajectories
            num_trajectories = f['action_types'].shape[0]
            split_slice = self._get_split_slice(num_trajectories)
            for key in f.keys():
                if f[key].shape and f[key].shape[0] == num_trajectories:
                    self.trajectories[key] = np.asarray(f[key][split_slice])
                else:
                    self.trajectories[key] = np.asarray(f[key])
        # Check consistency
        assert (self.num_friendlies + self.num_enemies) == self.num_agents

    @property
    def env_shape(self):
        return self.trajectories.map_size

    @property
    def num_friendlies(self):
        return self.trajectories.num_friendlies

    @property
    def num_enemies(self):
        return self.trajectories.num_enemies

    @property
    def num_agents(self):
        # unit_types should be na
        return self.trajectories.unit_types.shape[1]

    @property
    def num_trajectories(self):
        return self.trajectories.unit_types.shape[0]

    @property
    def num_steps(self):
        # action_types should be nta
        return self.trajectories.action_types.shape[1]

    @property
    def friendly_marker(self):
        # `friendly_marker` is the same for all trajectories.
        if 'friendly_marker' in self._cache:
            return self._cache['friendly_marker']
        else:
            friendly_marker = np.zeros(shape=(self.sequence_length, self.num_agents), dtype='float32')
            friendly_marker[:, :self.num_friendlies] = 1.
            self._cache['friendly_marker'] = friendly_marker
            return friendly_marker

    @property
    def spatial_markers(self):
        if 'spatial_marker' in self._cache:
            return self._cache['spatial_marker']
        else:
            rr, tt = np.meshgrid(np.linspace(-np.pi, np.pi, self.num_angular_bins, dtype='float32'),
                                 np.linspace(0, self.view_radius, self.num_radial_bins, dtype='float32'))
            spatial_marker = np.tile(np.array([rr, tt])[None, None], (self.sequence_length, self.num_agents, 1, 1, 1))
            self._cache['spatial_marker'] = spatial_marker
            return spatial_marker

    def _fetch_height_map(self, traj_idx, start_t):
        # height_map should be a TA(1)rΘ tensor, yielding the height at p + (r cosΘ, r sinΘ) where p is the position
        # at time T of agent A, as obtained from the TA2 tensor yielded by self._fetch_positions.
        # spatial_markers.shape = (2)rΘ, where the channels are r and Θ meshgrids. Chunk to two tensors of shape rΘ:
        rr, tt = list(self.spatial_markers[0, 0])
        # Convert rr, tt to delta in xx, yy from the true x, y position, i.e. dxx, dyy
        dxx = rr * np.cos(tt)
        dyy = rr * np.sin(tt)
        # Add deltas to the true positions.
        # positions.shape = TA2, which we split to two TA tensors
        positions = self._fetch_positions(traj_idx, start_t)
        x, y = positions[..., 0], positions[..., 1]
        # Now, make the meshgrid we need by adding the TA tensors to rΘ tensors to obtain TArΘ tensors
        # and fetch from the height map with map_coordinates
        xx, yy = x[:, :, None, None] + dxx[None, None, :, :], y[:, :, None, None] + dyy[None, None, :, :]
        # heights.shape = TArΘ
        heights = map_coordinates(self.trajectories.terrain_height, [xx, yy], order=1)
        # reshape to TA(1)rΘ and return
        return heights[:, :, None]

    def _fetch_onehot_unit_types(self, traj_idx):
        # Fetch the unit-type vector (of shape A)
        unit_type = self.trajectories.unit_types[traj_idx]
        # Convert to one-hot (shape = AC), and expand to TAC
        unit_type = (np.equal(np.arange(self.NUM_UNIT_TYPES)[None, :], unit_type[:, None])[None, :, :]
                     .repeat(self.sequence_length, axis=0))
        return unit_type.astype('float32')

    def _fetch_local_state_and_actions(self, traj_idx, start_t):
        # In what follows, the spatial coordinates are polar, comprising the radius r and angle Θ.
        # Local state is a TACrΘ tensor with C = 15 or 17 (if adding spatial markers), where the channels are:
        #   - (1) health
        #   - (1) energy
        #   - (1) cooldown
        #   - (1) shields
        #   - (1) friendly marker (marks whether the unit is friendly)
        #   - (9) unit_type marker (one-hot vector, marks the type of unit)
        #   - (1) terrain heights
        # Optionally:
        #   - (2) r and Θ markers.
        # In addition, the local_action is another TACrΘ tensor, where the channels are:
        #   - (9) action marker (one-hot vector, marks the action)
        #   - (1) target marker (marks the target of the current attack)
        # Optionally:
        #   - (2) r and Θ markers
        # We compute both states and actions in this function to avoid having to compute radii and angle
        # (and binning them) twice while still remaining stateless.
        # ---------------
        # Step 1: Get the polar indices (from position)
        # positions.shape = TA2. For each friendly agent, compute the distance from all other agents
        # (U friendlies and V enemies, where A = U + V)
        _friendly_slice = slice(0, self.num_friendlies)
        _seq_range = list(range(start_t, start_t + self.sequence_length))
        # vectors.shape = TAA2 where vectors[t, a1, a2] gives the vector going from a1 to a2 at time t
        vectors = (self.trajectories.positions[traj_idx, _seq_range, None, :, :] -
                   self.trajectories.positions[traj_idx, _seq_range, :, None, :])
        # radii.shape = TAA, where radii[t, a1, a2] is the distance between agents a1 and a2
        radii = np.linalg.norm(vectors, axis=-1)
        # angle.shape = TAA, where angle[t, a1, a2] is the angle to a2 as observed from a1.
        angles = np.arctan2(vectors[:, :, :, 1], vectors[:, :, :, 0])
        # Bin radii and angles (i.e. convert them to integer indices via "generalized rounding").
        # But note that we have an extra "buffer-bin" for radii; these store info about units outside the FOV,
        # but they'll be cropped before leaving this function.
        binned_radii = round_to_bins(radii, 0, self.view_radius, self.num_radial_bins, additional_buffer_bin=True)
        binned_angles = round_to_bins(angles, -np.pi, np.pi, self.num_angular_bins, additional_buffer_bin=False)
        # ---------------
        # Step 2: Gather the features by stacking them along the last axis, obtaining a TAC tensor
        features = np.stack([self.trajectories.healths[traj_idx, _seq_range],
                             self.trajectories.energies[traj_idx, _seq_range],
                             self.trajectories.cooldowns[traj_idx, _seq_range],
                             self.trajectories.shields[traj_idx, _seq_range],
                             self.friendly_marker], axis=-1)
        features = np.concatenate([features, self._fetch_onehot_unit_types(traj_idx)], axis=-1)
        num_features = features.shape[-1]
        # Reshape to a T1CA tensor for a very specific reason
        features = np.moveaxis(features, 1, 2)[:, None, :, :]
        # ---------------
        # Step 3: Write it out to the state tensor.
        # For now, local_states gets an extra radial bin, which stores info about agents outside the view_radius.
        local_states = np.zeros(shape=(self.sequence_length, self.num_agents,
                                       num_features, self.num_radial_bins + 1,
                                       self.num_angular_bins),
                                dtype='float32')
        # it's happenin' (see notebook)
        local_states[np.arange(self.sequence_length)[:, None, None, None],
                     np.arange(self.num_agents)[None, :, None, None],
                     np.arange(num_features)[None, None, :, None],
                     binned_radii[:, :, None, :], binned_angles[:, :, None, :]] = features
        # ---------------
        # Step 4: prepare the actions.
        # Actions are images like states, but channels comprising one-hot tensors of actions (of all agents in FOV),
        # with an additional channel specifying the target.
        num_action_features = self.NUM_ACTIONS + 1
        local_actions = np.zeros(shape=(self.sequence_length, self.num_agents,
                                        num_action_features, self.num_radial_bins + 1,
                                        self.num_angular_bins),
                                 dtype='float32')
        # Get the one-hot tensors of shape TAC and TAA
        actions_onehot, targets_onehot = self._fetch_actions(traj_idx, start_t)
        # Permute actions_onehot to a T1CA tensor (like features)
        actions_onehot = np.moveaxis(actions_onehot, 1, 2)[:, None, :, :]
        # Reshape targets_onehot to a TA1A tensor
        targets_onehot = targets_onehot[:, :, None, :]
        # Write out the actions
        local_actions[np.arange(self.sequence_length)[:, None, None, None],
                      np.arange(self.num_agents)[None, :, None, None],
                      np.arange(num_action_features - 1)[None, None, :, None],
                      binned_radii[:, :, None, :], binned_angles[:, :, None, :]] = actions_onehot
        # Write out the target
        local_actions[np.arange(self.sequence_length)[:, None, None, None],
                      np.arange(self.num_agents)[None, :, None, None],
                      np.arange(num_action_features - 1, num_action_features)[None, None, :, None],
                      binned_radii[:, :, None, :], binned_angles[:, :, None, :]] = targets_onehot
        # ---------------
        # And now we crop both tensors to TACrΘ
        local_states = local_states[:, :, :, :-1, :]
        local_actions = local_actions[:, :, :, :-1, :]
        if not self.mark_graves:
            # Find out who's still breathin'. These are units whose health is greater than 0.
            alive_map = np.greater(local_states[:, :, 0:1, :, :], 0.).astype(local_states.dtype)
            local_states = local_states * alive_map
            local_actions = local_actions * alive_map
        # Make a list of extra features, and concatenate only once (to avoid a copy)
        extra_state_features = [self._fetch_height_map(traj_idx, start_t)]
        extra_action_features = []
        # Add spatial markers if required
        if self.add_spatial_markers:
            extra_state_features.append(self.spatial_markers)
            extra_action_features.append(self.spatial_markers)
        # Add stuff in if required
        if extra_state_features:
            local_states = np.concatenate([local_states] + extra_state_features, axis=2)
        if extra_action_features:
            local_actions = np.concatenate([local_actions] + extra_action_features, axis=2)
        return local_states, local_actions

    def _fetch_actions(self, traj_idx, start_t):
        # We need not only the actions, but also the target tokens.
        # action_types.shape = TA
        action_types = self.trajectories.action_types[traj_idx, start_t:(start_t + self.sequence_length)]
        # Actions are of shape TA, so we expand to a one-hot TAC tensor.
        action_types = np.equal(np.arange(self.NUM_ACTIONS)[None, None, :],
                                action_types[:, :, None]).astype('float32')
        # action_target_ids.shape = TA
        action_target_ids = self.trajectories.action_target_ids[traj_idx, start_t:(start_t + self.sequence_length)]
        # Expand to a pseudo-one-hot tensor, TAA. Pseudo because the sum along the last dimension can sum to 1 or 0.
        action_target_ids = np.equal(np.arange(self.num_agents)[None, None, :],
                                     action_target_ids[:, :, None]).astype('float32')
        return action_types, action_target_ids

    def _process_to_target_state(self, state):
        # See super-class for more detailed comments. The only difference here is that there is no notion of
        # "target channels", since all channels must be predicted by the model.
        if self._no_target_shift:
            return state
        else:
            return state[1:]

    def _fetch_outcome(self, traj_idx):
        # outcome is a binary variable specifying whether the battle was won
        return self.trajectories['battle_wons'][traj_idx, None].astype('float32')

    def get(self, idx: int = None, traj_idx: int = None, start_t: int = None):
        """
        Generates a sequence sample.

        Parameters
        ----------
        idx : int
            Global (ravelled) index. Accepted only when `traj_idx` and `start_t` are None.
        traj_idx : int
            Trajectory index. Accepted only if `idx` is None.
        start_t : int
            Start time-step. Accepted only if `idx` is None.

        Notes
        -----
        In the following, T is the sequence length, A is the number of agents, r and Θ are the number of radial and
        angular bins respectively. By a "TA(C)rΘ tensor" or "TACrΘ tensor", we mean a tensor
        with shape (T, A, C, r, Θ), where C can be a numeral.

        Returns
        -------
        actions : numpy.ndarray
            An array of states of shape TA(10)rΘ, or TA(12)rΘ if spatial markers are requested.
            The 10 channels are comprised of:
                09 channels specifying the one-hot actions of the unit at (r, Θ); enemy actions are ActionTypes.UNKNOWN
                01 channels specifying the target of an (attack) action.
            In addition, we have 02 channels of (radial) meshgrid if self.add_spatial_markers is set to True,
            where the first channel corresponds to radial and the second channel to angular coordinates.
        positions : numpy.ndarray
            An array of integer (i,j)-positions (but of dtype float32) of shape TA2.
        outcomes: numpy.ndarray
            A binary (but of dtype float32) denoting if the battle is (eventually) won by the friendlies
            (i.e. w.r.t. the friendly marker in states)
        states : numpy.ndarray
            An array of states of shape TA(14)rΘ or TA(16)rΘ if spatial markers are requested.
            The 14 channels are comprised of:
                01 channels specifying the health of all units in FOV (friendlies and enemies)
                01 channels specifying the energy of all units in FOV (friendlies and enemies)
                01 channels specifying the cooldown of all units in FOV (friendlies and enemies).
                01 channels specifying the shields of all units in FOV (friendlies and enemies).
                01 channels specifying whether the unit at position (r, Θ) is friendly.
                09 channels specifying the unit-type (one-hot).
                01 channels specifying the height of the terrain.
            In addition, we have 02 channels of (radial) meshgrid if self.add_spatial_markers is set to True,
            where the first channel corresponds to radial and the second channel to angular coordinates.
        target_states : numpy.ndarray
            An array of states of shape (T-1)A({14, 16})rΘ, where the 14 channels are identical to that of the state.
            The T-1 is due to the fact that the target states are one step to the future. Use the context manager
            `no_target_shfit` to temporarily deactivate this time shift.
        """
        if idx is not None:
            assert traj_idx is None
            assert start_t is None
            # Convert idx to traj_idx and a start_t
            traj_idx, start_t = self._parse_index(idx)
        else:
            assert traj_idx is not None
            assert start_t is not None
        states, actions = self._fetch_local_state_and_actions(traj_idx, start_t)
        outcomes = self._fetch_outcome(traj_idx)
        positions = self._fetch_positions(traj_idx, start_t)
        target_states = self._process_to_target_state(states)
        if self.friendlies_only:
            # Crop out the states and retain only friendly states
            states = states[:, 0:self.num_friendlies]
            actions = actions[:, 0:self.num_friendlies]
            positions = positions[:, 0:self.num_friendlies]
            target_states = target_states[:, 0:self.num_friendlies]
        return actions, positions, outcomes, states, target_states

    def __getitem__(self, item):
        actions, positions, outcomes, states, target_states = self.get(idx=item)
        return (torch.from_numpy(actions).float(),
                torch.from_numpy(positions).float(),
                torch.from_numpy(outcomes).float(),
                torch.from_numpy(states).float(),
                torch.from_numpy(target_states).float())

    @classmethod
    def split_state_tensor(cls, state: torch.Tensor, with_null_unit_channel=False):
        """
        Splits state tensor (NTACrΘ) into constituent components.

        Parameters
        ----------
        state : torch.Tensor
            The state tensor. Can be NTACrΘ or TACrΘ, but the output is always a dict with NTACrΘ tensors.

        with_null_unit_channel : bool
            The state reconstruction can have an extra channel to indicate the presence of null-units, or equivalently,
            the absence of a unit (i.e. this extra channel functions as an "absorber" in the softmax). If this is to
            be the case, it must considered when splitting the state tensor.

        Returns
        -------
        components : Dict
            A dict of components with keys:
                - hecs: health, energy, cooldown and shields ~[0, inf]
                - friendly_marker: binary markers ~ {0, 1}
                - unit_types: one-hot markers of unit types
                - terrain: normalized image ~ [0, 1]
                - spatial_markers: spatial meshgrid of [0, inf] x [-pi, pi]
        """
        if state.dim() == 5:
            # add in batch-dimension if not available.
            state = state[None]
        components = Dict()
        cursor = 0
        components.hecs = state[:, :, :, 0:(cursor + 4)]
        cursor += 4
        components.friendly_marker = state[:, :, :, cursor:(cursor + 1)]
        cursor += 1
        components.unit_types = state[:, :, :, cursor:(cursor + cls.NUM_UNIT_TYPES)]    # FIXME
        cursor += (cls.NUM_UNIT_TYPES + 1 if with_null_unit_channel else cls.NUM_UNIT_TYPES)
        components.terrain = state[:, :, :, cursor:(cursor + 1)]
        cursor += 1
        components.spatial_markers = state[:, :, :, cursor:(cursor + 2)]
        return components

    @classmethod
    def consolidate_components(cls, components: Dict):
        """
        Merge components to a single tensor by concatenating along the channel axis. The order of concatenation is:
            - hecs (health, energy, cooldown, and shield)
            - friendly_marker
            - unit types
            - terrain
            - spatial_markers
        `components` must be a dict with the above arguments.
        Parameters
        ----------
        components : Dict
            A dict with the above mentioned keys.

        Returns
        -------
        torch.Tensor
        """
        state = torch.cat([components.hec,
                           components.friendly_marker,
                           components.unit_types,
                           components.terrain,
                           components.spatial_markers], dim=3)
        return state

    @classmethod
    def recon_normalizer(cls, recon_or_components, return_components=True):
        """
        Given a `state` tensor, normalize the components in it (hecs, friendly_marker, unit_types,
        terrain & spatial_markers) appropriately. For instance, unit_types must be softmaxed, whereas
        friendly_marker must be sigmoided. This function takes care of just that.

        Parameters
        ----------
        recon_or_components : Dict or torch.Tensor
            State tensor or the components in a Dict.

        return_components : bool
            Whether to return the components or to consolidate them to a single tensor.

        Returns
        -------
        torch.Tensor
        """
        if isinstance(recon_or_components, Dict):
            components = Dict(recon_or_components)
        else:
            components = cls.split_state_tensor(recon_or_components)
        components.friendly_marker = components.friendly_marker.sigmoid()
        components.unit_types = softmax_logit_map(components.unit_types, onehot=False, dim=3)
        if not return_components:
            return cls.consolidate_components(components)
        else:
            return components

    @classmethod
    def recon_to_state(cls, recon_or_components, return_components=True):
        """
        Convert reconstructions yielded by a model to states that can be fed to the said model. As of now,
        this entails converting the sigmoid-logit map of `friendly_marker` to binary, the logit-map of
        `unit_types` to one-hot.

        Parameters
        ----------
        recon_or_components : Dict or torch.Tensor
            Reconstruction tensor or the components in a Dict.
        return_components : bool
            Whether to return the components or to consolidate them to a single tensor.

        Returns
        -------
        torch.Tensor
        """
        if isinstance(recon_or_components, Dict):
            components = Dict(recon_or_components)
        else:
            components = cls.split_state_tensor(recon_or_components)
        components.friendly_marker = components.friendly_marker.clone().gt_(0.)
        components.unit_types = softmax_logit_map(components.unit_types, onehot=True, dim=3)
        if return_components:
            return components
        else:
            return cls.consolidate_components(components)


def sc2_env_trajectories_data_loader(*dataset_args, num_workers=0, batch_size, **dataset_kwargs):
    dataset = SC2Trajectories(*dataset_args, **dataset_kwargs)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    return loader


def get_concat_loader(*datasets, num_workers=0, batch_size):
    dataset = ConcatDataset(datasets)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
    return loader


if __name__ == '__main__':
    dataset = SC2Trajectories(view_radius=7, num_radial_bins=14, num_angular_bins=36, sequence_length=50,
                              path='/Users/redacted/Python/mawm/data/sc2_1c3s5z_greed-0.3_1000x_trajs.h5', mark_graves=False,
                              friendlies_only=False)
    actions, positions, outcomes, states, target_states = dataset[0]
    print(positions.max())
    print(positions.min())
    print(actions.shape)
    print(positions.shape)
    print(outcomes.shape)
    print(states.shape)
    print(target_states.shape)


