import torch
import math
import numpy as np
import random
import copy

dev = torch.device('cpu')

"""
    Tree Search Algorithm:
        takes as input a state and an actor-critic
         outputs improved decision probabilities with a probabilistic tree-search over simulated trajectories
"""
def tree_search(actor_critic, simulation_state, n_actions, n_steps, T=1., C=1.):
        actor_critic.eval()
        p,_ = actor_critic(torch.tensor(simulation_state.observe_state(), device=dev))
        root_node = Search_Node(p, simulation_state, reward=0, n_actions=n_actions)
    
        for _ in range(n_steps):
            # Select
            Q_min,Q_max = root_node.min_max_Q(0.,1.)
            trace = root_node.select(Q_min, Q_max, C)
            last_node = trace[-1][2]
            last_action = trace[-1][1]

            # Expand
            if not last_node.simulation_state.finished():
                last_node,v = last_node.expand(last_action, actor_critic)

            # Backup
            G = v
            trace.reverse()

            reward_ = last_node.R
            for reward,action,node in trace:
                G = reward_+G
                node.update(action,G)
                reward_ = reward

        # Compute root node probabilities
        sum_N = sum([n**(1/T) for n in root_node.N])
        return torch.tensor([n**(1/T)/sum_N for n in root_node.N])


class Search_Node:
    def __init__(self, p, simulation_state, reward, n_actions):
        p = p.tolist()
        # Initialize q-values estimation Q(s,a); number of plays along that branch N(s,a); and weights P given p the action probabilities from the policy
        self.Q = [0. for _ in range(n_actions)]
        self.N = [0 for _ in range(n_actions)]
        self.P = [p[i] for i in range(n_actions)]
        self.R = reward
        self.simulation_state = simulation_state
        self.children = [None for _ in range(n_actions)]
        self.n_actions = n_actions

    def select(self, Q_min, Q_max, C):
        # argmax q+ucb over children
        #    Q_min and Q_max are the minimal and maximal estimated Q value in the tree
        Q_ = [(q-Q_min)/(Q_max-Q_min) for q in self.Q]
        sum_plays = sum(self.N)
        ucb = [Q_[a] + C*self.P[a]*math.sqrt(sum_plays)/(1+self.N[a]) for a in range(self.n_actions)]
        a = random.sample(argmax(ucb),1)[0]

        if self.children[a]==None:
            return [(self.R,a,self)]
        else:
            return [(self.R,a,self)] + self.children[a].select(Q_min,Q_max,C)

    def expand(self, a, actor_critic):
        new_simulation_state = copy.deepcopy(self.simulation_state)
        r,s = new_simulation_state.step(a)

        p,v = actor_critic(torch.tensor(s, device=dev))
        new_node = Search_Node(p, new_simulation_state, r, self.n_actions)

        self.children[a] = new_node

        return new_node,v

    def update(self, a, G):
        # For the backup operation
        #    Update the search node with value estimated
        self.Q[a] = (self.N[a]*self.Q[a]+G)/(self.N[a]+1)
        self.N[a] += 1
    
    def min_max_Q(self,current_Q_min,current_Q_max):
        Q_min = min(current_Q_min,*self.Q)
        Q_max = max(current_Q_max,*self.Q)
        for c in self.children:
            if c!=None:
                Q_min,Q_max = c.min_max_Q(Q_min,Q_max)

        return Q_min,Q_max
    
    def print(self,depth):
        pad = '\t'*depth
        print(pad+str(self))
        children = [c for c in self.children if c!=None]
        if len(children)>0:
            print(pad+"children : ")
            for c in children:
                c.print(depth+1)
                print()


def argmax(tab):
    idx = None
    max = -math.inf
    for i,v in enumerate(tab):
        if v>max:
            max = v
            idx = [i]
        elif v==max:
            idx.append(i)
    
    return idx
