import operator
from deap import gp, tools, algorithms, base
from functools import partial
import numpy as np
import contextlib
import pickle
import random
import numpy as np
import time
import json


class Redux(object):
    MUT_PB: float = 0.1
    CX_PB: float = 0.9
    N_ELITES: int = 1
    HOF_NUM: int = 3
    MAX_TREE_HEIGHT: int = 5


@contextlib.contextmanager
def seeded_random_state(seed):
    pre_rd_state = random.getstate()
    pre_np_state = np.random.get_state()

    random.seed(seed)
    np.random.seed(seed)
    try:
        yield
    finally:
        random.setstate(pre_rd_state)
        np.random.set_state(pre_np_state)


fitnessMin = type('', (base.Fitness,), {'weights': (-1.0,)})
fitnessMax = type('', (base.Fitness,), {'weights': (1.0,)})


class IndividualMin(gp.PrimitiveTree):
    def __init__(self, gene_gen):
        super().__init__(gene_gen)
        self.fitness = fitnessMin()


class IndividualMax(gp.PrimitiveTree):
    def __init__(self, gene_gen):
        super().__init__(gene_gen)
        self.fitness = fitnessMax()


def div(x1, x2):
    if abs(x2) < 1e-6:
        return 1
    return x1 / x2


def std_stats():
    mstats = tools.Statistics(lambda ind: ind.fitness.values[0])
    # stats_size = tools.Statistics(lambda ind: ind.height)
    # mstats = tools.MultiStatistics(fitness=stats_fit, size=stats_size)
    # mstats = tools.MultiStatistics(fitness=stats_fit)
    mstats.register("avg", np.mean)
    mstats.register("std", np.std)
    mstats.register("min", np.min)
    mstats.register("max", np.max)
    return mstats


def std_pset(input_dim):
    pset = gp.PrimitiveSet("MAIN", arity=input_dim, prefix='x')
    pset.addPrimitive(operator.add, 2)
    # pset.addPrimitive(operator.sub, 2)
    pset.addPrimitive(operator.mul, 2)
    pset.addPrimitive(div, 2)
    pset.addPrimitive(operator.neg, 1)
    pset.addPrimitive(max, 2)
    pset.addPrimitive(min, 2)

    pset.addEphemeralConstant('c', partial(np.random.uniform, -1, 1))
    return pset


def std_toolbox(pset, minimize=True):
    Individual = IndividualMin if minimize else IndividualMax

    toolbox = base.Toolbox()
    toolbox.register("expr", gp.genHalfAndHalf, pset=pset, min_=1, max_=3)
    toolbox.register("individual", tools.initIterate, Individual, toolbox.expr)
    toolbox.register("population", tools.initRepeat, list, toolbox.individual)
    toolbox.register("compile", gp.compile, pset=pset)
    toolbox.register("select", tools.selTournament, tournsize=3)
    toolbox.register("mate", gp.cxOnePoint)
    toolbox.register("mutate", gp.mutUniform, expr=toolbox.expr, pset=pset)

    toolbox.decorate("mate", gp.staticLimit(key=operator.attrgetter("height"), max_value=Redux.MAX_TREE_HEIGHT))
    toolbox.decorate("mutate", gp.staticLimit(key=operator.attrgetter("height"), max_value=Redux.MAX_TREE_HEIGHT))
    return toolbox


class Evolution(object):
    def __init__(self, population, ngen, toolbox, pset, seed: int = 42, save_path=None, stats=None):
        if not hasattr(toolbox, 'evaluate'):
            raise AttributeError("Toolbox should have attribute 'evaluate'")

        self.save_path = save_path
        self.seed = seed
        self.save_path = None

        self.population = population
        self.toolbox = toolbox
        self.pset = pset
        self.n_gen = ngen
        self.n_pop = len(population)

        self.hof = tools.HallOfFame(Redux.HOF_NUM)

        self.stats = std_stats() if stats is None else stats
        self.logbook = tools.Logbook()
        self.logbook.header = ['time', 'gen', 'nevals'] + self.stats.fields

    def run(self):
        with seeded_random_state(seed=self.seed):
            for gen in range(self.n_gen):
                self.step(gen)

    def step(self, gen: int):
        fitnesses, invalid_ind = self._eval()

        self._log(gen, len(invalid_ind))
        if gen == self.n_gen - 1: return

        elites = tools.selBest(self.population, k=Redux.N_ELITES)
        offspring = self.toolbox.select(self.population, self.n_pop - Redux.N_ELITES)
        offspring = algorithms.varAnd(offspring, self.toolbox, Redux.CX_PB, Redux.MUT_PB)

        self.population = elites + offspring

    def _eval(self):
        invalid_ind = [ind for ind in self.population if not ind.fitness.valid]
        fitnesses = self.toolbox.map(self.toolbox.evaluate, invalid_ind)

        for ind, fit in zip(invalid_ind, fitnesses):
            ind.fitness.values = fit
        return fitnesses, invalid_ind

    def _log(self, gen: int, n_eval: int):
        self.hof.update(self.population)
        record = self.stats.compile(self.population)
        self.logbook.record(time=time.strftime("%D %H:%M:%S"), gen=gen, nevals=n_eval, **record)
        print(self.logbook.stream, flush=True)


from pathos.multiprocessing import ProcessPool


def my_map(f, args):
    return [f(arg) for arg in args]


def gplearn(target, dim, popsize=25, iterations=20):
    def evaluate(individual, pset):
        f = gp.compile(expr=individual, pset=pset)
        return target(f)

    pset = std_pset(dim)
    toolbox = std_toolbox(pset)
    toolbox.register('evaluate', evaluate, pset=pset)

    # pool = ProcessPool(10)
    # toolbox.register('map', pool.map)
    toolbox.register('map', my_map)

    pop = toolbox.population(n=popsize)
    evo = Evolution(population=pop, ngen=iterations, toolbox=toolbox, pset=pset, seed=42, save_path=None)
    evo.run()
    best_ind = evo.hof[0]
    return str(best_ind), gp.compile(best_ind, pset)


if __name__ == '__main__':
    def test_func(func):
        return func(2, 3, 4, 5),


    gplearn(test_func, 4)
