import os, sys
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(ROOT_DIR)
import numpy as np
from sklearn.utils.validation import check_random_state
from collections import deque
from scipy.special import expit
from copy import deepcopy

PLANNING_HORIZON = 10
INIT_AS = 100  # This is the population size
AS_PER_GEN = 20
MAX_EVALUATED_SEQUENCES = 1100  # Total number of policies evaluated at each planning step.
N_PARTICLES = 1
N_ACTIONS = 2


# ====================================================
class Grid:
    def __init__(self, bins, dimensions):
        """
        This data structure encodes the grid of MAP-Elites. For now only one agent per cell is supported
        :param bins: number of bins
        :param dimensions: [[minx, maxx], [miny, maxy], ...]
        :param name:
        """
        self.data = deque()  # This contains tuples with (agent_id, gt_bd bd, traj). The agent_id is necessary in case there are same bd-traj
        self.bins = bins
        self.dimensions = dimensions

        self.grid = self.init_grid()
        self.cell_lims = [np.linspace(dim[0], dim[1], num=self.bins+1) for dim in self.dimensions]

    def reset(self):
        """
        This function resets the archive to an empty state
        """
        del self.data
        self.data = deque()
        self.grid = self.init_grid()

    # ---------------------------------
    def __len__(self):
        """
        Returns the length of the archive
        """
        return self.size

    def __iter__(self):
        """
        Allows to directly iterate the pop.
        :return:
        """
        return self.data.__iter__()

    def __next__(self):
        """
        During iteration returns the next element of the iterator
        :return:
        """
        return self.data.__next__()

    def __getitem__(self, item):
        """
        Returns the asked item
        :param item: item to return. Can be an agent or a key
        :return: returns the corresponding agent or the column of the dict
        """
        if type(item) == str:
            try:
                return [np.array(x[item]) for x in self.data]  # Return list of those items
            except ValueError:
                raise ValueError('Wrong key given. Available: {} - Given: {}'.format(list(self.data[0].keys()), item))
        else:
            return self.data[item]

    @property
    def size(self):
        """
        Size of the archive
        """
        return len(self.data)
    # ---------------------------------

    def init_grid(self):
        """
        Initialized the grid with None values
        :return:
        """
        return np.full([self.bins] * len(self.dimensions), fill_value=None)

    def store(self, agent):
        """
        Store data in the archive as a list of: (genome, gt_bd, bd, traj).
        No need to store the ID given that we store the genome.
        Saving as a tuple instead of a dict makes the append operation faster

        It also checks if the grid cell is already occupied. In case it is, saves the one with the highest fitness

        :param agent: agent to store
        :return:
        """
        assert len(agent['bd']) == len(self.dimensions), \
            print('BD of wrong size. Given: {} - Expected: {}'.format(len(agent['bd']), len(self.dimensions)))

        cell = self._find_cell(agent['bd'])
        # If cell is empty, store the agent
        if self.grid[cell] is None:
            # Add idx of agent in the datalist
            self.grid[cell] = agent['id']
            self.data.append(agent)
            return True
        else:
            # Find stored agent
            stored_agent_idx = self['id'].index(self.grid[cell])
            stored_agent = self.data[stored_agent_idx]
            # If cell is full, store the agent with lowest cost
            if agent['cost'] < stored_agent['cost']:
                del self.data[stored_agent_idx]
                self.grid[cell] = agent['id']
                self.data.append(agent)
                return True
            # If cost similar, store solution with highest reward
            elif agent['cost'] == stored_agent['cost']:
                if agent['reward'] >= stored_agent['reward']:
                    del self.data[stored_agent_idx]
                    self.grid[cell] = agent['id']
                    self.data.append(agent)
                    return True
            else: # Otherwise keep old agent
                return False
        return False

    def _find_cell(self, bd):
        """
        This function finds in which cell the given BD belongs
        :param bd:
        :return:
        """
        cell_idx = []
        for dim_idx, dim in enumerate(self.dimensions):
            assert dim[1] >= bd[dim_idx] >= dim[0], \
                print("BD outside of grid. BD: {} - Bottom Limits: {} - Upper Limits: {}".format(bd, dim[0], dim[1]))

            # The max() is there so if we are at the bottom border the cell counts as the first
            # Remove 1 for indexing starting at 0
            cell_idx.append(max(np.argmax(self.cell_lims[dim_idx] >= bd[dim_idx]), 1) - 1)
        return tuple(cell_idx)


class NNController:
    """
    This class contains the NNs evolved by QD that is then used to generate action sequences
    """
    def __init__(self, input_size, output_size, hidden_layers, hidden_layer_size):
        self.input_size = input_size
        self.output_size = output_size
        self.hidden_layer_size = hidden_layer_size
        self.hidden_layers = hidden_layers
        self.bias_size = self.hidden_layer_size

        if self.hidden_layers > 0:
            self.genome_size = self.input_size * self.hidden_layer_size + self.bias_size + \
                               (self.hidden_layer_size * self.hidden_layer_size + self.bias_size) * (
                                           self.hidden_layers - 1) + \
                               self.hidden_layer_size * self.output_size + self.output_size
        else:
            self.genome_size = self.input_size * self.output_size + self.output_size

    def load_weights(self, weights):
        """
        Loads the weights of all the controllers
        weights: array of shape [num_controllers, genome_len]
        """
        self.layers = []
        self.bias = []
        idx = 0
        if self.hidden_layers > 0:
            # Input to hidden
            start = idx
            end = start + self.input_size * self.hidden_layer_size
            idx = end
            layer = weights[:, start:end]
            self.layers.append(np.reshape(layer, (-1, self.input_size, self.hidden_layer_size)))
            self.bias.append(weights[:, idx:idx+self.bias_size])
            idx += self.bias_size
            # Hidden to hidden
            for k in range(self.hidden_layers - 1):
                start = idx
                end = start + self.hidden_layer_size * self.hidden_layer_size
                idx = end
                layer = weights[:, start:end]
                self.layers.append(np.reshape(layer, (-1, self.hidden_layer_size, self.hidden_layer_size)))
                self.bias.append(weights[:, idx:idx+self.bias_size])
                idx += self.bias_size
            # Hidden to output
            start = idx
            end = start + self.hidden_layer_size * self.output_size
            idx = end
            layer = weights[:, start:end]
            self.layers.append(np.reshape(layer, (-1, self.hidden_layer_size, self.output_size)))
            self.bias.append(weights[:, idx:idx+self.bias_size])
            idx += self.output_size
        else:
            idx = self.input_size * self.output_size
            layer = weights[:, :idx]
            self.layers.append(np.reshape(layer, (-1, self.input_size, self.output_size)))
            self.bias.append(weights[:, idx:])

        assert len(self.bias) == len(self.layers), f'Not enough bias or layers. Bias {len(self.bias)} - ' \
                                                   f'Layers {len(self.layers)}'

    def get_action(self, observations):
        """
        Get the actions from all the population due to the observations.
        All in a vectorized form
        """
        data = observations
        for i in range(len(self.layers)):
            # Doing the einsum this way is equivalent of doing this down here:
            # [d.dot(l) for d, l in zip(data, layer)]
            # expit is the Sigmoid activation [0, 1]
            data = expit(np.einsum('Bi,BiN ->BN', data, self.layers[i]) + self.bias[i])

        data = (data-0.5)*2  # Go in the [-1, 1] range
        return data
# ====================================================


# noinspection PyMethodMayBeStatic
class Agent:
    """
    Quality Diversity based agent. It maximizes exploration of a given behavior space and the performance of the
    found solutions.

    See: Quality Diversity: A New Frontier for Evolutionary Computation, Pugh et al. 2016

    Parameters
    ----------
    env : gym environment
        Environment with which to run the random shooting.
    epoch_output_dir : string
        Path of the output directory of the current epoch. Can be used to save
        results.
    epsilon : float
        Value of epsilon for the epsilon-greedy exploration. Set to None if
        not epsilon-greedy not used.
    gamma : float
        Discount factor.
    random_action : bool
        Whether to draw actions at random.
    seed : int
        Seed of the RNG.
    """

    def __init__(self, env, epoch_output_dir, epsilon=None, gamma=1,
                 random_action=False, seed=None):
        self.seed(seed)
        self.epoch_output_dir = epoch_output_dir
        self.env = env
        self.epsilon = epsilon
        self.gamma = gamma
        self.random_action = random_action
        self.qd_mut_prob = 0.8
        self.qd_mut_sigma = 0.05

        self.grid_dimensions = [[-2, 2]] * 2  # goal_compass

        self.archive = Grid(bins=20, dimensions=self.grid_dimensions)
        self.policy_id = 0
        self.controllers = NNController(input_size=22, output_size=N_ACTIONS,
                                        hidden_layers=1, hidden_layer_size=3)

    def seed(self, seed):
        self.np_random = check_random_state(seed)
        return [seed]

    def bd_extractor(self, obs):
        """
        Use goal compass times goal distance to estimate relative position to the goal
        """
        goal_dist = obs[:, 5:6]
        goal_compass = obs[:, 3:5]
        bds = goal_dist * goal_compass
        return np.clip(bds, -2, 2)

    # def reset(self):
    #     self.archive = Grid(bins=50, dimensions=self.grid_dimensions)
    #     self.policy_id = 0

    def _sample_from_archive(self, num_as):
        """
        Returns the sequences to evaluate.
        If num_as is bigger than the total archive_size then it returns all the seq in the archive
        If it is smaller, then it returns only the best sequences.

        The mutate flag selects if the returned sequences have to be mutated or not.
        """
        # If we need more than the size of the archive, just return all the policies, safe or not
        if num_as >= self.archive.size:
            genomes = np.array(self.archive['genome'])
            return genomes

        # In the other cases we check
        rewards = np.array(self.archive['reward'])
        costs = -np.array(self.archive['cost'])  # The minus is because we want to minimize the cost
        safe_idx = np.arange(self.archive.size)[costs == 0]  # Only selects the safe policies

        if len(safe_idx) >= num_as:  # If we have enough safe policies, sample among them weighting by their rewards
            rewards = rewards[safe_idx]
            rewards = rewards - np.min(rewards)
            reward_sum = np.sum(rewards)
            if reward_sum > 0:
                idx = np.random.choice(safe_idx, size=num_as, p=rewards / np.sum(rewards))
            else:
                idx = np.random.choice(safe_idx, size=num_as)
            genomes = np.array(self.archive['genome'])[idx]  # Mutate these AS
        else:  # Otherwise return the safe ones and the best from the pareto fronts
            # We return the safe ones anyway
            chosen_idx = safe_idx
            unsafe_idx = np.arange(self.archive.size)[costs != 0]

            # In each front are contained the idx of the elements in unsafe_idx
            fronts = self.fast_non_dominated_sort(rewards[unsafe_idx], costs[unsafe_idx])
            for front in fronts:
                chosen_idx = np.r_[chosen_idx, unsafe_idx[front]]
                if len(chosen_idx) >= num_as:
                    break
            chosen_idx = chosen_idx[:num_as]
            genomes = np.array(self.archive['genome'])[chosen_idx]

        return genomes

    def act(self, observations, restart):
        """Return the action to take given the observations.

        Parameters
        ----------
        observations : array, shape (1, n_features)
            Observations
        restart : int
            Whether the observation is the first of an episode.

        Returns
        -------
        action : int
            The action to take.
        """
        if hasattr(self.env, 'real_states_history'):
            self.env.real_states_history.append(observations)

        if self.random_action:
            action = self.np_random.uniform(low=-1, high=1, size=(N_ACTIONS,))
        else:
            evaluated_action_seq = 0

            # Search for policies through Map-Elites
            # ====================================================
            while evaluated_action_seq < MAX_EVALUATED_SEQUENCES:

                # Get sequences to Evaluate
                weights = self.get_pop_to_eval(evaluated_action_seq)
                self.controllers.load_weights(weights)

                # We make copies of the environments so to evaluate all the weights at once
                if not hasattr(self.env, 'model_env'):
                    envs = [deepcopy(self.env) for _ in range(len(weights))]

                # duplicate observations and restart to leverage the vectorized sampling of the model
                observation_vec = np.tile(observations, (len(weights) * N_PARTICLES, 1))
                restart_vec = np.array([restart] * len(weights) * N_PARTICLES)
                restart_vec = restart_vec.reshape(-1, 1)
                if hasattr(self.env, 'history'):
                    self.env.add_observations_to_history(observation_vec, restart_vec)

                all_returns = np.zeros(len(weights) * N_PARTICLES)
                safety_costs = np.zeros(len(weights) * N_PARTICLES)
                obs = observation_vec

                # Evaluate controllers
                # -------------------------
                for horizon in range(PLANNING_HORIZON):
                    contr_inputs = self.extract_controller_input(obs)
                    actions = self.controllers.get_action(contr_inputs)

                    if hasattr(self.env, 'model_env'):
                        obs, rewards, _, info = self.env.step(actions)

                    # Real env
                    # -------------------------
                    else:
                        obs = []
                        rewards = []
                        for action, env in zip(actions, envs):  # Each env is on its own planning horizon path
                            o, r, _, info = env.step(action[0])
                            obs.append(o)
                            rewards.append(r)
                        obs = np.array(obs)
                        rewards = np.array(rewards)
                        # info = {key: np.array([i[key] for i in info]) for key in info[0].keys()}
                    # -------------------------
                    sampled_obs = obs
                    all_returns += (self.gamma ** horizon * rewards)
                    safety_costs += (self.gamma ** horizon * info['cost'])
                # -------------------------

                # Extract behavior descriptors
                bds = self.bd_extractor(sampled_obs)
                ids = np.arange(self.policy_id, self.policy_id + len(bds))
                self.policy_id = max(ids) + 1

                # Store controllers parameters
                [self.archive.store({'genome': gen, 'bd': bd, 'id': aid, 'reward': rew, 'cost': cost})
                 for gen, bd, aid, rew, cost in zip(weights, bds, ids, all_returns, safety_costs)]
                evaluated_action_seq += len(weights)
            # ====================================================

            # Select next action
            # ====================================================
            returns = np.array(self.archive['reward'])
            costs = -np.array(self.archive['cost'])
            # Select the pareto front and choose the one with lowest cost in the front
            front = self.fast_non_dominated_sort(costs, returns, best=True)
            if len(front) == 1:  # Just return that one
                best_policy_idx = front[0]
            else:  # Otherwise return the one with the lowest cost
                best_policy_idx = front[np.argmax(costs[front])]

            best_policy = self.archive['genome'][best_policy_idx]
            self.controllers.load_weights(np.array([best_policy]))
            contr_inputs = self.extract_controller_input(np.array([observations]))
            action = self.controllers.get_action(contr_inputs)[0]
            # ====================================================

        return action

    def get_pop_to_eval(self, evaluated_action_seq):
        """
        This function returns the weights of the controllers to evaluate
        """
        # At the beginning we have nothing so we return random weights
        if self.archive.size == 0:
            evaluated_genomes = min(INIT_AS, MAX_EVALUATED_SEQUENCES)
            genomes = self.np_random.normal(0, 1, size=(evaluated_genomes, self.controllers.genome_size))

        # At the new time step we reevaluate the archive first
        elif evaluated_action_seq == 0:
            genomes = self._sample_from_archive(MAX_EVALUATED_SEQUENCES)
            # evaluated_genomes = len(genomes)
            if len(genomes) < INIT_AS:  # In case the archive got too small we add new AS
                genomes = np.r_[genomes,  self.np_random.normal(0, 1, size=(INIT_AS - len(genomes),
                                                                            self.controllers.genome_size))]
            self.archive.reset()

        # In any other case we get new weights, and if not enough, we generate new ones
        else:
            evaluated_genomes = min(AS_PER_GEN, MAX_EVALUATED_SEQUENCES-evaluated_action_seq)
            genomes = self._sample_from_archive(evaluated_genomes)

            # Mutate weights
            which_mutate = np.array(np.random.uniform(low=0, high=1, size=genomes.shape) < self.qd_mut_prob)
            mutations = np.random.normal(0, self.qd_mut_sigma, size=genomes.shape)  # Get mutations
            genomes[which_mutate] += mutations[which_mutate]  # Mutate genome

            if len(genomes) < evaluated_genomes:  # In case the archive got too small we add new AS
                genomes = np.r_[genomes, self.np_random.normal(0, 1, size=(evaluated_genomes - len(genomes),
                                                                           self.controllers.genome_size))]

        return genomes

    def extract_controller_input(self, obs):
        """
        This function extracts the controller inputs from the obs
        """
        return obs

    def fast_non_dominated_sort(self, values1, values2, best=False):
        """
        This function sorts the non dominated elements according to the values of the 2 objectives.
        Taken from https://github.com/haris989/NSGA-II
        :param values1: Values of first obj
        :param values2: Values of second obj
        :param best: if True only returns the best front without calculating the others
        :return: Sorted list of indexes
        """
        S = [[]] * len(values1)
        front = [[]]
        n = [0] * len(values1)
        rank = [0] * len(values1)

        for p in range(len(values1)):
            S[p] = []
            n[p] = 0
            for q in range(len(values1)):
                if (values1[p] > values1[q] and values2[p] > values2[q]) or (
                        values1[p] >= values1[q] and values2[p] > values2[q]) or (
                        values1[p] > values1[q] and values2[p] >= values2[q]):
                    if q not in S[p]: S[p].append(q)
                elif (values1[q] > values1[p] and values2[q] > values2[p]) or (
                        values1[q] >= values1[p] and values2[q] > values2[p]) or (
                        values1[q] > values1[p] and values2[q] >= values2[p]):
                    n[p] += 1
            if n[p] == 0:
                rank[p] = 0
                if p not in front[0]:
                    front[0].append(p)

        if best:
            return front[0]

        i = 0
        while front[i]:
            Q = []
            for p in front[i]:
                for q in S[p]:
                    n[q] -= 1
                    if n[q] == 0:
                        rank[q] = i + 1
                        if q not in Q:
                            Q.append(q)
            i += 1
            front.append(Q)
        del front[-1]
        return front