import random
from copy import deepcopy
from collections import deque
import geppy as gep  # for reuse compile program


class IncrementalGEProgram(list):
    pset = None
    SEP = '|'

    def __init__(self, pset: gep.PrimitiveSet, head_length: int):
        super().__init__()
        self.__class__.pset = pset
        self.func_length = head_length
        self.max_length = self.func_length * self.pset.max_arity + 1
        self.remain = 1
        self._cur_funcs = deque()

    def increment(self, token=None, new_instance=False):
        if self.finished():
            raise Exception('Already!')
        if token is None:
            if len(self) < self.func_length and random.random() < 0.5:
                token = random.choice(self.pset.functions)
            else:
                token = random.choice(self.pset.terminals)
        prog = self
        if new_instance:
            prog = deepcopy(self)
        prog.append(token)

        prog.remain -= 1
        prog.remain += token.arity

        if len(prog._cur_funcs):
            prog._cur_funcs[0][1] -= 1
            if prog._cur_funcs[0][1] <= 0:
                prog._cur_funcs.popleft()
        if token.arity > 0:
            prog._cur_funcs.append([token, token.arity])

        return prog

    @property
    def current_parent(self):
        if self._cur_funcs:
            return self._cur_funcs[0][0]
        else:
            return None

    def next_available_tokens(self):
        if self.finished():
            return []
        if len(self) < self.func_length:
            tokens = self.pset.functions + self.pset.terminals
        else:
            tokens = self.pset.terminals
        return deepcopy(tokens)

    def __call__(self, *args):
        chrm = self.to_chromosome()
        func = gep.compile_(chrm, self.pset)
        return func(*args)

    def finished(self) -> bool:
        return self.remain <= 0

    def to_gene(self):
        if not self.finished():
            raise Exception('unfinished geprogram!')
        prog = deepcopy(self)
        if prog.remain == 0:
            for _ in range(len(prog), prog.max_length):
                prog.append(prog.pset.terminals[0])
        return gep.Gene.from_genome(prog, prog.func_length)

    @classmethod
    def from_gene(cls, gene, pset):
        ind = 0
        prog = cls(pset, gene.head_length)
        while not prog.finished():
            prog = prog.increment(gene[ind], new_instance=False)
            ind += 1
        return prog

    @classmethod
    def from_string(cls, string, pset, head_length):
        token_mapping = {
            token.name: token for token in pset.functions+pset.terminals
        }
        names = string.strip().split(cls.SEP)
        tokens = [token_mapping[name] for name in names]
        ind = 0
        prog = cls(pset, head_length)
        while not prog.finished():
            prog = prog.increment(tokens[ind], new_instance=False)
            ind += 1
        return prog

    def to_chromosome(self):
        return gep.Chromosome.from_genes([self.to_gene()], None)

    def __repr__(self):
        return self.SEP.join([i.name for i in self])

    def __hash__(self):
        return hash(self.__repr__())


if __name__ == '__main__':
    import operator

    pset = gep.PrimitiveSet('Main', input_names=['x', 'y'])
    pset.add_function(operator.add, 2)
    pset.add_function(operator.sub, 2)
    pset.add_function(operator.mul, 2)

    prog = IncrementalGEProgram(pset, 5)
    while not prog.finished():
        prog.increment()
        print(prog, prog.current_parent)
    #
    # genome = gep.generate_genome(pset, 8)
    # gene = gep.Gene.from_genome(genome, 8)
    # print([g.name for g in genome])
    # prog = IncrementalGEProgram.from_gene(
    #     gene, pset
    # )
    # print(prog)
