import os, sys
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(ROOT_DIR)
import numpy as np
from sklearn.utils.validation import check_random_state
from collections import deque
from scipy.special import expit

PLANNING_HORIZON = 10
INIT_AS = 25  # This is the population size
AS_PER_GEN = 5
MAX_EVALUATED_SEQUENCES = 100  # Total number of policies evaluated at each planning step.
N_PARTICLES = 1
N_ACTIONS = 3


class Grid:
    def __init__(self, bins, dimensions):
        """
        This data structure encodes the grid of MAP-Elites. For now only one agent per cell is supported
        :param bins: number of bins
        :param dimensions: [[minx, maxx], [miny, maxy], ...]
        :param name:
        """
        self.data = deque()  # This contains tuples with (agent_id, gt_bd bd, traj). The agent_id is necessary in case there are same bd-traj
        self.bins = bins
        self.dimensions = dimensions

        self.grid = self.init_grid()
        self.cell_lims = [np.linspace(dim[0], dim[1], num=self.bins+1) for dim in self.dimensions]

    def reset(self):
        """
        This function resets the archive to an empty state
        """
        del self.data
        self.data = deque()
        self.grid = self.init_grid()

    # ---------------------------------
    def __len__(self):
        """
        Returns the length of the archive
        """
        return self.size

    def __iter__(self):
        """
        Allows to directly iterate the pop.
        :return:
        """
        return self.data.__iter__()

    def __next__(self):
        """
        During iteration returns the next element of the iterator
        :return:
        """
        return self.data.__next__()

    def __getitem__(self, item):
        """
        Returns the asked item
        :param item: item to return. Can be an agent or a key
        :return: returns the corresponding agent or the column of the dict
        """
        if type(item) == str:
            try:
                return [np.array(x[item]) for x in self.data]  # Return list of those items
            except ValueError:
                raise ValueError('Wrong key given. Available: {} - Given: {}'.format(list(self.data[0].keys()), item))
        else:
            return self.data[item]

    @property
    def size(self):
        """
        Size of the archive
        """
        return len(self.data)
    # ---------------------------------

    def init_grid(self):
        """
        Initialized the grid with None values
        :return:
        """
        return np.full([self.bins] * len(self.dimensions), fill_value=None)

    def store(self, agent):
        """
        Store data in the archive as a list of: (genome, gt_bd, bd, traj).
        No need to store the ID given that we store the genome.
        Saving as a tuple instead of a dict makes the append operation faster

        It also checks if the grid cell is already occupied. In case it is, saves the one with the highest fitness

        :param agent: agent to store
        :return:
        """
        assert len(agent['bd']) == len(self.dimensions), \
            print('BD of wrong size. Given: {} - Expected: {}'.format(len(agent['bd']), len(self.dimensions)))
        cell = self._find_cell(agent['bd'])
        if self.grid[cell] is None:
            # Add idx of agent in the datalist
            self.grid[cell] = agent['id']
            self.data.append(agent)
            return True
        else:
            # Find stored agent
            stored_agent_idx = self['id'].index(self.grid[cell])
            stored_agent = self.data[stored_agent_idx]
            # Only store if reward is higher
            if agent['reward'] >= stored_agent['reward']:
                del self.data[stored_agent_idx]
                self.grid[cell] = agent['id']
                self.data.append(agent)
                return True
        return False

    def _find_cell(self, bd):
        """
        This function finds in which cell the given BD belongs
        :param bd:
        :return:
        """
        cell_idx = []
        for dim_idx, dim in enumerate(self.dimensions):
            assert dim[1] >= bd[dim_idx] >= dim[0], \
                print("BD outside of grid. BD: {} - Bottom Limits: {} - Upper Limits: {}".format(bd, dim[0], dim[1]))

            # The max() is there so if we are at the bottom border the cell counts as the first
            # Remove 1 for indexing starting at 0
            cell_idx.append(max(np.argmax(self.cell_lims[dim_idx] >= bd[dim_idx]), 1) - 1)
        return tuple(cell_idx)


class NNController:
    """
    This class contains the NNs evolved by QD that is then used to generate action sequences
    """
    def __init__(self, input_size, output_size, hidden_layers, hidden_layer_size):
        self.input_size = input_size
        self.output_size = output_size
        self.hidden_layer_size = hidden_layer_size
        self.hidden_layers = hidden_layers
        self.bias_size = self.hidden_layer_size

        if self.hidden_layers > 0:
            self.genome_size = self.input_size * self.hidden_layer_size + self.bias_size + \
                               (self.hidden_layer_size * self.hidden_layer_size + self.bias_size) * (
                                           self.hidden_layers - 1) + \
                               self.hidden_layer_size * self.output_size + self.output_size
        else:
            self.genome_size = self.input_size * self.output_size + self.output_size

    def load_weights(self, weights):
        """
        Loads the weights of all the controllers
        weights: array of shape [num_controllers, genome_len]
        """
        self.layers = []
        self.bias = []
        idx = 0
        if self.hidden_layers > 0:
            # Input to hidden
            start = idx
            end = start + self.input_size * self.hidden_layer_size
            idx = end
            layer = weights[:, start:end]
            self.layers.append(np.reshape(layer, (-1, self.input_size, self.hidden_layer_size)))
            self.bias.append(weights[:, idx:idx+self.bias_size])
            idx += self.bias_size
            # Hidden to hidden
            for k in range(self.hidden_layers - 1):
                start = idx
                end = start + self.hidden_layer_size * self.hidden_layer_size
                idx = end
                layer = weights[:, start:end]
                self.layers.append(np.reshape(layer, (-1, self.hidden_layer_size, self.hidden_layer_size)))
                self.bias.append(weights[:, idx:idx+self.bias_size])
                idx += self.bias_size
            # Hidden to output
            start = idx
            end = start + self.hidden_layer_size * self.output_size
            idx = end
            layer = weights[:, start:end]
            self.layers.append(np.reshape(layer, (-1, self.hidden_layer_size, self.output_size)))
            self.bias.append(weights[:, idx:idx+self.bias_size])
            idx += self.output_size
        else:
            idx = self.input_size * self.output_size
            layer = weights[:, :idx]
            self.layers.append(np.reshape(layer, (-1, self.input_size, self.output_size)))
            self.bias.append(weights[:, idx:])

        assert len(self.bias) == len(self.layers), f'Not enough bias or layers. Bias {len(self.bias)} - ' \
                                                   f'Layers {len(self.layers)}'

    def get_action(self, observations):
        """
        Get the actions from all the population due to the observations.
        All in a vectorized form
        """
        data = observations
        for i in range(len(self.layers)):
            # Doing the einsum this way is equivalent of doing this down here:
            # [d.dot(l) for d, l in zip(data, layer)]
            # expit is the Sigmoid activation [0, 1]
            data = expit(np.einsum('Bi,BiN ->BN', data, self.layers[i]) + self.bias[i])

        return np.argmax(data, axis=1)  # Use argmax cause the actions are discrete


# noinspection PyMethodMayBeStatic
class Agent:
    """
    Quality Diversity based agent. It maximizes exploration of a given behavior space and the performance of the
    found solutions.

    See: Quality Diversity: A New Frontier for Evolutionary Computation, Pugh et al. 2016

    Parameters
    ----------
    env : gym environment
        Environment with which to run the random shooting.
    epoch_output_dir : string
        Path of the output directory of the current epoch. Can be used to save
        results.
    epsilon : float
        Value of epsilon for the epsilon-greedy exploration. Set to None if
        not epsilon-greedy not used.
    gamma : float
        Discount factor.
    random_action : bool
        Whether to draw actions at random.
    seed : int
        Seed of the RNG.
    """

    def __init__(self, env, epoch_output_dir, epsilon=None, gamma=1,
                 random_action=False, seed=None):
        self.seed(seed)
        self.epoch_output_dir = epoch_output_dir
        self.env = env
        self.epsilon = epsilon
        self.gamma = gamma
        self.random_action = random_action
        self.grid_dimensions = [[-2, 2],  # x_{T/2}
                                [-2, 2],  # y_{T/2}
                                [-2, 2],  # x_T
                                [-2, 2]]  # y_T

        self.qd_mut_sigma = 0.05  # Sigma used by the mutation operators of QD
        self.qd_mut_prob = 0.8  # Probability of mutation for each gene in the genome

        self.archive = Grid(bins=50, dimensions=self.grid_dimensions)
        self.policy_id = 0
        self.controllers = NNController(input_size=6, output_size=N_ACTIONS,
                                        hidden_layers=2, hidden_layer_size=5)
        self.genome_size = self.controllers.genome_size

    def seed(self, seed):
        self.np_random = check_random_state(seed)
        return [seed]

    def bd_extractor(self, obs):
        """
        Uses the End effector pose as BD
        """
        return np.array([obs[:, :, 1] + obs[:, :, 3], -obs[:, :, 0] - obs[:, :, 2]]).swapaxes(0, 1).reshape(-1, 4)

    def _sample_from_archive(self, num_as):
        """
        Returns the sequences to evaluate.
        If num_as is bigger than the total archive_size then it returns all the seq in the archive
        If it is smaller, then it returns only the best sequences.
        """
        rewards = np.array(self.archive['reward'])

        # If we need more than the ones in the archive, return the whole archive
        if num_as >= self.archive.size:
            genomes = np.array(self.archive['genome'])

        # Otherwise only select best AS
        else:
            rewards = rewards - np.min(rewards)
            reward_sum = np.sum(rewards)
            if reward_sum > 0:
                idx = np.random.choice(self.archive.size, size=num_as, p=rewards / np.sum(rewards))
            else:
                idx = np.random.choice(self.archive.size, size=num_as)
            genomes = np.array(self.archive['genome'])[idx]  # Mutate these AS

        return genomes


    def act(self, observations, restart):
        """Return the action to take given the observations.

        Parameters
        ----------
        observations : array, shape (1, n_features)
            Observations
        restart : int
            Whether the observation is the first of an episode.

        Returns
        -------
        action : int
            The action to take.
        """
        if self.random_action:
            action = self.np_random.randint(N_ACTIONS)
        else:
            evaluated_action_seq = 0

            while evaluated_action_seq < MAX_EVALUATED_SEQUENCES:

                # Get sequences to Evaluate
                weights = self._get_weights(evaluated_action_seq)
                self.controllers.load_weights(weights)

                # duplicate observations and restart to leverage the vectorized
                # sampling of the model
                observation_vec = np.tile(observations, (len(weights) * N_PARTICLES, 1))
                restart_vec = np.array([restart] * len(weights) * N_PARTICLES)
                restart_vec = restart_vec.reshape(-1, 1)
                self.env.add_observations_to_history(observation_vec, restart_vec)

                all_returns = np.zeros(len(weights) * N_PARTICLES)
                sampled_obs = []
                obs = observation_vec

                # Evaluate controllers
                for horizon in range(PLANNING_HORIZON):
                    actions = self.controllers.get_action(obs).reshape(-1, 1)
                    obs, rewards, _, _ = self.env.step(actions)
                    if horizon in [int(PLANNING_HORIZON/2), PLANNING_HORIZON-1]:
                        sampled_obs.append(obs)
                    all_returns += (self.gamma ** horizon * rewards)

                # Extract behavior descriptors
                bds = self.bd_extractor(np.swapaxes(sampled_obs, 1, 0))
                ids = np.arange(self.policy_id, self.policy_id + len(bds))
                self.policy_id = max(ids) + 1

                # Store controllers parameters
                [self.archive.store({'genome': gen, 'bd': bd, 'id': aid, 'reward': rew})
                 for gen, bd, aid, rew in zip(weights, bds, ids, all_returns)]
                evaluated_action_seq += len(weights)

            # Select next action
            returns = self.archive['reward']  # TODO change this to calc pareto fronts between novelty and rew!
            returns_argmax = np.argmax(returns)
            best_policy = self.archive['genome'][returns_argmax]
            self.controllers.load_weights(np.array([best_policy]))
            action = self.controllers.get_action(np.array([observations]))[0]

        return action

    def get_pop_to_eval(self, evaluated_action_seq):
        """
        This function returns the weights of the controllers to evaluate
        # TODO check good value of sigma for mutations (from the paper  https://arxiv.org/abs/1706.01905)
        """
        # At the beginning we have nothing so we mutate the bootstrap net weights
        if self.archive.size == 0:
            # Just to be sure we do not return too many in case MAX_EVAL_SEQ is less than INIT_AS
            evaluated_genomes = min(INIT_AS, MAX_EVALUATED_SEQUENCES)
            genomes = self.np_random.normal(0, 1, size=(evaluated_genomes, self.genome_size))

        # At the new time step we reevaluate the archive first
        elif evaluated_action_seq == 0:
            # Just to be sure we do not return too many in case MAX_EVAL_SEQ is less than INIT_AS
            evaluated_genomes = min(INIT_AS, MAX_EVALUATED_SEQUENCES)
            genomes = self._sample_from_archive(evaluated_genomes - 1)

            # In case the archive got too small we add new AS by mutating the bootstrap net parameters
            if len(genomes) < evaluated_genomes:
                mutations = self.np_random.normal(0, 1, size=(evaluated_genomes - len(genomes), self.genome_size))
                genomes = np.r_[genomes, mutations]
            self.archive.reset()

        # In any other case we sample from the archive and mutate those weights,
        # and if not enough, we generate new ones
        else:
            # Just to be sure we do not return too many, in case AS_PER_GEN is higher than the remaining budget
            evaluated_genomes = min(AS_PER_GEN, MAX_EVALUATED_SEQUENCES - evaluated_action_seq)
            genomes = self._sample_from_archive(evaluated_genomes)
            # Mutate weights
            which_mutate = np.array(np.random.uniform(low=0, high=1, size=genomes.shape) < self.qd_mut_prob)
            mutations = np.random.normal(0, self.qd_mut_sigma, size=genomes.shape)  # Get mutations
            genomes[which_mutate] += mutations[which_mutate]  # Mutate genome

            # In case the archive got too small we add new AS by mutating the RL policy parameters
            if len(genomes) < evaluated_genomes:
                mutations = self.np_random.normal(0, 1, size=(evaluated_genomes - len(genomes), self.genome_size))
                genomes = np.r_[genomes, mutations]

        return genomes


