import math
import numpy as np
from .algorithm import Algorithm
from .utils import common
from .utils import tree


DEBUG = False


def update_single_obs(tree_list, action, reward, effort_levels, inherit_flag):
    """
    Loop through each location + effort value in the combinatorial action
    to update reward estimates and num visits

    Same function procedure used to update online tree as dataset tree
    """
    loc, effort = action

    effort_discrete = np.argmin(np.abs(effort_levels - np.asarray(effort))) # index of chosen effort level
    current_tree = tree_list[effort_discrete] # tree for the specificly chosen effort level

    node, node_val = current_tree.get_active_ball(loc) # node that was selected for having highest reward

    ready_to_split = node.add_obs(loc, reward)
    if ready_to_split: # split a region
        current_tree.tree_split_node(node, inherit_flag) # get tree and split


def update_obs(tree_list, actions, rewards, effort_levels, inherit_flag):
    """
    Loop through each location + effort value in the combinatorial action
    to update reward estimates and num visits

    Same function procedure used to update online tree as dataset tree
    """
    if DEBUG: print(f'Observed reward: {rewards} and action: {actions}')

    for i, action in enumerate(actions):
        if DEBUG: print(f'Updating observation for {loc} and effort {effort}')
        update_single_obs(tree_list, action, rewards[i], effort_levels, inherit_flag)


class AdaptiveDiscretization(Algorithm):
    """
    Adaptive discretization algorithm implemented for green-security enviroments
    with continuous domain and finite effort levels, with a discretization induced by l_inf metric


    Extra Attributes:
        T: (int) number of timesteps

        inherit_flag: boolean of whether to inherit when making children nodes
    """


    def __init__(self, env, T, epsilon, inherit_flag):
        super().__init__(env)

        self.T = T
        self.inherit_flag = inherit_flag

        # store the epsilon value used to compute the optimal action, and ensure that this tree doesn't discretize beyond that
        self.epsilon = epsilon
        self.max_depth = math.floor(math.log(epsilon, 0.5))
        print(f'  epsilon {epsilon}, max depth {self.max_depth}')

        # Create a list of discretizations of the location space for each possible effort level
        self.tree_list = []
        for j in range(len(self.effort_levels)):
            self.tree_list.append(tree.Tree(self.effort_levels[j], 2, self.max_depth))


    def reset(self):
        """ Resets the agent by setting all parameters back to zero """
        self.regret = 0 # resets the estimates back to zero
        self.regret_iterations = 0
        self.tree_list = []
        for j in range(len(self.effort_levels)):
            self.tree_list.append(tree.Tree(self.effort_levels[j], 2, self.max_depth))


    def update_obs(self, action, reward):
        update_obs(self.tree_list, action, reward, self.env.effort_levels, self.inherit_flag)


    def update_single_obs(self, loc, effort, reward):
        update_single_obs(self.tree_list, (loc, effort), reward, self.env.effort_levels, self.inherit_flag)


    def get_tree_size(self):
        ''' return the number of nodes in the whole tree (summed across effort levels) '''
        n_nodes = 0
        for j, eff in enumerate(self.effort_levels):
            tree = self.tree_list[j]
            n_nodes += tree.get_tree_size()
        return n_nodes


    def pick_action(self, t):
        """
        Formulates the ILP for the knapsack problem
        and solves for the optimal action vector
        """
        if DEBUG: print(f'Picking an action!')

        # Step 1: Generate a tree that we will use in order to optimize over
        # which is the "finest" discretization generated by all of the different trees for the separate effort levels
        # This is required in order to encode the constraint that each "location" can only have one "effort level" selected

        if DEBUG: print(f'Creating merged tree')
        merged_tree = tree.Tree(0, 2, self.max_depth) # make placeholder empty tree for merging
        self.merge_tree(merged_tree, merged_tree.head, [self.tree_list[j].head for j in range(len(self.effort_levels))])


        if DEBUG: print(f'Finished merging tree')
        if DEBUG: print(f'Generating the estimates to solve the knapsack IP')

        region_ucbs = [[] for _ in range(len(merged_tree.leaves))]
        for leaf_i, node in enumerate(merged_tree.leaves):
            # TODO: what if node.mean_val is just a float? and not a list. (this is the default, before tree merge)
            for i in range(len(node.mean_val)):
                if DEBUG: print(f'Node bounds: {node.bounds}, estimates: {node.mean_val}')
                if node.num_visits[i] >= 1:
                    # calculate UCB
                    ucb_estimate = node.mean_val[i] + common.conf_r(self.T, t, node.num_visits[i])
                    # ensure monotone
                    if ucb_estimate > node.prev_ucb[i]:
                        ucb_estimate = node.prev_ucb[i]
                    else:
                        node.prev_ucb[i] = ucb_estimate
                    region_ucbs[leaf_i].append(ucb_estimate)
                else:
                    region_ucbs[leaf_i].append(1)
            leaf_i += 1

        if DEBUG: print(f'UCBs of each region: {region_ucbs}')

        # import pdb; pdb.set_trace()
        # print(region_ucbs)

        # Solves for the optimal action based on these reward index values
        action, _ = common.solve_exploit(self.effort_levels, self.N, np.array(region_ucbs, dtype=object), self.budget)

        if DEBUG: print(f'Chosen action: {action}')
        # action is returned as a vector of length at most N of (loc index, effort level) pairs
        true_action = []
        for index, effort in action:
            effort_index = np.argmin(np.abs(self.env.effort_levels - np.array(effort)))
            loc = merged_tree.leaves[index].bounds[:,0] + merged_tree.leaves[index].radius
            current_tree = self.tree_list[effort_index]
            node, _ = current_tree.get_active_ball(loc)
            bounds = node.bounds
            radius = node.radius

            # randomly sample point from within the region
            location = node.sample_point()

            # location = bounds[:, 0] + radius
            # TODO: Should be randomly sampled from the region.....

            true_action.append((location, effort))
        if DEBUG: print(f'Final action: {true_action}')
        return true_action


    def merge_tree(self, merge_tree, new_node, node_list):
        mean_list     = [node.mean_val   for node in node_list]
        visit_list    = [node.num_visits for node in node_list]
        prev_ucb_list = [node.prev_ucb   for node in node_list]

        # NOTE: this is making mean_val, num_visits, and prev_ucb lists... whereas the default is to have them as floats
        new_node.mean_val   = np.copy(mean_list)
        new_node.num_visits = np.copy(visit_list) # updates current estimates for the node based on the list of nodes
        new_node.prev_ucb   = np.copy(prev_ucb_list) # xupdates current estimates for the node based on the list of nodes
        if DEBUG: print(f'Updating current node estimates at location {new_node.bounds, new_node.radius}: {new_node.mean_val}, {new_node.num_visits}')
        if np.all([len(node.children) == 0 for node in node_list]): # current set of nodes have no more children
            if DEBUG: print(f'No more children, done merging this branch')
            return # finished merging this branch, are finished
        else:
            if DEBUG: print(f'Recursive call!')
            # creates potentially fake list of node with no children and same estimate
            # and recurse
            merge_tree.tree_split_node(new_node) # splits the new node in the merged tree
            if DEBUG: print(f'Number of children: {len(new_node.children)}')
            for index in range(len(new_node.children)):
                new_node_list = []
                for node in node_list:
                    if len(node.children) == 0: # if the node has no children
                        new_node_list.append(tree.Node(0, new_node.children[index].bounds, new_node.children[index].depth, self.max_depth, node.mean_val, node.num_visits))
                    else:
                        new_node_list.append(node.children[index])
                if DEBUG: print(f'Length of new_node_list: {len(new_node_list)}')
                self.merge_tree(merge_tree, new_node.children[index], new_node_list) # recurse on the subtree
        return
