import gym
import numpy as np
from .base import MorphologyEnv

class DictObsWrapper(gym.Wrapper):
    '''
    Wrapper for extracting only one key from the MorphologyEnv
    This is good for training on only edge features, for example
    '''
    def __init__(self, env, keys=None):
        assert not keys is None, "Must provide a key to wrapper"
        super().__init__(env)
        self.keys = keys
        if len(self.keys) == 1:
            self.observation_space = self.env.observation_space[self.keys[0]]
        else:
            lows, highs = list(), list()
            for key in self.keys:
                lows.append(self.env.observation_space[key].low.flatten())
                highs.append(self.env.observation_space[key].high.flatten())
            low  = np.concatenate(lows, axis=0)
            high = np.concatenate(highs, axis=0)
            self.observation_space = gym.spaces.Box(low=low, high=high)
    
    def _wrap_obs(self, obs):
        if len(self.keys) == 1:
            return obs[self.keys[0]]
        else:
            out_obs = []
            for key in self.keys:
                out_obs.append(obs[key].flatten())
            out_obs = np.concatenate(out_obs, axis=0)
            return out_obs

    def step(self, action):
        obs, reward, done, info = self.env.step(action)
        return self._wrap_obs(obs), reward, done, info

    def reset(self, **kwargs):
        obs = self.env.reset(**kwargs)
        return self._wrap_obs(obs)

class NodeWrapper(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)
        if not env.pad_actions:
            self.env.pad_actions = True
            self.env.set_action_space()
        self.action_space = self.env.action_space
        obs = self._wrap_obs(self.env._get_obs())
        self.env.set_observation_space(obs)
        self.observation_space = self.env.observation_space

    @staticmethod
    def _wrap_obs(obs):
        edge_index = np.concatenate((obs['edge_index'], np.roll(obs['edge_index'] , 1, axis=1)), axis=0)
        # Edges features: num_edges x F
        padding = np.zeros((1, obs['edge_attr'].shape[1]))
        padded_edge_features = np.concatenate((padding, obs['edge_attr']), axis=0)
        if 'u' in obs and not obs['u'] is None:
            assert len(obs['u'].shape) == 2, "Global features must include batch dim"
            u = np.tile(obs['u'], (len(obs['x']), 1))
            x = np.concatenate((obs['x'], padded_edge_features, u), axis=1)
        else:
            x = np.concatenate((obs['x'], padded_edge_features), axis=1)
        return dict(x=x, edge_index=edge_index)

    @staticmethod
    def get_morphology_graph(morphology, include_segments=False):
        return NodeWrapper._wrap_obs(MorphologyEnv.get_morphology_graph(morphology, include_segments=include_segments))

    def get_morphology_obs(self, include_segments):
        return self._wrap_obs(self.env.get_morphology_obs(include_segments=include_segments))

    def step(self, action):
        obs, reward, done, info = self.env.step(action[1:])
        return self._wrap_obs(obs), reward, done, info

    def reset(self, **kwargs):
        obs = self.env.reset(**kwargs)
        return self._wrap_obs(obs)

class NodeMorphologyWrapper(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)
        if not env.pad_actions:
            self.env.pad_actions = True
            self.env.set_action_space()
        self.action_space = self.env.action_space
        obs = self._morphology_wrap_obs(self.env._get_obs())
        self.env.set_observation_space(obs)
        self.observation_space = self.env.observation_space

    @staticmethod
    def _wrap_obs(obs):
        edge_index = np.concatenate((obs['edge_index'], np.roll(obs['edge_index'] , 1, axis=1)), axis=0)
        # Edges features: num_edges x F
        padding = np.zeros((1, obs['edge_attr'].shape[1]))
        padded_edge_features = np.concatenate((padding, obs['edge_attr']), axis=0)
        if 'u' in obs and not obs['u'] is None:
            assert len(obs['u'].shape) == 2, "Global features must include batch dim"
            u = np.tile(obs['u'], (len(obs['x']), 1))
            x = np.concatenate((obs['x'], padded_edge_features, u), axis=1)
        else:
            x = np.concatenate((obs['x'], padded_edge_features), axis=1)
        return dict(x=x, edge_index=edge_index)

    def _morphology_wrap_obs(self, obs):
        obs = self._wrap_obs(obs)
        morphology_obs = self.get_morphology_obs(include_segments=False)
        obs['x'] = np.concatenate((obs['x'], morphology_obs['x']), axis=1)
        return obs

    def get_morphology_obs(self, include_segments):
        return self._wrap_obs(self.env.get_morphology_obs(include_segments=include_segments))

    def step(self, action):
        obs, reward, done, info = self.env.step(action[1:])
        obs = self._morphology_wrap_obs(obs)
        return obs, reward, done, info

    def reset(self, **kwargs):
        obs = self._morphology_wrap_obs(self.env.reset(**kwargs))
        return obs

class LineGraphWrapper(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)
        if not env.pad_actions:
            self.env.pad_actions = True
            self.env.set_action_space()
        self.action_space = self.env.action_space
        raw_obs = self.env._get_obs()
        self._set_edges(raw_obs)
        obs = self._wrap_obs(raw_obs)
        self.env.set_observation_space(obs)
        self.observation_space = self.env.observation_space

    def _wrap_obs(self, obs):
        parent_ids, child_ids = obs['edge_index'][:, 0], obs['edge_index'][:, 1]
        x = np.concatenate((obs['x'][parent_ids], obs['x'][child_ids], obs['edge_attr'][child_ids - 1]), axis=1)
        if 'u' in obs and not obs['u'] is None:
            assert len(obs['u'].shape) == 2, "Global features must include batch dim"
            u = np.tile(obs['u'], (len(x), 1))
            x = np.concatenate((x, u), axis=1)
        return dict(x=x, edge_index=self.line_edges)

    def step(self, action):
        obs, reward, done, info = self.env.step(action)
        return self._wrap_obs(obs), reward, done, info

    def _set_edges(self, raw_obs):
        edges = raw_obs['edge_index']
        line_edges = []
        for i in range(len(edges) - 1):
            for j in range(i+1, len(edges)):
                # Considers all edges
                node_set_1 = set(list(edges[i]))
                node_set_2 = set(list(edges[j]))
                if len(node_set_1.intersection(node_set_2)) > 0:
                    # These "body parts" share a connection.
                    # Need to update graph based on JOINT graph.
                    joint_id_1 = edges[i][1] - 1 # Get Child ID - 1
                    joint_id_2 = edges[j][1] - 1 # Get Child ID - 1
                    line_edges.append([joint_id_1 ,joint_id_2])
                    line_edges.append([joint_id_2, joint_id_1]) # Add both for undirected graph
        if len(line_edges) == 0:
            raise ValueError("Got zero line edges.")
        self.line_edges = np.array(line_edges)

    def reset(self, **kwargs):
        obs = self.env.reset(**kwargs)
        return self._wrap_obs(obs)
