from typing import Dict, List, Union

import ecole

from rl.environments.search_tree import SearchTree


class RewardAgent:
    def __init__(self, n_step: int = 3, gamma: float = None):
        """
        Helpful class to track B&B tree and postprocess episode
        into Tree MDP.
        """
        self.n_step = n_step
        self.off_policy = True
        self.gamma = gamma

    def synchronize(self, n_step: int = 3, gamma: float = None):
        self.n_step = n_step
        self.off_policy = True
        self.gamma = gamma

    def before_reset(self, model: ecole.scip.Model):
        self.started = False

    def extract(self, model: ecole.scip.Model, done: bool, action_set=None) -> Union[None, List[Dict]]:
        if not self.started:
            self.started = True
            self.search_tree = SearchTree(model, action_set)
            return None

        self.search_tree.update_tree(model, action_set)

        if not done:
            # only update B&B tree
            return None

        else:
            # instance solved, retrospectively build Tree MDP
            if self.search_tree.tree.graph["root_node"] is None:
                # instance was pre-solved
                return [{0: 0}]

            # keep track of which nodes have been added to a sub-tree
            self.nodes_added = set()

            # map which nodes were visited at which step in episode
            self.visited_nodes_to_step_idx = {
                node: idx for idx, node in enumerate(self.search_tree.tree.graph["visited_node_ids"])
            }
            self.step_idx_to_visited_nodes = dict(enumerate(self.search_tree.tree.graph["visited_node_ids"]))
            node_step_idx_to_child_nodes = {step_idx: None for step_idx in self.step_idx_to_visited_nodes}

            # Current time complexity is: O(|T|*2**n_step)
            # It should be possible to do it in O(|T|).

            for node in self.search_tree.tree.nodes:
                if node not in self.visited_nodes_to_step_idx:
                    continue
                reward, step = 0, 0
                one_step_children, n_step_children = [], []
                open_nodes = [node]
                while len(open_nodes) > 0:
                    current_node = open_nodes.pop(0)
                    children = [
                        child
                        for child in self.search_tree.tree.successors(current_node)
                        if child in self.visited_nodes_to_step_idx
                    ]
                    reward -= 1 * (self.gamma**step)
                    step += 1
                    # sort children according to their step visit

                    for child in children:
                        open_nodes.append(child)

                    open_nodes = sorted(open_nodes, key=lambda x: self.visited_nodes_to_step_idx[x])

                    # store 1-step transition
                    if step == 1:
                        one_step_reward = reward
                        one_step_children = open_nodes[:]  ## FLAG
                        if self.n_step == 1:
                            break
                    # update n-step transition
                    if step <= self.n_step:
                        n_step_reward = reward
                        n_step_children = open_nodes[:]  ## FLAG
                        if step == self.n_step:
                            break

                one_step_children_indices = [
                    self.visited_nodes_to_step_idx[child] for child in one_step_children
                ]
                tree_transition = {
                    "next_states": one_step_children_indices,
                    "reward": one_step_reward,
                }

                if self.n_step > 1:
                    # store n-step transition
                    n_step_children_indices = [
                        self.visited_nodes_to_step_idx[n_step_child] for n_step_child in n_step_children
                    ]
                    tree_transition.update(
                        {
                            "n_step_next_states": n_step_children_indices,
                            "n_step_reward": n_step_reward,
                        }
                    )

                node_step_idx = self.visited_nodes_to_step_idx[node]
                node_step_idx_to_child_nodes[node_step_idx] = tree_transition
                self.nodes_added.add(node)

            assert len(self.nodes_added) == len(self.visited_nodes_to_step_idx)

            return node_step_idx_to_child_nodes


class TreeMDPRewardAgent:
    def __init__(self, n_step: int = 3, gamma: float = None):
        """
        Helpful class to track B&B tree and postprocess episode
        into Tree MDP.
        """
        self.n_step = n_step
        self.off_policy = True
        self.gamma = gamma

    def synchronize(self, n_step: int = 3, gamma: float = None):
        self.n_step = n_step
        self.off_policy = True
        self.gamma = gamma

    def before_reset(self, model: ecole.scip.Model):
        self.started = False

    def extract(self, model: ecole.scip.Model, done: bool, action_set=None) -> Union[None, List[Dict]]:
        if not self.started:
            self.started = True
            self.search_tree = SearchTree(model, action_set)
            return None

        self.search_tree.update_tree(model, action_set)

        if not done:
            # only update B&B tree
            return None

        else:
            # instance solved, retrospectively build Tree MDP
            if self.search_tree.tree.graph["root_node"] is None:
                # instance was pre-solved
                return [{0: 0}]

            # keep track of which nodes have been added to a sub-tree
            self.nodes_added = set()

            # map which nodes were visited at which step in episode
            self.visited_nodes_to_step_idx = {
                node: idx for idx, node in enumerate(self.search_tree.tree.graph["visited_node_ids"])
            }
            self.step_idx_to_visited_nodes = dict(enumerate(self.search_tree.tree.graph["visited_node_ids"]))
            node_step_idx_to_child_nodes = {step_idx: None for step_idx in self.step_idx_to_visited_nodes}

            # Current time complexity is: O(|T|*2**n_step)
            # It should be possible to do it in O(|T|).

            for node in self.search_tree.tree.nodes:
                if node not in self.visited_nodes_to_step_idx:
                    continue
                reward, depth = 0, 0
                one_step_children, n_step_children = [], []
                parents = [node]
                while len(parents) > 0:
                    depth += 1
                    next_parents = []
                    for parent in parents:
                        children = [
                            child
                            for child in self.search_tree.tree.successors(parent)
                            if child in self.visited_nodes_to_step_idx
                        ]
                        reward -= (3 - len(children)) * (self.gamma**depth)
                        for child in children:
                            next_parents.append(child)
                    parents = next_parents

                    # store 1-step transition
                    if depth == 1:
                        one_step_reward = reward
                        one_step_children = parents
                        if self.n_step == 1:
                            break
                    # update n-step transition
                    if depth <= self.n_step:
                        n_step_reward = reward
                        n_step_children = parents
                        if depth == self.n_step:
                            break

                one_step_children_indices = [
                    self.visited_nodes_to_step_idx[child] for child in one_step_children
                ]
                tree_transition = {
                    "next_states": one_step_children_indices,
                    "reward": one_step_reward,
                }

                if self.n_step > 1:
                    # store n-step transition
                    n_step_children_indices = [
                        self.visited_nodes_to_step_idx[n_step_child] for n_step_child in n_step_children
                    ]
                    tree_transition.update(
                        {
                            "n_step_next_states": n_step_children_indices,
                            "n_step_reward": n_step_reward,
                        }
                    )

                node_step_idx = self.visited_nodes_to_step_idx[node]
                node_step_idx_to_child_nodes[node_step_idx] = tree_transition
                self.nodes_added.add(node)

            assert len(self.nodes_added) == len(self.visited_nodes_to_step_idx)

            return node_step_idx_to_child_nodes
