import os, sys
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(ROOT_DIR)
import numpy as np
from sklearn.utils.validation import check_random_state
from collections import deque
from scipy.special import expit
from copy import deepcopy

PLANNING_HORIZON = 10
INIT_AS = 100  # This is the population size
AS_PER_GEN = 20
MAX_EVALUATED_SEQUENCES = 1100  # Total number of policies evaluated at each planning step.
N_PARTICLES = 1
N_ACTIONS = 2


# ====================================================
class Grid:
    def __init__(self, bins, dimensions):
        """
        This data structure encodes the grid of MAP-Elites. For now only one agent per cell is supported
        :param bins: number of bins
        :param dimensions: [[minx, maxx], [miny, maxy], ...]
        :param name:
        """
        self.data = deque()  # This contains tuples with (agent_id, gt_bd bd, traj). The agent_id is necessary in case there are same bd-traj
        self.bins = bins
        self.dimensions = dimensions

        self.grid = self.init_grid()
        self.cell_lims = [np.linspace(dim[0], dim[1], num=self.bins+1) for dim in self.dimensions]

    def reset(self):
        """
        This function resets the archive to an empty state
        """
        del self.data
        self.data = deque()
        self.grid = self.init_grid()

    # ---------------------------------
    def __len__(self):
        """
        Returns the length of the archive
        """
        return self.size

    def __iter__(self):
        """
        Allows to directly iterate the pop.
        :return:
        """
        return self.data.__iter__()

    def __next__(self):
        """
        During iteration returns the next element of the iterator
        :return:
        """
        return self.data.__next__()

    def __getitem__(self, item):
        """
        Returns the asked item
        :param item: item to return. Can be an agent or a key
        :return: returns the corresponding agent or the column of the dict
        """
        if type(item) == str:
            try:
                return [np.array(x[item]) for x in self.data]  # Return list of those items
            except ValueError:
                raise ValueError('Wrong key given. Available: {} - Given: {}'.format(list(self.data[0].keys()), item))
        else:
            return self.data[item]

    @property
    def size(self):
        """
        Size of the archive
        """
        return len(self.data)
    # ---------------------------------

    def init_grid(self):
        """
        Initialized the grid with None values
        :return:
        """
        return np.full([self.bins] * len(self.dimensions), fill_value=None)

    def store(self, agent):
        """
        Store data in the archive as a list of: (genome, gt_bd, bd, traj).
        No need to store the ID given that we store the genome.
        Saving as a tuple instead of a dict makes the append operation faster

        It also checks if the grid cell is already occupied. In case it is, saves the one with the highest fitness

        :param agent: agent to store
        :return:
        """
        assert len(agent['bd']) == len(self.dimensions), \
            print('BD of wrong size. Given: {} - Expected: {}'.format(len(agent['bd']), len(self.dimensions)))
        # if agent['cost'] > 0:
        #     return False

        cell = self._find_cell(agent['bd'])
        if self.grid[cell] is None:
            # Add idx of agent in the datalist
            self.grid[cell] = agent['id']
            self.data.append(agent)
            return True
        else:
            # Find stored agent
            stored_agent_idx = self['id'].index(self.grid[cell])
            stored_agent = self.data[stored_agent_idx]
            # Only store if reward is higher
            if agent['reward'] >= stored_agent['reward']:
                del self.data[stored_agent_idx]
                self.grid[cell] = agent['id']
                self.data.append(agent)
                return True
        return False

    def _find_cell(self, bd):
        """
        This function finds in which cell the given BD belongs
        :param bd:
        :return:
        """
        cell_idx = []
        for dim_idx, dim in enumerate(self.dimensions):
            assert dim[1] >= bd[dim_idx] >= dim[0], \
                print("BD outside of grid. BD: {} - Bottom Limits: {} - Upper Limits: {}".format(bd, dim[0], dim[1]))

            # The max() is there so if we are at the bottom border the cell counts as the first
            # Remove 1 for indexing starting at 0
            cell_idx.append(max(np.argmax(self.cell_lims[dim_idx] >= bd[dim_idx]), 1) - 1)
        return tuple(cell_idx)


class MultiCostGrid(Grid):
    """
    This class implements a multicost grid.
    Given that the costs are integers (depending on how many unsafe states have been visited), this instantiated
    a separate grid for each cost level. The main grid is always the one with cost == 0.
    The others are only used when selecting the best action sequence when there are none with cost 0
    """
    def __init__(self, bins, dimensions):
        self.bins = bins
        self.dimensions = dimensions
        super().__init__(bins, dimensions)

        self.cost_grids = {0: None}

    def store(self, agent):
        """
        This function stores according to the cost
        """
        if agent['cost'] == 0:
            return super().store(agent)

        if not agent['cost'] in self.cost_grids:
            self.cost_grids[agent['cost']] = Grid(bins=self.bins, dimensions=self.dimensions)
        return self.cost_grids[agent['cost']].store(agent)

    def reset(self):
        super(MultiCostGrid, self).reset()
        del self.cost_grids
        self.cost_grids = {0: None}

    def get_best_policy(self):
        """
        Go through all the grids until it finds a good policy.
        """
        costs = list(self.cost_grids.keys())
        costs.sort()
        if self.size > 0:
            returns = self['reward']
            returns_argmax = np.argmax(returns)
            best_policy = self['genome'][returns_argmax]
            return best_policy

        for cost in costs[1:]:
            if self.cost_grids[cost].size > 0:
                returns = self.cost_grids[cost]['reward']
                returns_argmax = np.argmax(returns)
                best_policy = self.cost_grids[cost]['genome'][returns_argmax]
                return best_policy


class NNController:
    """
    This class contains the NNs evolved by QD that is then used to generate action sequences
    """
    def __init__(self, input_size, output_size, hidden_layers, hidden_layer_size):
        self.input_size = input_size
        self.output_size = output_size
        self.hidden_layer_size = hidden_layer_size
        self.hidden_layers = hidden_layers
        self.bias_size = self.hidden_layer_size

        if self.hidden_layers > 0:
            self.genome_size = self.input_size * self.hidden_layer_size + self.bias_size + \
                               (self.hidden_layer_size * self.hidden_layer_size + self.bias_size) * (
                                           self.hidden_layers - 1) + \
                               self.hidden_layer_size * self.output_size + self.output_size
        else:
            self.genome_size = self.input_size * self.output_size + self.output_size

    def load_weights(self, weights):
        """
        Loads the weights of all the controllers
        weights: array of shape [num_controllers, genome_len]
        """
        self.layers = []
        self.bias = []
        idx = 0
        if self.hidden_layers > 0:
            # Input to hidden
            start = idx
            end = start + self.input_size * self.hidden_layer_size
            idx = end
            layer = weights[:, start:end]
            self.layers.append(np.reshape(layer, (-1, self.input_size, self.hidden_layer_size)))
            self.bias.append(weights[:, idx:idx+self.bias_size])
            idx += self.bias_size
            # Hidden to hidden
            for k in range(self.hidden_layers - 1):
                start = idx
                end = start + self.hidden_layer_size * self.hidden_layer_size
                idx = end
                layer = weights[:, start:end]
                self.layers.append(np.reshape(layer, (-1, self.hidden_layer_size, self.hidden_layer_size)))
                self.bias.append(weights[:, idx:idx+self.bias_size])
                idx += self.bias_size
            # Hidden to output
            start = idx
            end = start + self.hidden_layer_size * self.output_size
            idx = end
            layer = weights[:, start:end]
            self.layers.append(np.reshape(layer, (-1, self.hidden_layer_size, self.output_size)))
            self.bias.append(weights[:, idx:idx+self.bias_size])
            idx += self.output_size
        else:
            idx = self.input_size * self.output_size
            layer = weights[:, :idx]
            self.layers.append(np.reshape(layer, (-1, self.input_size, self.output_size)))
            self.bias.append(weights[:, idx:])

        assert len(self.bias) == len(self.layers), f'Not enough bias or layers. Bias {len(self.bias)} - ' \
                                                   f'Layers {len(self.layers)}'

    def get_action(self, observations):
        """
        Get the actions from all the population due to the observations.
        All in a vectorized form
        """
        data = observations
        for i in range(len(self.layers)):
            # Doing the einsum this way is equivalent of doing this down here:
            # [d.dot(l) for d, l in zip(data, layer)]
            # expit is the Sigmoid activation [0, 1]
            data = expit(np.einsum('Bi,BiN ->BN', data, self.layers[i]) + self.bias[i])

        data = (data-0.5)*2  # Go in the [-1, 1] range
        return data
# ====================================================


# noinspection PyMethodMayBeStatic
class Agent:
    """
    Quality Diversity based agent. It maximizes exploration of a given behavior space and the performance of the
    found solutions.

    See: Quality Diversity: A New Frontier for Evolutionary Computation, Pugh et al. 2016

    Parameters
    ----------
    env : gym environment
        Environment with which to run the random shooting.
    epoch_output_dir : string
        Path of the output directory of the current epoch. Can be used to save
        results.
    epsilon : float
        Value of epsilon for the epsilon-greedy exploration. Set to None if
        not epsilon-greedy not used.
    gamma : float
        Discount factor.
    random_action : bool
        Whether to draw actions at random.
    seed : int
        Seed of the RNG.
    """

    def __init__(self, env, epoch_output_dir, epsilon=None, gamma=1,
                 random_action=False, seed=None):
        self.seed(seed)
        self.epoch_output_dir = epoch_output_dir
        self.env = env
        self.epsilon = epsilon
        self.gamma = gamma
        self.random_action = random_action
        self.qd_mut_prob = 0.8
        self.qd_mut_sigma = 0.05

        self.grid_dimensions = [[-2, 2]] * 2  # goal_compass

        self.archive = MultiCostGrid(bins=20, dimensions=self.grid_dimensions)
        self.policy_id = 0
        self.controllers = NNController(input_size=22, output_size=N_ACTIONS,
                                        hidden_layers=1, hidden_layer_size=3)

    def seed(self, seed):
        self.np_random = check_random_state(seed)
        return [seed]

    def bd_extractor(self, obs):
        """
        Use goal compass times goal distance to estimate relative position to the goal
        """
        goal_dist = obs[:, 5:6]
        goal_compass = obs[:, 3:5]
        bds = goal_dist * goal_compass
        return np.clip(bds, -2, 2)

    # def reset(self):
    #     self.archive = MultiCostGrid(bins=50, dimensions=self.grid_dimensions)
    #     self.policy_id = 0

    def _sample_from_archive(self, num_as):
        """
        Returns the sequences to evaluate.
        If num_as is bigger than the total archive_size then it returns all the seq in the archive
        If it is smaller, then it returns only the best sequences.

        The mutate flag selects if the returned sequences have to be mutated or not.
        """
        rewards = np.array(self.archive['reward'])

        if num_as >= self.archive.size:
            genomes = np.array(self.archive['genome'])
        else:  # Only select best AS
            rewards = rewards - np.min(rewards)
            reward_sum = np.sum(rewards)
            if reward_sum > 0:
                idx = np.random.choice(self.archive.size, size=num_as, p=rewards / np.sum(rewards))
            else:
                idx = np.random.choice(self.archive.size, size=num_as)
            genomes = np.array(self.archive['genome'])[idx]  # Mutate these AS

        return genomes

    def act(self, observations, restart):
        """Return the action to take given the observations.

        Parameters
        ----------
        observations : array, shape (1, n_features)
            Observations
        restart : int
            Whether the observation is the first of an episode.

        Returns
        -------
        action : int
            The action to take.
        """
        if hasattr(self.env, 'real_states_history'):
            self.env.real_states_history.append(observations)

        if self.random_action:
            action = self.np_random.uniform(low=-1, high=1, size=(N_ACTIONS,))
        else:
            evaluated_action_seq = 0

            # Search for policies through Map-Elites
            # ====================================================
            while evaluated_action_seq < MAX_EVALUATED_SEQUENCES:

                # Get sequences to Evaluate
                weights = self.get_pop_to_eval(evaluated_action_seq)
                self.controllers.load_weights(weights)

                # We make copies of the environments so to evaluate all the weights at once
                if not hasattr(self.env, 'model_env'):
                    envs = [deepcopy(self.env) for _ in range(len(weights))]

                # duplicate observations and restart to leverage the vectorized sampling of the model
                observation_vec = np.tile(observations, (len(weights) * N_PARTICLES, 1))
                restart_vec = np.array([restart] * len(weights) * N_PARTICLES)
                restart_vec = restart_vec.reshape(-1, 1)
                if hasattr(self.env, 'history'):
                    self.env.add_observations_to_history(observation_vec, restart_vec)

                all_returns = np.zeros(len(weights) * N_PARTICLES)
                safety_costs = np.zeros(len(weights) * N_PARTICLES)
                obs = observation_vec

                # Evaluate controllers
                # -------------------------
                for horizon in range(PLANNING_HORIZON):
                    contr_inputs = self.extract_controller_input(obs)
                    actions = self.controllers.get_action(contr_inputs)

                    if hasattr(self.env, 'model_env'):
                        obs, rewards, _, info = self.env.step(actions)

                    # Real env
                    # -------------------------
                    else:
                        obs = []
                        rewards = []
                        for action, env in zip(actions, envs):  # Each env is on its own planning horizon path
                            o, r, _, info = env.step(action[0])
                            obs.append(o)
                            rewards.append(r)
                        obs = np.array(obs)
                        rewards = np.array(rewards)
                        # info = {key: np.array([i[key] for i in info]) for key in info[0].keys()}
                    # -------------------------
                    sampled_obs = obs
                    all_returns += (self.gamma ** horizon * rewards)
                    safety_costs += info['cost']
                # -------------------------

                # Extract behavior descriptors
                bds = self.bd_extractor(sampled_obs)
                ids = np.arange(self.policy_id, self.policy_id + len(bds))
                self.policy_id = max(ids) + 1

                # Store controllers parameters # TODO for now just discard unsafe policies. Find other and smarter ways
                [self.archive.store({'genome': gen, 'bd': bd, 'id': aid, 'reward': rew, 'cost': cost})
                 for gen, bd, aid, rew, cost in zip(weights, bds, ids, all_returns, safety_costs)]
                evaluated_action_seq += len(weights)
            # ====================================================

            # Select next action
            # ====================================================
            best_policy = self.archive.get_best_policy()
            self.controllers.load_weights(np.array([best_policy]))
            contr_inputs = self.extract_controller_input(np.array([observations]))
            action = self.controllers.get_action(contr_inputs)[0]
            # ====================================================

        return action

    def get_pop_to_eval(self, evaluated_action_seq):
        """
        This function returns the weights of the controllers to evaluate
        """
        # At the beginning we have nothing so we return random weights
        if self.archive.size == 0:
            evaluated_genomes = min(INIT_AS, MAX_EVALUATED_SEQUENCES)
            genomes = self.np_random.normal(0, 1, size=(evaluated_genomes, self.controllers.genome_size))

        # At the new time step we reevaluate the archive first
        elif evaluated_action_seq == 0:
            genomes = self._sample_from_archive(MAX_EVALUATED_SEQUENCES)
            # evaluated_genomes = len(genomes)
            if len(genomes) < INIT_AS:  # In case the archive got too small we add new AS
                genomes = np.r_[genomes,  self.np_random.normal(0, 1, size=(INIT_AS - len(genomes),
                                                                            self.controllers.genome_size))]
            self.archive.reset()

        # In any other case we get new weights, and if not enough, we generate new ones
        else:
            evaluated_genomes = min(AS_PER_GEN, MAX_EVALUATED_SEQUENCES-evaluated_action_seq)
            genomes = self._sample_from_archive(evaluated_genomes)

            # Mutate weights
            which_mutate = np.array(np.random.uniform(low=0, high=1, size=genomes.shape) < self.qd_mut_prob)
            mutations = np.random.normal(0, self.qd_mut_sigma, size=genomes.shape)  # Get mutations
            genomes[which_mutate] += mutations[which_mutate]  # Mutate genome

            if len(genomes) < evaluated_genomes:  # In case the archive got too small we add new AS
                genomes = np.r_[genomes, self.np_random.normal(0, 1, size=(evaluated_genomes - len(genomes),
                                                                           self.controllers.genome_size))]

        return genomes

    def extract_controller_input(self, obs):
        """
        This function extracts the controller inputs from the obs
        """
        return obs