import copy
import warnings

import numpy as np
from scipy import stats
from joblib import Parallel, delayed

from sklearn.utils.validation import check_random_state

N_PARTICLES = 1
N_ACTION_SEQUENCES = 20
N_ELITES = 10  # number of samples used to update the distribution
PLANNING_HORIZON = 10
MAX_ITER = 5  # maximum iteration of the CEM
CEM_LEARNING_RATE = 0.1
N_DISCRETE_ACTIONS = np.array([3])
MIN_ELITES = 5

N_ACTION_FEATURES = 1

# distribution shape is [a_1, ..., a_p] where p is the PLANNING_HORIZON
# and each a_i of size N_ACTION_FEATURES
N_DIST_FEATURES = PLANNING_HORIZON * N_ACTION_FEATURES
ACTION_LOWER_BOUND = np.array([-1] * N_DIST_FEATURES)
ACTION_UPPER_BOUND = np.array([1] * N_DIST_FEATURES)


class Agent:
    """Cross entropy method (CEM) for maximization of the return.

    See "A Tutorial on the Cross-Entropy Method" by De Boer et al., 2005.

    We follow the code of the PETS paper, see kchua/handful-of-trials github
    repository, MIT license.

    Parameters
    ----------
    env : gym environment
        Environment with which to run the random shooting.
    epoch_output_dir : string
        Path of the output directory of the current epoch. Can be used to save
        results.
    epsilon : float
        Value of epsilon for the epsilon-greedy exploration. Set to None if
        not epsilon-greedy not used.
    gamma : float
        Discount factor.
    random_action : bool
        Whether to draw actions at random.
    seed : int
        Seed of the RNG.
    """

    def __init__(self, env, epoch_output_dir,
                 epsilon=None, gamma=1, beta=0.4, random_action=False,
                 seed=None):

        self.seed(seed)
        self.epoch_output_dir = epoch_output_dir
        self.env = env
        self.epsilon = epsilon
        self.gamma = gamma
        self.beta = beta

        self.random_action = random_action
        self.minimal_elites = MIN_ELITES
        self.init_mean = np.zeros(N_DIST_FEATURES)

    def seed(self, seed=None):
        # seed for numpy
        self.np_random = check_random_state(seed)
        return [seed]

    def reset(self):
        self.init_mean = np.zeros(N_DIST_FEATURES)

    def act(self, observations, restart):
        """Return the action to take given the observations.

        Parameters
        ----------
        observations : array, shape (1, n_features)
            Observations
        restart : int
            Whether the observation is the first of an episode.

        Returns
        -------
        action : int
            The action to take.
        """
        if self.random_action:
            return  self.np_random.randint(3)
        else:
            # initial distribution is discrete uniform distribution for each
            # action feature
            dist = 1 / np.repeat(N_DISCRETE_ACTIONS, N_DISCRETE_ACTIONS)
            dists = np.tile(dist, PLANNING_HORIZON)

            iteration = 0
            best_return = -np.inf
            ind_dist = np.cumsum(np.r_[0, N_DISCRETE_ACTIONS])
            observation_vec = np.tile(
                    observations, (N_ACTION_SEQUENCES * N_PARTICLES, 1))
            restart_vec = np.array(
                    [restart] * N_ACTION_SEQUENCES * N_PARTICLES)
            restart_vec = restart_vec.reshape(-1, 1)


            while iteration < MAX_ITER:
                self.env.add_observations_to_history(observation_vec, restart_vec)
                # generate random sequences
                action_sequences = np.zeros((
                    N_ACTION_SEQUENCES, PLANNING_HORIZON, N_ACTION_FEATURES))
                for p in range(PLANNING_HORIZON):
                    start_ind_p = p * np.sum(N_DISCRETE_ACTIONS)
                    for d in range(N_ACTION_FEATURES):
                        dist_slice = slice(
                            start_ind_p + ind_dist[d],
                            start_ind_p + ind_dist[d + 1])
                        action_sequences[:, p, d] = self.np_random.choice(
                            N_DISCRETE_ACTIONS[d],
                            size=(N_ACTION_SEQUENCES, 1),
                            p=dists[dist_slice]).ravel()

                # compute the returns for each action sequence
                action_sequences_reps = np.repeat(
                    action_sequences, N_PARTICLES, axis=0)
                all_returns = np.zeros(N_ACTION_SEQUENCES * N_PARTICLES)
                safety_costs = np.zeros(N_ACTION_SEQUENCES * N_PARTICLES)

                for horizon in range(PLANNING_HORIZON):
                    actions = action_sequences_reps[:, horizon]
                    _, rewards, _, info = self.env.step(actions)
                    all_returns += (self.gamma ** horizon * rewards)
                    safety_costs += (self.beta ** horizon * info['cost'])

                all_returns = all_returns.reshape(N_ACTION_SEQUENCES, N_PARTICLES)
                safety_costs = safety_costs.reshape(N_ACTION_SEQUENCES, N_PARTICLES)

                returns = np.mean(all_returns, axis=1)
                cost =  np.max(safety_costs, axis=1) # We have to be safe so we take the max cost among particles
                feasible_idx = (cost == 0)
                feasible_action_sequences_reward = returns[feasible_idx]
                feasible_action_sequences = action_sequences[feasible_idx] # [num, sol_dim]


                # update distribution from elite samples
                feasible_num = feasible_action_sequences.shape[0]

                if feasible_num < self.minimal_elites:
                    n = self.minimal_elites - feasible_num
                    ind_elite = np.argsort(cost)
                    sub_elites = action_sequences[ind_elite][:n]
                    elite_action_sequences = np.concatenate((sub_elites, feasible_action_sequences), axis=0)
                else:
                    ind_elite = np.argsort(feasible_action_sequences_reward)[::-1]
                    elite_action_sequences = feasible_action_sequences[ind_elite][:N_ELITES]

                new_dists = np.zeros_like(dists)
                for p in range(PLANNING_HORIZON):
                    start_ind_p = p * np.sum(N_DISCRETE_ACTIONS)
                    for d in range(N_ACTION_FEATURES):
                        dist_slice = slice(
                            start_ind_p + ind_dist[d],
                            start_ind_p + ind_dist[d + 1])

                        feature_dist = new_dists[dist_slice]

                        n_discrete = N_DISCRETE_ACTIONS[d]
                        for n in range(n_discrete):
                            feature_dist[n] = np.mean(
                                elite_action_sequences[:, p, d] == n)

                dists = (CEM_LEARNING_RATE * dists +
                        (1 - CEM_LEARNING_RATE) * new_dists)

                if returns[ind_elite[0]] > best_return:
                    best_return = returns[ind_elite[0]]
                    best_action_sequence = elite_action_sequences[0]

                iteration += 1

        return int(best_action_sequence[0])
