from .mcts import SearchDomain
from .program import IncrementalGEProgram
import random
import geppy as gep
import numpy as np


class StateTransition:
    def __init__(self, programs):
        self.programs = programs
        self.pset = IncrementalGEProgram.pset
        self.tokens = self.pset.functions + self.pset.terminals
        self._prob = {f: {token: 1 for token in self.tokens}
                      for f in self.pset.functions}

    def build(self):
        for prog in self.programs:
            self._build_single(prog)
        return self

    def _build_single(self, prog):
        parent_ind = len(prog) - 1
        child_ind = len(prog) - 1
        while parent_ind >= 0:
            if prog[parent_ind].arity > 0:
                f = prog[parent_ind]
                for _ in range(f.arity):
                    arg = prog[child_ind]
                    # print(f, arg)
                    self._add_relationship(f, arg)
                    child_ind -= 1
            parent_ind -= 1

    def _add_relationship(self, parent, child):
        self._prob[parent][child] += 1

    def probability(self, parent, child):
        return self._prob[parent][child]


class SymDomain(SearchDomain):
    def __init__(self, pset, opt, ST, func_length):
        self.pset = pset
        self.cache = dict()
        self.num_eval = 0

        self.model = opt
        self.ST = ST

        self.func_length = func_length

    def get_initial_state(self):
        return IncrementalGEProgram(self.pset, self.func_length)

    def get_valid_actions(self, state: IncrementalGEProgram):
        actions = state.next_available_tokens()
        random.shuffle(actions)
        return actions

    def apply_action(self, state: IncrementalGEProgram, action: gep.Primitive):
        new_state = state.increment(action, new_instance=True)
        return new_state

    def is_terminal(self, state: IncrementalGEProgram):
        return state.finished()

    def evaluate(self, state: IncrementalGEProgram):
        cache_reward = self.cache.get(state)
        if cache_reward is not None:
            return 1 / cache_reward
        else:
            reward = self.model.evaluate(state)
            self.num_eval += 1
            self.cache[state] = reward
            return 1 / self.cache[state]

    def rollout_action(self, state: IncrementalGEProgram):
        cur_parent = state.current_parent
        actions = self.get_valid_actions(state)
        p = np.array([self.ST.probability(cur_parent, action) for action in actions])
        return np.random.choice(actions, p=p / np.sum(p))


def get_pareto_front(domain):
    is_dominated = lambda a, b: (b[1] <= a[1] and b[2] <= a[2]) and (b[1] < a[1] or b[2] < a[2])

    solutions = domain.cache
    evaluated = [(path, len(path), length) for path, length in solutions.items()]

    pareto_front = []
    for a in evaluated:
        dominated = False
        for b in evaluated:
            if is_dominated(a, b):
                dominated = True
                break
        if not dominated:
            pareto_front.append(a)

    pareto_front.sort(key=lambda e: e[-1])
    return [(a[0], a[2]) for a in pareto_front]