import numpy as np


class GVF:
    def __init__(self, aux_type, num_aux_tasks, config):
        self.aux_type = aux_type
        self.num_aux_tasks = num_aux_tasks
        self.obs_size = config['obs_size']
        self.feature_size = config['feature_size']
        self.batch_size = config['buffer_batch_size']
        self.gamma_list = np.array([0, 0.5, 0.75, 0.8, 0.9, 0.99])
        self.terminal_threshold = np.array([0.1, 0.4, 0.7, 1, 1.3, 1.6])
        self.gvf_max_or_min = np.array([1, -1])
        self.cumulant_net = config['cumulant_net']
        self.gammas = np.random.choice(self.gamma_list, self.num_aux_tasks)
        self.c = 100

        # gen and test params
        self.ordered = True
        self.random_replacement = False
        # 'random_goals', 'random_net', 'visited_goals', 'feature_attainment'
        self.generator = 'random_goals'
        self.env_name = config['env_name']
        self.pinball_random_goal_radius = config['pinball_random_goal_radius']
        self.pinball_random_goal_increment = config['pinball_random_goal_increment']

        self.step = 0
        self.test_start = config['test_freq']
        self.test_freq = config['test_freq']
        self.scores = np.zeros(self.num_aux_tasks)
        self.trace_param = -1
        self.replace_rate = config['replace_rate']
        self.replace_number = int(np.ceil(self.replace_rate * self.num_aux_tasks))
        self.age = np.zeros(self.num_aux_tasks)
        self.age_threshod = config['age_threshold']

        # === four rooms hand designed auxiliary tasks ===
        self.hallways = np.zeros((3, 49))
        self.hallways[0, 10] = 1
        self.hallways[1, 19] = 1
        self.hallways[2, 22] = 1

        self.corners = np.zeros((4, 49))
        self.corners[0, 42] = 1
        self.corners[1, 44] = 1
        self.corners[2, 28] = 1
        self.corners[3, 30] = 1

        # === maze hand designed auxiliary tasks ===
        # self.maze_hallways = np.zeros((4, 81))
        # self.maze_hallways[0, 43] = 1
        # self.maze_hallways[1, 39] = 1
        # self.maze_hallways[2, 8] = 1
        # self.maze_hallways[3, 80] = 1
        # self.maze_hallways[0, 38] = 1
        # self.maze_hallways[1, 23] = 1
        # self.maze_hallways[2, 4] = 1
        # self.maze_hallways[3, 58] = 1
        self.maze_hallways = np.zeros((2, 81))
        self.maze_hallways[0, 43] = 1
        self.maze_hallways[1, 39] = 1

        self.maze_corners = np.zeros((3, 81))
        self.maze_corners[0, 8] = 1
        self.maze_corners[1, 72] = 1
        self.maze_corners[2, 80] = 1

        # === pinball hand-designed auxiliary tasks ===
        self.pinball_bottleneck = np.array([[0.65,  0.35],
                                              [0.36, 0.35],
                                              [0.3, 0.65],
                                              [0.55, 0.65],
                                            ])

        self.pinball_corner = np.array([
                                                [0.85, 0.05],
                                                [0.9, 0.05],
                                                [0.95, 0.05],
                                                [0.9, 0.1],
                                                [0.95, 0.1]
                                              ])

        self.puddle_paths = np.array([
            [0.8, 0.6],
            [0.9, 0.6],
            [0.9, 0.8],
            [0.9, 0.9]
        ])


        self.visited_goals = np.zeros((self.num_aux_tasks, self.obs_size))
        self.random_goals = np.zeros((self.num_aux_tasks, self.obs_size))
        self.target_features = np.zeros((self.num_aux_tasks, self.feature_size))

        # === logging params ===
        self.retain_threshold = 50000
        if self.env_name == 'four_rooms':
            self.retain_threshold = 10000
        if self.env_name == 'maze':
            self.retain_threshold = 20000
        self.retained_auxiliary_tasks = np.arange(self.num_aux_tasks)
        self.auxiliary_tasks_ordered = np.arange(self.num_aux_tasks)

        if self.aux_type == 'random_goal' or self.generator == 'random_goals' or self.generator == 'visited_goals':
            if self.env_name in ['four_rooms', 'maze']:
                self.cells = np.arange(self.obs_size)
                if self.env_name == 'four_rooms':
                    walls = [3, 17, 18, 20, 21, 23, 24, 31, 38, 45]
                if self.env_name == 'maze':
                    walls = [10, 11, 12, 13, 14, 15, 16, 19, 25, 28, 30, 31, 32, 34, 37, 41,
                             46, 48, 49, 50, 52, 55, 61, 64, 65, 66, 67, 68, 69, 70]
                self.cells = np.setdiff1d(self.cells, walls)
            if self.env_name == 'pinball':
                X, Y = np.mgrid[0.1:1:self.pinball_random_goal_increment, 0.1:1:self.pinball_random_goal_increment]
                self.positions = np.vstack([X.ravel(), Y.ravel()])
                self.positions = self.positions.transpose()
                self.positions = self.remove_pinball_obstacle()
            if self.aux_type == 'random_goal' or self.generator == 'random_goals':
                self.random_goals = self.make_random_goals()
                # === four_rooms tester evaluation start ===
                # self.random_goals = np.zeros((4, 49))
                # self.random_goals[0, 10] = 1
                # self.random_goals[1, 19] = 1
                # self.random_goals[2, 42] = 1
                # self.random_goals[3, 44] = 1
                # === four_rooms tester evaluation end ===
                # self.random_goals = np.array([
                #                                 [0.85, 0.05],
                #                                 [0.9, 0.05],
                #                                 [0.95, 0.05],
                #                                 # [0.85, 0.1],
                #                                 [0.9, 0.1],
                #                                 [0.95, 0.1],
                #                                 # [0.85, 0.15],
                #                                 # [0.9, 0.15],
                #                                 # [0.95, 0.15],
                #                               ])
            if self.generator == 'visited_goals':
                self.visited_goals = self.make_random_goals()

        if self.aux_type == 'discovered':
            self.goal_pool = np.load('{}/retained_auxiliary_tasks.npy'.format(self.env_name))
            indices = np.random.choice(self.goal_pool.shape[0], self.num_aux_tasks, replace=False)
            self.discovered_goals = self.goal_pool[indices]
            print(self.discovered_goals)
            # self.print_goals()

    def cumulant(self, aux_id, batch_obs_tp1, **kwargs):
        output = None
        if self.aux_type == 'goal_reaching':
            output = np.ones(batch_obs_tp1.shape[0]) * -1
        elif self.aux_type == 'obs_diff_prediction' or self.aux_type == 'obs_diff_control':
            output = batch_obs_tp1[:, aux_id % self.obs_size] - kwargs['batch_obs_t'][:, aux_id % self.obs_size]
        elif self.aux_type == 'obs_max_min':
            output = batch_obs_tp1[:, aux_id % self.obs_size] * self.gvf_max_or_min[
                np.floor(aux_id / self.obs_size).astype('int') % self.gvf_max_or_min.shape[0]]
        elif self.aux_type == 'pixel_diff_control' or self.aux_type == 'pixel_diff_prediction':
            x = 4 * (aux_id // 6)
            y = 4 * (aux_id % 6)
            offset = 8
            output = np.sum(np.abs(batch_obs_tp1[:, x + offset:x + offset + 4, y + offset:y + offset + 4, :] -
                                   kwargs['batch_obs_t'][:, x + offset:x + offset + 4, y + offset:y + offset + 4, :]),
                            axis=(1, 2, 3)).astype(float)
        elif self.aux_type == 'pixel_control' or self.aux_type == 'pixel_prediction':
            x = 4 * (aux_id // 6)
            y = 4 * (aux_id % 6)
            offset = 8
            output = np.sum(np.abs(batch_obs_tp1[:, x + offset:x + offset + 4, y + offset:y + offset + 4, :]),
                            axis=(1, 2, 3)).astype(float)
        elif self.aux_type == 'pixel_random_noise':
            random_pixels = np.random.randint(256, size=(batch_obs_tp1.shape[0], 4, 4, 3))
            output = np.sum(np.abs(random_pixels),
                            axis=(1, 2, 3)).astype(float)
        elif self.aux_type == 'obs_prediction' or self.aux_type == 'obs_control':
            output = batch_obs_tp1[:, aux_id % self.obs_size]
        elif self.aux_type == 'random_noise':
            output = np.ones(batch_obs_tp1.shape[0]) * np.random.normal(0, aux_id + 1)
        elif self.aux_type == 'random_obs' or self.aux_type == 'random_pixel':
            f_tp1 = self.cumulant_net(batch_obs_tp1).cpu().detach().numpy()
            f_t = self.cumulant_net(kwargs['batch_obs_t']).cpu().detach().numpy()
            output = np.tanh(self.c * (f_tp1[:, aux_id] - f_t[:, aux_id]))
        elif self.aux_type == "random_goal":
            if self.env_name == 'puddle_world':
                output = kwargs['reward']
            else:
                output = np.ones(batch_obs_tp1.shape[0]) * -1
        elif self.aux_type == 'MSGT' or self.aux_type == 'GT':
            if self.generator == 'random_goals':
                if self.env_name == 'puddle_world':
                    output = kwargs['reward']
                else:
                    output = np.ones(batch_obs_tp1.shape[0]) * -1
            elif self.generator == 'visited_goals' or self.generator == 'feature_attainment':
                output = np.ones(batch_obs_tp1.shape[0]) * -1
            elif self.generator == 'random_net':
                f_tp1 = self.cumulant_net(batch_obs_tp1).cpu().detach().numpy()
                f_t = self.cumulant_net(kwargs['batch_obs_t']).cpu().detach().numpy()
                output = np.tanh(self.c * (f_tp1[:, aux_id] - f_t[:, aux_id]))
        elif self.aux_type == 'hallway' or self.aux_type == 'corner' or \
                self.aux_type == 'maze_hallway' or self.aux_type == 'maze_corner':
            output = np.ones(batch_obs_tp1.shape[0]) * -1
        elif self.aux_type == 'constant':
            output = np.ones(batch_obs_tp1.shape[0]) * aux_id
        elif self.aux_type == 'pinball_bottleneck':
            output = np.ones(batch_obs_tp1.shape[0]) * -1
        elif self.aux_type == 'pinball_corner':
            output = np.ones(batch_obs_tp1.shape[0]) * -1
        elif self.aux_type == 'puddle_path':
            output = kwargs['reward']
        elif self.aux_type == 'discovered':
            output = np.ones(batch_obs_tp1.shape[0]) * -1
        return output

    def continuation_function(self, aux_id, batch_obs_tp1, **kwargs):
        output = None
        if self.aux_type == 'goal_reaching':
            continuation = (-np.cos(batch_obs_tp1[:, 0]) - np.cos(batch_obs_tp1[:, 1] + batch_obs_tp1[:, 0])) <= \
                           self.terminal_threshold[aux_id]
            output = continuation.astype(float)
        elif self.aux_type == 'obs_max_min':
            output = self.gamma_list[np.floor(aux_id / (self.obs_size * self.gvf_max_or_min.shape[0])).astype('int')]
        elif self.aux_type == 'pixel_diff_control' or self.aux_type == 'pixel_diff_prediction':
            output = np.ones(batch_obs_tp1.shape[0]) * 0.9
        elif self.aux_type == 'pixel_control' or self.aux_type == 'pixel_prediction' or \
                self.aux_type == 'pixel_random_noise':
            output = np.ones(batch_obs_tp1.shape[0]) * 0.9
        elif self.aux_type == 'random_obs' or self.aux_type == 'random_pixel':
            output = np.ones(batch_obs_tp1.shape[0]) * 0.9  # self.gammas[aux_id]
        elif self.aux_type == "random_goal":
            is_random_goal = self.is_goal(batch_obs_tp1, aux_id)
            not_random_goal = 1 - is_random_goal
            output = np.ones(batch_obs_tp1.shape[0])
            output = output * not_random_goal
        elif self.aux_type == 'MSGT' or self.aux_type == 'GT':
            if self.generator == 'random_goals':
                is_goal = self.is_goal(batch_obs_tp1, aux_id)
                not_goal = 1 - is_goal
                output = np.ones(batch_obs_tp1.shape[0])
                output = output * not_goal
            if self.generator == 'visited_goals':
                is_goal = self.is_goal(batch_obs_tp1, aux_id)
                not_goal = 1 - is_goal
                output = np.ones(batch_obs_tp1.shape[0])
                output = output * not_goal
            if self.generator == 'feature_attainment':
                is_target_feature = self.is_target_feature(kwargs['batch_feature_tp1'], aux_id)
                not_target_feature = 1 - is_target_feature
                output = np.ones(batch_obs_tp1.shape[0])
                output = output * not_target_feature
            elif self.generator == 'random_net':
                output = np.ones(batch_obs_tp1.shape[0]) * 0.9  # self.gammas[aux_id]
        elif self.aux_type == 'discovered':
            is_goal = self.is_goal(batch_obs_tp1, aux_id)
            not_goal = 1 - is_goal
            output = np.ones(batch_obs_tp1.shape[0])
            output = output * not_goal
        elif self.aux_type == 'obs_prediction' or self.aux_type == 'obs_control':
            output = np.ones(batch_obs_tp1.shape[0]) * self.gamma_list[np.floor(aux_id / self.obs_size).astype('int')]
        elif self.aux_type == 'random_noise':
            output = np.zeros(batch_obs_tp1.shape[0])
        elif self.aux_type == 'obs_diff_prediction' or self.aux_type == 'obs_diff_control':
            output = np.ones(batch_obs_tp1.shape[0]) * self.gamma_list[np.floor(aux_id / self.obs_size).astype('int')]
        elif self.aux_type == 'hallway':
            is_hallway = np.dot(batch_obs_tp1, self.hallways[aux_id, :])
            not_hallway = 1 - is_hallway
            output = np.ones(batch_obs_tp1.shape[0]) * 1
            output = output * not_hallway
        elif self.aux_type == 'corner':
            is_corner = np.dot(batch_obs_tp1, self.corners[aux_id, :])
            not_corner = 1 - is_corner
            output = np.ones(batch_obs_tp1.shape[0]) * 1
            output = output * not_corner
        elif self.aux_type == 'constant':
            output = np.zeros(batch_obs_tp1.shape[0])
        elif self.aux_type == 'maze_hallway':
            is_maze_hallway = np.dot(batch_obs_tp1, self.maze_hallways[aux_id, :])
            not_maze_hallway = 1 - is_maze_hallway
            output = np.ones(batch_obs_tp1.shape[0]) * 1
            output = output * not_maze_hallway
        elif self.aux_type == 'maze_corner':
            is_maze_corner = np.dot(batch_obs_tp1, self.maze_corners[aux_id, :])
            not_maze_corner = 1 - is_maze_corner
            output = np.ones(batch_obs_tp1.shape[0]) * 1
            output = output * not_maze_corner
        elif self.aux_type == 'pinball_bottleneck':
            is_pinball_buttleneck = np.linalg.norm(batch_obs_tp1[:, :2] - self.pinball_bottleneck[aux_id, :],
                                     axis=1) < self.pinball_random_goal_radius
            not_pinball_buttleneck = 1 - is_pinball_buttleneck
            output = np.ones(batch_obs_tp1.shape[0])
            output = output * not_pinball_buttleneck
        elif self.aux_type == 'pinball_corner':
            is_pinball_corner = np.linalg.norm(batch_obs_tp1[:, :2] - self.pinball_corner[aux_id, :],
                                     axis=1) < self.pinball_random_goal_radius
            not_pinball_corner = 1 - is_pinball_corner
            output = np.ones(batch_obs_tp1.shape[0]) * 1
            output = output * not_pinball_corner
        elif self.aux_type == 'puddle_path':
            is_puddle_path = np.linalg.norm(batch_obs_tp1[:, :2] - self.puddle_paths[aux_id, :],
                                               axis=1) < 0.1
            not_puddle_path = 1 - is_puddle_path
            output = np.ones(batch_obs_tp1.shape[0]) * 1
            output = output * not_puddle_path
        return output

    def gen_and_test(self, eval_metric, direct=True):

        self.step += 1
        if direct:
            self.scores = self.trace_param * eval_metric + (1 - self.trace_param) * self.scores
        else:
            self.scores = eval_metric

        self.age = self.age + 1


        if self.step < self.test_start or self.step % self.test_freq != 0:
            return

        mature_ind = np.where(self.age > self.age_threshod)[0]
        if mature_ind.shape[0] == 0:
            return

        if self.random_replacement is False:
            if self.ordered:
                # self.print_goals()
                # print(self.random_goals)
                # indices = np.argpartition(self.scores[mature_ind],
                #                           self.replace_number)[
                #           :self.replace_number]

                self.auxiliary_tasks_ordered = np.argpartition(self.scores[mature_ind],
                                          self.replace_number)

                indices = self.auxiliary_tasks_ordered[
                          :self.replace_number]

                i_replace = mature_ind[indices]
                if self.generator == 'random_goals':
                    self.change_random_goals(i_replace)
                elif self.generator == 'random_net':
                    self.cumulant_net.reset_output_weights(i_replace)
                    self.gammas[i_replace] = np.random.choice(self.gamma_list, i_replace.shape[0])

                print(self.scores)
                print(i_replace)
                print(self.step)
                # print('')
                if self.step > self.retain_threshold:
                    self.retained_auxiliary_tasks = np.setdiff1d(self.retained_auxiliary_tasks, i_replace)
                    print('retained:')
                    print(self.retained_auxiliary_tasks)
                self.age[i_replace] = 0
            else:
                i_replace_candid = np.argpartition(self.scores, self.scores.shape[0] // 2)[:self.scores.shape[0] // 2]
                i_replace = np.random.choice(i_replace_candid, self.replace_number,
                                             replace=False)
                i_preserved = np.setdiff1d(np.arange(self.num_aux_tasks), i_replace)
                self.cumulant_net.reset_output_weights(i_replace)
                self.gammas[i_replace] = np.random.choice(self.gamma_list, i_replace.shape[0])
                self.scores[i_replace] = np.median(self.scores[i_preserved])
        else:
            i_replace = np.random.choice(self.num_aux_tasks, self.replace_number, replace=False)
            if self.generator == 'random_goals':
                self.change_random_goals(i_replace)
            elif self.generator == 'random_net':
                self.cumulant_net.reset_output_weights(i_replace)
                self.gammas[i_replace] = np.random.choice(self.gamma_list, i_replace.shape[0])
            self.age[i_replace] = 0

        return i_replace

    def make_random_goals(self):
        if self.env_name in ['four_rooms', 'maze']:
            random_goals = np.zeros((self.num_aux_tasks, self.obs_size))
            goal_indices = np.random.choice(self.cells, self.num_aux_tasks, replace=False)
            for i in np.arange(self.num_aux_tasks):
                random_goals[i, goal_indices[i]] = 1
            return random_goals
        if self.env_name == 'pinball':
            goal_indices = np.random.choice(self.positions.shape[0], self.num_aux_tasks, replace=False)
            return self.positions[goal_indices, :]
        if self.env_name == 'puddle_world':
            return np.random.random((self.num_aux_tasks, 2))


    def change_random_goals(self, indices):
        if self.env_name in ['four_rooms', 'maze']:
            self.random_goals[indices] = np.zeros((indices.shape[0], self.obs_size))
            goal_indices = np.random.choice(self.cells, indices.shape[0], replace=False)
            for i in np.arange(indices.shape[0]):
                self.random_goals[indices[i], goal_indices[i]] = 1
        if self.env_name == 'pinball':
            self.random_goals[indices] = np.zeros((indices.shape[0], 2))
            goal_indices = np.random.choice(self.positions.shape[0], indices.shape[0], replace=False)
            self.random_goals[indices] = self.positions[goal_indices, :]
        if self.env_name == 'puddle_world':
            self.random_goals[indices] = np.random.random((indices.shape[0], 2))

    def print_goals(self):
        goals = []
        if self.aux_type == 'MSGT':
            if self.generator == 'random_goals':
                for i in np.arange(self.num_aux_tasks):
                    goals.append(np.where(self.random_goals[i, :] == 1)[0][0])
            elif self.generator == 'visited_goals':
                for i in np.arange(self.num_aux_tasks):
                    goals.append(np.where(self.visited_goals[i, :] == 1)[0][0])
        elif self.aux_type == 'discovered':
            for i in np.arange(self.num_aux_tasks):
                goals.append(np.where(self.discovered_goals[i, :] == 1)[0][0])
        print(goals)

    def reset_visited_goals(self, indices, visited_goals, num_goals):
        self.visited_goals[indices[:num_goals]] = visited_goals

    def reset_target_features(self, indices, target_features, num_target_fearues):
        self.target_features[indices[:num_target_fearues]] = target_features

    def is_goal(self, batch_obs, aux_id):
        if self.env_name in ['four_rooms', 'maze']:
            tiled_goals = None
            if self.aux_type == 'MSGT' or 'GT':
                if self.generator == 'random_goals':
                    tiled_goals = np.tile(self.random_goals[aux_id, :], (self.batch_size, 1))
                elif self.generator == 'visited_goals':
                    tiled_goals = np.tile(self.visited_goals[aux_id, :], (self.batch_size, 1))
            elif self.aux_type == 'discovered':
                tiled_goals = np.tile(self.discovered_goals[aux_id, :], (self.batch_size, 1))
            remainder = batch_obs.reshape(self.batch_size, self.obs_size) - tiled_goals
            is_goal = np.all(np.isclose(remainder, 0), axis=1)
            is_goal = is_goal.astype(int)
            return is_goal
        if self.env_name == 'pinball':
            if self.aux_type == 'MSGT':
                if self.generator == 'random_goals':
                    is_goal = np.linalg.norm(batch_obs[:, :2] - self.random_goals[aux_id, :],
                                             axis=1) < self.pinball_random_goal_radius
                    return is_goal
                if self.generator == 'visited_goals':
                    is_goal = np.linalg.norm(batch_obs[:, :2] - self.visited_goals[aux_id, :],
                                             axis=1) < self.pinball_random_goal_radius
                    return is_goal
            elif self.aux_type == 'random_goal':
                is_goal = np.linalg.norm(batch_obs[:, :2] - self.random_goals[aux_id, :],
                                         axis=1) < self.pinball_random_goal_radius
                return is_goal
            elif self.aux_type == 'discovered':
                is_goal = np.linalg.norm(batch_obs[:, :2] - self.discovered_goals[aux_id, :],
                                         axis=1) < self.pinball_random_goal_radius
                return is_goal
        if self.env_name == 'puddle_world':
            if self.aux_type == 'MSGT':
                is_goal = np.linalg.norm(batch_obs[:, :2] - self.random_goals[aux_id, :],
                                         axis=1) < 0.1
                return is_goal
            elif self.aux_type == 'random_goal':
                is_goal = np.linalg.norm(batch_obs[:, :2] - self.random_goals[aux_id, :],
                                         axis=1) < 0.1
                return is_goal


    def is_target_feature(self, batch_feature, aux_id):
        tiled_target_features = np.tile(self.target_features[aux_id, :], (self.batch_size, 1))
        remainder = batch_feature.reshape(self.batch_size, self.feature_size) - tiled_target_features
        is_target_feature = np.all(np.isclose(remainder, 0), axis=1)
        is_target_feature = is_target_feature.astype(int)
        return is_target_feature

    def remove_pinball_obstacle(self):
        # min_x_obstacles = np.array([0.3, 0.0, 0.3, 0.63, 0.15, 0.7])
        # max_x_obstacles = np.array([0.5, 0.25, 0.8, 0.975, 0.6, 0.8])
        # min_y_obstacles = np.array([0.35, 0.27, 0.75, 0.3, 0.0, 0.025])
        # max_y_obstacles = np.array([0.7, 0.6, 0.9, 0.7, 0.3, 0.27])

        min_x_obstacles = np.array([0.35, 0.3, 0.0, 0.0, 0.3, 0.63, 0.3, 0.15, 0.7])
        max_x_obstacles = np.array([0.5, 0.45, 0.2, 0.25, 0.8, 0.975, 0.6, 0.3, 0.8])
        min_y_obstacles = np.array([0.35, 0.65, 0.27, 0.5, 0.75, 0.3, 0.0, 0.0, 0.025])
        max_y_obstacles = np.array([0.7, 0.7, 0.55, 0.6, 0.9, 0.7, 0.3, 0.2, 0.27])

        non_obstacle_positions = []
        for (x, y) in self.positions:
            is_obstacle = False
            for i in np.arange(min_y_obstacles.shape[0]):
                if self.is_obstacle(x, y, min_x_obstacles[i],
                                    max_x_obstacles[i], min_y_obstacles[i], max_y_obstacles[i]):
                    is_obstacle = True
            if is_obstacle == False:
                non_obstacle_positions.append((x, y))
        return np.asarray(non_obstacle_positions)


    def is_obstacle(self, x, y, x_min, x_max, y_min, y_max):
        if x < x_max and x > x_min and y < y_max and y > y_min:
            return True
        else:
            return False

    def set_discovered_goals(self, discovered_goals):
        self.discovered_goals = discovered_goals
        self.print_goals()

    def remove_best_aux(self):
        best_aux = np.argpartition(-self.scores, 2)[:2]
        print('best aux')
        print(self.random_goals)
        print(best_aux)
        self.change_random_goals(best_aux)
        print(self.random_goals)
        self.age[best_aux] = 0
