import numpy as np
from numpy import random as rand
from collections import deque
from .treexplore import *

# EQUALITY1 = [
#     ([2, 0], []),
#     ([0, 2], []),
#     ([1, 0], [0, 1]),
#     ([1, 2], [2, 1]),
#     ([0, 3], [3, 0]),
#     ([2, 3], [3, 2]),
#     ([3, 1], []),
#     ([], [1, 3]),
# ]

EQUALITY1 = [
    ([1, 0], []),
    ([0, 1], []),
    ([0, 0], [1, 1]),
    ([0, 0, 0, 0], []),
    ([1, 1, 1, 1], []),
    # redondant
    # ([1], [0, 0, 0]),
    # ([0], [1, 1, 1]),
]


class Explorer:
    def __init__(self, actions, depth, reset_on_policy=False):
        self.actions = actions
        self.depth = depth
        self.reset_on_policy = reset_on_policy
        self.n_actions = len(actions)

    def add_action(self, action, exploration=True):
        if isinstance(action, np.ndarray):
            for a in action:
                self._add_single_action(a)
        else:
            self._add_single_action(action, exploration)
            if not exploration and self.reset_on_policy:
                self.reset()

    def get_p(self):
        raise NotImplementedError

    def _add_single_action(self, action, exploration):
        raise NotImplementedError

    def get_action(self, n_samples=1):
        raise NotImplementedError

    def reset(self):
        raise NotImplementedError


class NaiveExplorer(Explorer):
    def __init__(self, actions, depth):
        super().__init__(actions, depth, reset_on_policy=False)
        self.reset()

    def reset(self):
        self.count = 0
        self.continue_exploring = False

    def _add_single_action(self, action, exploration=True):
        self.count += 1
        if self.count >= self.depth:
            self.reset()

    def get_action(self, n_samples=1):
        self.continue_exploring = True

        return [rand.choice(self.actions) for _ in range(n_samples)]


class GraphExplorer(Explorer):
    def __init__(
        self,
        actions,
        depth,
        equalities,
        objective="node_balance",
        reset_on_policy=False,
        verbose=True,
    ):
        super().__init__(actions, depth, reset_on_policy)
        self.g = TreExplore(equalities, len(actions), tree_depth=depth, verbose=verbose)
        self.g.expand_length()
        self.g.find_opt(objective)
        self.objective = objective
        self.verbose=verbose

        self.reset()

    def _add_single_action(self, action, exploration=True):
        self.prev_actions.append(action)
        if len(self.prev_actions) >= self.depth:
            self.reset()

    def reset(self):
        self.prev_actions = []
        self.continue_exploring = False
        self.g.reset_traj()

    def get_p(self):
        node = self.g.node_from_actions(self.prev_actions)
        return self.g.action_probabilities(node, self.objective)

    def get_action(self, n_samples=1):
        self.continue_exploring = True

        action = self.g.sample_next()
        return np.array([action])
