import numpy as np
from sklearn.cluster import KMeans
import pdb

class Node(object):

    def __init__(self, id, value):

        self.node_id = id
        self.children = []
        self.value = value
        self.father = None
        self.visit = 0
        self.depth = 0

    def cal_uct(self, c, test_mode=False):
        if self.father == None:
            return 0
        else:
            if test_mode:
                return self.value
            f_visit = self.father.visit
            uct = self.value + c * np.sqrt(np.log(f_visit) / (1 + self.visit))
            return uct

    def backprop(self, value):

        node = self
        node.value = (node.value * node.visit + value) / (node.visit + 1)
        node.visit += 1
        node = node.father

        while node != None:
            node.visit += 1
            sum = 0
            for _, n in node.children:
                sum += n.visit * n.value
            node.value = sum / node.visit

            node = node.father 


    def __repr__(self):
        return str(self.value)

class Tree(object):

    def __init__(self, max_seq_length, predictor_sample, tree_sample):
        self.c = 0.1
        self.roots = []
        self.max_seq_length = max_seq_length
        self.max_depth = 20
        self.hyper_node = Node(-1, -1)
        self.candidate_nodes = []
        self.cluster_num = 400

        self.predictor_sample = predictor_sample
        self.tree_sample = tree_sample
        self.warm_up = True

    def generate_predictor(self, replay_buffer):
        '''
        if self.predictor_sample == 'ran':
            samples = replay_buffer.sample(index)
        elif self.predictor_sample == 'ind':
            samples = replay_buffer.sample_by_index(index)
        else:
            samples = replay_buffer.sample_by_return(index)
        '''
        samples, index = replay_buffer.sample_all()
        states = samples['state'][:, :self.max_depth]
        terminated = samples['terminated'][:, :]

        nodes = []
        for episode_idx in range(index):
            for tran_idx in range(self.max_depth):
                if terminated[episode_idx, tran_idx] == 1:
                    break
                nodes.append(states[episode_idx, tran_idx].tolist())
        nodes = np.array(nodes)

        self.kmeans = KMeans(self.cluster_num, max_iter=100, n_init=5).fit(nodes)

    def generate_tree_from_buffer(self, replay_buffer):
        
        self.generate_predictor(replay_buffer)

        self.hyper_node = Node(-1, -1)
        self.roots = []
        '''
        if self.tree_sample == 'ran':
            samples = replay_buffer.sample(index)
        elif self.tree_sample == 'ind':
            samples = replay_buffer.sample_by_index(index)
        else:
            samples = replay_buffer.sample_by_return(index)
        '''
        samples, index = replay_buffer.sample_all()

        states = samples['state'][:, :]
        terminated = samples['terminated'][:, :]
        rewards = samples['reward'][:, :]
        actions = samples['actions'][:, :]

        for episode_idx in range(index):

            root_state_idx = int(self.kmeans.predict(states[episode_idx, 0].reshape(1, -1)))
            in_root = False
            for r in self.roots:
                if r.node_id == root_state_idx:
                    n = r
                    in_root = True
                    break
            if in_root == False:
                n = Node(root_state_idx, 0)
                n.father = self.hyper_node
                self.roots.append(n)

            for tran_idx in range(self.max_depth):
                if terminated[episode_idx, tran_idx] == 1:
                    break
                a = tuple(inner_a.item() for inner_a in actions[episode_idx, tran_idx])
                s = states[episode_idx, tran_idx+1]
                node_id = int(self.kmeans.predict(s.reshape(1, -1)))

                in_children = False

                for(action, node) in n.children:
                    if action == a and node_id == node.node_id:
                        n = node
                        in_children = True
                        break
                
                if in_children == False:
                    new_node = Node(node_id, 0)
                    new_node.father = n
                    new_node.depth = n.depth + 1
                    n.children.append((a, new_node))
                    n = new_node

            episode_return = np.sum(r[0] for r in rewards[episode_idx].tolist())
            n.backprop(episode_return)
            #n.backprop([r[0] for r in rewards[episode_idx].tolist()])
        print('Finish Generating')

    def update_tree(self, ep_batch):

        states = ep_batch['state']
        rewards = ep_batch['reward']
        actions = ep_batch['actions']
        terminated = ep_batch['terminated']
        
        for tra_idx in range(len(states)):

            root_state = states[tra_idx, 0].cpu()
            root_state_idx = int(self.kmeans.predict(root_state.reshape(1, -1)))

            in_root = False
            for r in self.roots:
                if r.node_id == root_state_idx:
                    n = r
                    in_root = True
                    break
            if in_root == False:
                n = Node(root_state_idx, 0)
                n.father = self.hyper_node
                self.roots.append(n)
            # Need to Modify here.
            for tran_idx in range(self.max_depth):
                if terminated[tra_idx][tran_idx] == 1:
                    break
                
                a = tuple(inner_a.item() for inner_a in actions[tra_idx, tran_idx].cpu())
                s = states[tra_idx, tran_idx+1].cpu()
                node_id = int(self.kmeans.predict(s.reshape(1, -1)))
                
                in_children = False
                for (action, node) in n.children:
                    if action == a and node_id == n.node_id:
                        n = node
                        in_children = True
                        break

                if in_children == False:
                    new_node = Node(node_id, 0)
                    new_node.father = n
                    new_node.depth = n.depth + 1
                    n.children.append((a, new_node))
                    n = new_node
                    
            # BP the tree 
            episode_return = np.sum(r[0] for r in rewards[tra_idx].tolist())
            n.backprop(episode_return)
            #n.backprop([r[0] for r in rewards[tra_idx].tolist()])
        #self.nodes_in_tree()
        

    def generate_templates(self, states):
        relay = [0 for _ in range(len(states))]
        for state_idx, s in enumerate(states):
            root_idx = self.kmeans.predict(s[0].cpu().reshape(1, -1))
            in_root = False
            for r in self.roots:
                if r.node_id == root_idx:
                    in_root = True
                    node = r
                    idx = 1
                    while len(node.children) != 0:
                        ucts = [n for _, n in node.children]
                        # Change cal_uct to 0 here.
                        ucts = sorted(ucts, key=lambda x: x.cal_uct(0), reverse=True)
                        if ucts[0].node_id != self.kmeans.predict(s[idx].cpu().reshape(1, -1)):
                            relay[state_idx] = max(idx - 1, 0)
                            break
                        node = ucts[0]
                        idx += 1
        
        #print(relay)
        return relay
    
    def predict(self, state):
        return self.kmeans.predict(state.cpu().reshape(1, -1)).item()
    
    def nodes_in_tree(self):
        node_num = 0
        for r in self.roots:
            node_num += self.dfs_cal_nodes(r)
        print(node_num)

    def dfs_cal_nodes(self, node):
        node_num = 1
        for a, c in node.children:
            node_num += self.dfs_cal_nodes(c)
        return node_num



        