import deap, numpy as np, random
from deap import algorithms, base, creator, tools, gp
import math, random, logging, sys, torch
import operator


def protectedDiv(left, right):
    with np.errstate(divide='ignore',invalid='ignore'):
        x = np.divide(left, right)
        if isinstance(x, np.ndarray):
            x[np.isinf(x)] = 1
            x[np.isnan(x)] = 1
        elif np.isinf(x) or np.isnan(x):
            x = 1
    return x

def passGreater(left, right):
    if left > right: return left
    return right

def passSmaller(left, right):
    if left < right: return left
    return right

def equalTo(left, right):
    if left == right: return 1.0
    return 0.0

def Gate(a, b, condtion):
    if condtion <= 0: return a
    else: return b

def square(a):
    return a*a

def isNegative(a):
    if a < 0: return 1.0
    return 0.0

def div100(a):
    return a/100.0

def div10(a):
    return a/10.0

def l1_1d(a, b):
    return -abs(a-b)

def l1_2d(a, b, c, d):
    return -abs(a-c) - abs(b-d)



class GP_Population:
    def __init__(self, feature_dim, output_dim, popn_size, elite_ratio, kill_ratio, alpha=1.0):


        self.num_elites = int(elite_ratio * popn_size)
        if self.num_elites < 2: self.num_elites = 2
        self.num_kills= int(kill_ratio * popn_size)
        if self.num_kills < 2: self.num_kills = 2



        #######Create Primitives ############
        self.pset = gp.PrimitiveSet("MAIN", feature_dim)
        self.pset.addPrimitive(np.add, 2, name="vadd")
        self.pset.addPrimitive(np.subtract, 2, name="vsub")
        self.pset.addPrimitive(np.multiply, 2, name="vmul")
        self.pset.addPrimitive(np.negative, 1, name="vneg")
        self.pset.addPrimitive(np.cos, 1, name="vcos")
        self.pset.addPrimitive(np.sin, 1, name="vsin")
        self.pset.addPrimitive(np.abs, 1, name="abs")
        self.pset.addPrimitive(np.tan, 1, name="vtan")
        self.pset.addPrimitive(protectedDiv, 2)
        self.pset.addPrimitive(passGreater, 2)
        self.pset.addPrimitive(passSmaller, 2)
        self.pset.addPrimitive(equalTo, 2)
        self.pset.addPrimitive(Gate, 3)
        self.pset.addPrimitive(max, 5)
        self.pset.addPrimitive(min, 5)
        self.pset.addPrimitive(isNegative, 1)

        #Terminals
        self.pset.addTerminal(1)
        self.pset.addTerminal(0)

        #Custom Ops
        self.pset.addPrimitive(l1_1d, 2)
        self.pset.addPrimitive(l1_2d, 4)
        self.pset.addPrimitive(div100, 1)
        self.pset.addPrimitive(div10, 1)
        self.pset.addPrimitive(square, 1)

        # try:
        #     self.pset.addEphemeralConstant("randint", lambda: random.randint(-1,1))
        #     self.pset.addEphemeralConstant("randfloat", lambda: random.random())
        # except: None

        #Create Individual and Population
        creator.create("FitnessMax", base.Fitness, weights=(1.0,))
        creator.create("Individual", gp.PrimitiveTree, fitness=creator.FitnessMax, pset=self.pset)

        #Create Tooolbox
        self.toolbox = base.Toolbox()
        self.toolbox.register("expr", gp.genHalfAndHalf, pset=self.pset, min_=1, max_=2)
        self.toolbox.register("individual", tools.initIterate, creator.Individual, self.toolbox.expr)
        self.toolbox.register("population", tools.initRepeat, list, self.toolbox.individual)
        self.toolbox.register("mate", gp.cxOnePoint)
        self.toolbox.register("expr_mut", gp.genHalfAndHalf, min_=0, max_=2)
        self.toolbox.register('mutUniform', gp.mutUniform, expr=self.toolbox.expr_mut, pset=self.pset)
        self.toolbox.register('mutNode', gp.mutNodeReplacement, pset=self.pset)
        self.toolbox.register('mutEphemeral', gp.mutEphemeral,  pset=self.pset)
        self.toolbox.register('mutInsert', gp.mutInsert, pset=self.pset)
        self.toolbox.register('mutShrink', gp.mutShrink, pset=self.pset)
        self.toolbox.register("compile", gp.compile, pset=self.pset)
        #self.toolbox.register("select", tools.selTournament, tournsize=3)
        self.toolbox.register("select", tools.selDoubleTournament, fitness_size=3, parsimony_size=2, fitness_first=True)

        #Create Population
        self.popn = self.toolbox.population(popn_size)
        self.lineage = [-1.0] * popn_size
        self.alpha = alpha
        self.champ_ind = 0

        #Trackers
        self.avg_response = [0 for _ in range(popn_size)]




    def ParsimonySelection(self, individuals, selection_num, fitness_tsize, parsimony_tsize):

        #assert (1 <= parsimony_tsize <= 2), "Parsimony tournament size has to be in the range [1, 2]."

        ###### FITNESS BASED SELECTION ########
        fittest_inds = []
        for i in range(selection_num * 2):
            aspirants_inds = [random.choice(range(len(individuals))) for _ in range(fitness_tsize)]
            aspirants_fits = [self.lineage[id] for id in aspirants_inds]

            champ_index = aspirants_inds[aspirants_fits.index(max(aspirants_fits))]
            fittest_inds.append(champ_index)

        ###### Parsimony BASED SELECTION #########
        chosen = []
        for i in range(selection_num):
            aspirants_inds = [random.choice(fittest_inds) for _ in range(parsimony_tsize)]
            aspirants_size = [len(individuals[id]) for id in aspirants_inds]

            champ_index = aspirants_inds[aspirants_size.index(min(aspirants_size))]
            chosen.append(champ_index)

        return chosen

    def compute_reward(self, s, popn_id):
        #s --> [batch_id, features]

        # Get phenotype
        func = self.toolbox.compile(expr= self.popn[popn_id])

        #Apply Phenotype
        r = np.array([float(func(*row)) for row in s])



        #r = -torch.abs(s[:,0:1] - s[:,2:3]) - torch.abs(s[:,1:2] - s[:,3:4])
        # r = (r-r.mean()) / r.std()

        #print(abs(r).mean())


        #Sigmoid
        r = np.clip(r, -50, 50)
        r = 2 * (1.0 / (1 + np.exp(-r))) - 1.0  #(Sigmoid) * 2 - 1.0 --> [-1, 1]
        r *= 5.0 #Scale to [-5, 5]
        r = np.nan_to_num(r)

        self.avg_response[popn_id] = np.abs(r).mean()

        r = torch.Tensor(np.expand_dims(r, 1))



        if torch.isnan(r).any():
            Exception('NAN encountered in auto-reward compute')

        return r

    def forward(self, s, popn_id):
        #s --> [batch_id, features]

        # Get phenotype
        func = self.toolbox.compile(expr= self.popn[popn_id])

        #Apply Phenotype
        r = np.array([float(func(*row)) for row in s])

        return r

    def evolve(self, crossover_prob, mut_prob, verbose=False):


        ## One gen of evolution ###
        if len(self.popn) <= 1:  # If not no-evo
            return

        ############ SELECTION ##############
        #Update Lineage Scores
        for i in range(len(self.popn)):
            if self.lineage[i] == None: self.lineage[i] = self.popn[i].fitness.values[0]
            else: self.lineage[i] = (self.alpha * self.popn[i].fitness.values[0]) + (1-self.alpha) * self.lineage[i]

        #Select Elites
        elite_inds = sorted(range(len(self.lineage)), key=self.lineage.__getitem__)
        elite_inds.reverse()
        elite_inds = elite_inds[0:self.num_elites]

        #Select offsprings
        offspring_inds = self.ParsimonySelection(self.popn, len(self.popn) - self.num_elites, fitness_tsize=2, parsimony_tsize=3)

        #Process unselects
        unselect_inds = []
        for ind in range(len(self.popn)):
            if ind not in elite_inds and ind not in offspring_inds:
                unselect_inds.append(ind)
                self.lineage[ind] = -1.0

        #Save elites into unselects
        new_elite_inds = []; remove = []
        for i, j in zip(elite_inds, unselect_inds):
            self.popn[j] = self.toolbox.clone(self.popn[i])
            new_elite_inds.append(j)
            self.lineage[j] = self.lineage[i]

        #Remove new elites from unselects
        for ind in new_elite_inds:
            unselect_inds.remove(ind)

        #Clone offsprings/elites into unselects and mutate
        for child1, child2 in zip(unselect_inds[::2], unselect_inds[1::2]):

            parent_a = random.choice(new_elite_inds)
            parent_b = random.choice(offspring_inds)
            self.popn[child1] = self.toolbox.clone(self.popn[parent_a])
            self.popn[child2] = self.toolbox.clone(self.popn[parent_b])


            if random.random()< 0.5:
                try: self.popn[child1], self.popn[child2] = self.toolbox.mate(self.popn[child1], self.popn[child2])
                except: None
            if random.random()< 0.2:
                try: self.popn[child1] = self.toolbox.individual()
                except: None

            if random.random() < 0.2:
                self.popn[child2] = self.toolbox.individual()

            if random.random() < 0.1:
                try: self.toolbox.mutUniform(self.popn[child1])
                except:
                    if verbose: print('MutUniform for Unselects failed')

            if random.random() < 0.25:
                try: self.toolbox.mutNode(self.popn[child2])
                except:
                    if verbose: print('MutNode for Unselects failed')

            if random.random() < 0.25:
                try: self.toolbox.mutInsert(self.popn[child1])
                except:
                    if verbose: print('MutInsert for Unselects failed')

            if random.random() < 0.5:
                try: self.toolbox.mutShrink(self.popn[child2])
                except:
                    if verbose: pass#print('MutShrink for Unselects failed')
                # try: self.toolbox.mutEphermeral(self.popn[ind])
                # except: print('MutEphemeral for Unselects failed')

        # Apply crossover on the offsprings
        for child1, child2 in zip(offspring_inds[::2], offspring_inds[1::2]):
            if random.random() < crossover_prob:
                try:
                    self.popn[child1], self.popn[child2] = self.toolbox.mate(self.popn[child1], self.popn[child2])
                    combined_lineage = (self.lineage[child1] + self.lineage[child2]) / 2.0
                    self.lineage[child1], self.lineage[child2] = combined_lineage, combined_lineage
                except:
                    if verbose: print('Crossover failed')

        # Apply mutation on the offsprings
        for mutant in offspring_inds:
            if random.random() < mut_prob:
                try:
                    rand = random.random()
                    if rand < 0.2: self.toolbox.mutUniform(self.popn[mutant])
                    elif rand < 0.4: self.toolbox.mutNode(self.popn[mutant])
                    elif rand < 0.6: self.toolbox.mutInsert(self.popn[mutant])
                    elif rand < 0.8: self.toolbox.mutShrink(self.popn[mutant])
                    else: self.toolbox.mutEphermeral(self.popn[mutant])
                except:
                    if verbose: print('Mutation for Offsprings failed')


        if verbose: print('Elites', len(new_elite_inds), 'Offsprings', len(offspring_inds), 'Unselects', len(unselect_inds), 'Redundancy', len(new_elite_inds)+len(offspring_inds)+len(unselect_inds)-len(self.popn))


        return unselect_inds, elite_inds, new_elite_inds
