from pymoo.algorithms.moo.nsga2 import NSGA2
from pymoo.algorithms.moo.rnsga2 import RNSGA2
from pymoo.algorithms.moo.age import AGEMOEA
from pymoo.algorithms.moo.ctaea import CTAEA
from pymoo.algorithms.moo.moead import MOEAD
from pymoo.algorithms.moo.unsga3 import UNSGA3
from pymoo.factory import get_crossover, get_mutation, get_sampling
from pymoo.factory import get_reference_directions
from pymoo.core.problem import Problem
from pymoo.optimize import minimize
import time
import numpy as np
import autograd.numpy as anp



class SearchAlgoManager:
    '''
    Manages the search parameters for multi-objective search (2 objectives).

    Parameters
    ----------
    algorithm : string
        Define a multi-objective search algorithm.
    seed : int
        Seed value for Pymoo search.
    verbose : Boolean
        Verbosity option
    engine : string
        Support different engine types (e.g. pymoo, optuna, nni, deap)
    '''

    def __init__(self, algorithm='nsga2', seed=0, verbose=False, engine='pymoo'):
        self.algorithm = algorithm
        self.seed = seed
        self.verbose = verbose
        self.engine = engine
        if self.algorithm == 'nsga2':
            self.algorithm_def = self.configure_nsga2()
            self.engine = 'pymoo'
        elif self.algorithm == 'rnsga2':
            self.algorithm_def = self.configure_rnsga2()
            self.engine = 'pymoo'
        elif self.algorithm == 'age':
            self.algorithm_def = self.configure_age()
            self.engine = 'pymoo'
        elif self.algorithm == 'ctaea':
            self.algorithm_def = self.configure_ctaea()
            self.engine = 'pymoo'
        elif self.algorithm == 'moead':
            self.algorithm_def = self.configure_moead()
            self.engine = 'pymoo'
        elif self.algorithm == 'unsga3':
            self.algorithm_def = self.configure_unsga3()
            self.engine = 'pymoo'
        elif self.algorithm == 'random':
            pass
            #self.algorithm_def = self.configure_random()
        elif self.algorithm == 'motpes':
            self.algorithm_def = self.configure_motpes()

        else:
            print('[Warning] algorithm "{}" not implemented.'.format(self.algorithm))
            raise NotImplementedError

    def configure_random(self, num_evals=1000, seed=0):
        self.engine = 'random'
        self.seed = seed
        self.num_evals = num_evals

        raise NotImplementedError

    def configure_motpes(self, num_evals=1000, seed=0):
        self.engine = 'optuna'
        self.seed = seed
        self.num_evals = num_evals

        raise NotImplementedError

    def configure_nsga2(self, population=50, num_evals=1000, warm_pop=None,
                        crossover_prob=0.9, crossover_eta=15.0,
                        mutation_prob=0.02, mutation_eta=20.0):
        self.engine = 'pymoo'
        self.n_gens = num_evals/population

        print('[Info] Configuring NSGA-II algorithm.')

        if type(warm_pop) == 'numpy.ndarray':
            print('[Info] Using warm start population')
            sample_strategy = warm_pop
        else:
            sample_strategy = get_sampling("int_lhs")

        self.algorithm_def = NSGA2(
            pop_size=population,
            sampling=sample_strategy,
            crossover=get_crossover("int_sbx", prob=crossover_prob, eta=crossover_eta),
            mutation=get_mutation("int_pm", prob=mutation_prob, eta=mutation_eta),
            eliminate_duplicates=True)

    def configure_unsga3(self, population=50, num_evals=1000,
                        ref_dirs=None, warm_pop=None,
                        crossover_prob=1.0, crossover_eta=30.0,
                        mutation_prob=0.02, mutation_eta=20.0):
        self.engine = 'pymoo'
        self.n_gens = num_evals/population

        ref_dirs = get_reference_directions("energy", 2, 20, seed=0)
        ref_dirs = ref_dirs.astype('float64')

        print('[Info] Configuring UNSGA-III algorithm.')

        if type(warm_pop) == 'numpy.ndarray':
            print('[Info] Using warm start population')
            sample_strategy = warm_pop
        else:
            sample_strategy = get_sampling("int_lhs")

        self.algorithm_def = UNSGA3(
            ref_dirs=ref_dirs,
            pop_size=population,
            sampling=sample_strategy,
            crossover=get_crossover("int_sbx", prob=crossover_prob, eta=crossover_eta),
            mutation=get_mutation("int_pm", prob=mutation_prob, eta=mutation_eta),
            eliminate_duplicates=True)

        if num_evals % population != 0:
            print('[Warning] Number of samples not divisible by population size')

    def configure_rnsga2(self, population=50, num_evals=1000, warm_pop=None,
                         ref_points=[[0, 0]],
                         crossover_prob=0.9, crossover_eta=15.0,
                         mutation_prob=0.02, mutation_eta=20.0):
        self.engine = 'pymoo'
        self.n_gens = num_evals/population

        print('[Info] Configuring RNSGA-II algorithm.')

        if type(warm_pop) == 'numpy.ndarray':
            print('[Info] Using warm start population')
            sample_strategy = warm_pop
        else:
            sample_strategy = get_sampling("int_lhs")

        print('[Info] Reference points for RNSGA-II are: {}'.format(ref_points))
        reference_points = np.array(ref_points)  # lat=0, 1/acc=0

        self.algorithm_def = RNSGA2(
            ref_points=reference_points,
            pop_size=population,
            epsilon=0.01,
            normalization='front',
            extreme_points_as_reference_points=False,
            sampling=sample_strategy,
            crossover=get_crossover("int_sbx", prob=crossover_prob, eta=crossover_eta),
            mutation=get_mutation("int_pm", prob=mutation_prob, eta=mutation_eta),
            weights=np.array([0.5, 0.5]),
            eliminate_duplicates=True)

    def configure_age(self, population=50, num_evals=1000, warm_pop=None,
                        crossover_prob=0.9, crossover_eta=15.0,
                        mutation_prob=0.02, mutation_eta=20.0):
        self.engine = 'pymoo'
        self.n_gens = num_evals/population

        print('[Info] Configuring AGE-MOEA algorithm.')

        if type(warm_pop) == 'numpy.ndarray':
            print('[Info] Using warm start population')
            sample_strategy = warm_pop
        else:
            sample_strategy = get_sampling("int_lhs")

        self.algorithm_def = AGEMOEA(
            pop_size=population,
            sampling=sample_strategy,
            crossover=get_crossover("int_sbx", prob=crossover_prob, eta=crossover_eta),
            mutation=get_mutation("int_pm", prob=mutation_prob, eta=mutation_eta),
            eliminate_duplicates=True)

        if num_evals % population != 0:
            print('[Warning] Number of samples not divisible by population size')

    def configure_ctaea(self, warm_pop=None, num_evals=1000,
                        ref_dirs=None,
                        crossover_prob=1.0, crossover_eta=30.0,
                        mutation_prob=None, mutation_eta=20.0):
        self.engine = 'pymoo'
        self.n_gens = num_evals/20
        print('[Info] Configuring C-TAEA algorithm.')

        if type(warm_pop) == 'numpy.ndarray':
            print('[Info] Using warm start population')
            sample_strategy = warm_pop
        else:
            sample_strategy = get_sampling("int_lhs")

        ref_dirs = get_reference_directions("energy", 2, 20, seed=0)
        ref_dirs = ref_dirs.astype('float64')

        self.algorithm_def = CTAEA(
            ref_dirs=ref_dirs,
            sampling=sample_strategy,
            crossover=get_crossover("int_sbx", prob=crossover_prob, eta=crossover_eta),
            mutation=get_mutation("int_pm", prob=mutation_prob, eta=mutation_eta),
            eliminate_duplicates=True)

    def configure_moead(self, n_neighbors=20, num_evals=1000,
                        warm_pop=None, ref_dirs=None,
                        crossover_prob=1.0, crossover_eta=20.0,
                        mutation_prob=None, mutation_eta=20.0):
        self.engine = 'pymoo'
        self.n_gens = num_evals/n_neighbors

        print('[Info] Configuring MOEA/D algorithm.')

        if type(warm_pop) == 'numpy.ndarray':
            print('[Info] Using warm start population')
            sample_strategy = warm_pop
        else:
            sample_strategy = get_sampling("int_lhs")

        ref_dirs = get_reference_directions("energy", 2, 20, seed=self.seed)
        ref_dirs = ref_dirs.astype('float64')

        self.algorithm_def = MOEAD(
            ref_dirs=ref_dirs,
            n_neighbors=n_neighbors,
            sampling=sample_strategy,
            crossover=get_crossover("int_sbx", prob=crossover_prob, eta=crossover_eta),
            mutation=get_mutation("int_pm", prob=mutation_prob, eta=mutation_eta),
            )


    def run_search(self, problem, save_history=False):

        print('[Info] Running Search ', end='', flush=True)
        start_time = time.time()

        if self.engine == 'pymoo':
            result = minimize(problem,
                            self.algorithm_def,
                            ('n_gen', int(self.n_gens)),
                            seed=self.seed,
                            save_history=save_history,
                            verbose=self.verbose)
        else:
            print('[Error] Invalid algorithm configuration!')
            raise NotImplementedError

        print("Success")
        print('[Info] Search Took {:.3f} seconds.'.format(time.time()-start_time))

        return result



class ProblemSingleObjective(Problem):

    def __init__(self, evaluation_interface, param_count, param_upperbound):
        super().__init__(n_var=param_count, n_obj=1, n_constr=0,
                         xl=0, xu=param_upperbound, type_var=np.int)

        self.evaluation_interface = evaluation_interface

    def _evaluate(self, x, out, *args, **kwargs):

        # Store results for a given generation for PyMoo
        objective_arr = list()

        # Measure new individuals
        for i in range(len(x)):

            _, objective = self.evaluation_interface.eval_subnet(x[i])

            objective_arr.append(objective)

        print('.', end='', flush=True)

        # Update PyMoo with evaluation data
        out["F"] = anp.column_stack([objective_arr])


class ProblemMultiObjective(Problem):

    def __init__(self, evaluation_interface, param_count, param_upperbound):
        super().__init__(n_var=param_count, n_obj=2, n_constr=0,
                         xl=0, xu=param_upperbound, type_var=np.int)

        self.evaluation_interface = evaluation_interface

    def _evaluate(self, x, out, *args, **kwargs):

        # Store results for a given generation for PyMoo
        objective_x_arr, objective_y_arr = list(), list()

        # Measure new individuals
        for i in range(len(x)):

            _, objective_x, objective_y = self.evaluation_interface.eval_subnet(x[i])

            objective_x_arr.append(objective_x)
            objective_y_arr.append(objective_y)

        print('.', end='', flush=True)

        # Update PyMoo with evaluation data
        out["F"] = anp.column_stack([objective_x_arr, objective_y_arr])

class ProblemManyObjective(Problem):

    def __init__(self, evaluation_interface, param_count, param_upperbound):
        super().__init__(n_var=param_count, n_obj=3, n_constr=0,
                         xl=0, xu=param_upperbound, type_var=np.int)

        self.evaluation_interface = evaluation_interface

    def _evaluate(self, x, out, *args, **kwargs):

        # Store results for a given generation for PyMoo
        objective_x_arr, objective_y_arr, objective_z_arr = list(), list(), list()

        # Measure new individuals
        for i in range(len(x)):

            _, objective_x, objective_y, objective_z = self.evaluation_interface.eval_subnet(x[i])

            objective_x_arr.append(objective_x)
            objective_y_arr.append(objective_y)
            objective_z_arr.append(objective_z)

        print('.', end='', flush=True)

        # Update PyMoo with evaluation data
        out["F"] = anp.column_stack([objective_x_arr, objective_y_arr, objective_z_arr])
