import time
import numpy as np
from models.high_level_dynamics.trajectory.search import make_prefix, beam_plan_one_step, beam_plan_one_play

import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class WILLOW_BERRY():
    def __init__(self, b_value, b_state, n_child, parent_idx=-1):
        self.init_value = b_value
        self.berry_state = b_state
        self.n_child = n_child
        self.parent_idx = parent_idx
        
        self.berry_value = b_value
        self.rollout_value = 0.0
        self.child_count, self.child_value, self.child_idx = self.make_child_info()
                
    def make_child_info(self):
        return np.zeros(self.n_child), np.zeros(self.n_child), -np.ones(self.n_child)
        
    def get_value(self):
        return self.berry_value

    def get_state(self):
        return np.copy(self.berry_state)

    def update_value(self):
        self.berry_value = max(self.init_value, 0.95*np.amax(self.child_value), self.rollout_value)

    def update_rollout_value(self, reward):
        self.rollout_value = reward
        
    def get_info(self):
        return self.child_count, self.child_value, self.child_idx
        
    def get_child_index(self, child_idx):
        return int(self.child_idx[child_idx])

    def get_child_values(self):
        counts_ = np.copy(self.child_count)
        values_ = np.copy(self.child_value)
        count_sum = np.sum(counts_)
        count_sum = max(count_sum, 1)
        
        ucb_values = values_*10. + np.sqrt(2.0*np.log(count_sum)/(counts_+1e-3))
        return ucb_values
        
    def link_child(self, child_idx, berry_idx):
        self.child_idx[child_idx] = berry_idx
        
    def update_child(self, child_idx, child_value):
        self.child_count[child_idx] += 1
        self.child_value[child_idx] = child_value
        self.update_value()
        
    
              
class WILLOW():
    def __init__(self, init_state, goal_state, playbook,
                 dynamics_predictors, dynamics_indicators, discretizers, distance_metrics, max_depth):
        self.init_state = np.array(init_state)
        self.goal_state = np.array(goal_state)
        self.playbook = playbook
        self.dynamics_indicators = np.array(dynamics_indicators)
        self.dynamics_predictors = dynamics_predictors
        self.discretizers = discretizers
        self.distance_metrics = distance_metrics
        self.num_plays, self.max_depth, self.searched_depth = np.sum(dynamics_indicators), max_depth, 0
                
        # make root node
        root_q = self.measure_distance(init_state, 0)
        root_berry = WILLOW_BERRY(root_q, init_state, self.num_plays)
        # make node list
        self.berry_list = [root_berry]

    def do_mcts(self, n_tree_iter, n_k_iter=1, print_info=True):
        iter_list = []
        search_time_s = time.time()
        
        for n in range(int(n_tree_iter)):
            current_berry, current_idx = self.berry_list[0], 0
            current_state = np.copy(self.init_state)
            current_depth = 1
            
            b_history, iter_type = [], "start"
            while True:
                # choose action
                play_idx, new_one = self.choose_dynamics(current_berry, deterministic=False)
                b_history.append([current_berry, play_idx])

                # get next state
                if new_one: # make node
                    do_rollout = True
                    next_state = self.update_state(current_state, play_idx)
                    next_berry, next_idx = self.add_berry(next_state, current_berry, current_idx, play_idx)
                else:
                    do_rollout = False
                    next_idx = current_berry.get_child_index(play_idx)
                    assert next_idx >= 0
                    next_berry = self.berry_list[next_idx]
                    next_state = next_berry.get_state()

                # update information
                current_berry, current_idx, current_state = next_berry, next_idx, next_state
                current_depth += 1
                
                if current_depth > self.max_depth:
                    iter_type = "maxstep"
                    break
                
                # MCTS: ROLLOUT
                if do_rollout:
                    reward = self.perform_rollout(current_state, current_depth, play_idx)
                    current_berry.update_rollout_value(reward)
                    current_berry.update_value()
                    iter_type = "rollout"
                    break
                # # MCTS: SIMULATION
                # else:
                #     pass

            # MCTS: BACKPROPAGATION
            b_history.append([current_berry, -1])
            self.berry_backpropagation(b_history)
            
            root_berry = self.berry_list[0]
            root_value = root_berry.get_value()

            if self.searched_depth < current_depth:
                self.searched_depth = current_depth

            num_bars = 50
            progress_ = int((n+1)/n_tree_iter*num_bars)
            percent_ = int((n+1)/n_tree_iter*100)

            print_line = '  [MCTS][Progress {}{}:{:3d}%] Root-Berry {:.4f} | Max-Depth {}     '\
                .format('█'*progress_, ' '*(num_bars-progress_), percent_, root_value, self.searched_depth)
            print(print_line, end='\r')
        
        if print_info:
            print(print_line)
        else:
            print(" "*120, end='\r')
        search_time_e = time.time()
        self.search_time = search_time_e-search_time_s
        
    def optimal_selection(self):
        root_berry = self.berry_list[0]
        c_values, r_values, _ = root_berry.get_info()
        opt_idx = np.argmax(r_values)

        cum_indicator = np.cumsum(self.dynamics_indicators)
        dynamics_idx = np.argwhere(cum_indicator>opt_idx)[0][0]
        if dynamics_idx > 0:
            opt_idx -= cum_indicator[dynamics_idx-1]
        play = self.playbook[dynamics_idx][opt_idx]
        return play, dynamics_idx
        
    def add_berry(self, state, parent_berry, parent_index, play_index):
        q_val = self.measure_distance(state, play_index)
        berry = WILLOW_BERRY(q_val, state, self.num_plays, parent_index)
        # make node list
        berry_index = len(self.berry_list)
        self.berry_list.append(berry)
        # update parent node's information
        parent_berry.link_child(play_index, berry_index)
        return berry, berry_index

    def choose_dynamics(self, berry, deterministic=False):
        child_counts, child_vals, _ = berry.get_info()
        num_dynamics = len(self.dynamics_indicators)

        cum_indicator = np.cumsum(self.dynamics_indicators)
        cum_indicator = np.concatenate((np.array([0.0]), cum_indicator), 0)
        cum_indicator = np.array(cum_indicator, dtype=np.int32)

        cum_counts, cum_vals = [], []
        for i in range(1,num_dynamics+1):
            cum_counts.append(np.sum(child_counts[cum_indicator[i-1]:cum_indicator[i]]))
            cum_vals.append(np.amax(child_vals[cum_indicator[i-1]:cum_indicator[i]]))
        cum_counts, cum_vals = np.array(cum_counts), np.array(cum_vals)

        if deterministic:
            berry_vals = cum_vals
        else:
            count_sum = np.sum(cum_counts)
            count_sum = max(count_sum, 1)
            berry_vals = cum_vals*0.2 + np.sqrt(2.0*np.log(float(count_sum))/(cum_counts+1e-3))

        max_val = np.amax(berry_vals)
        max_is = np.argwhere(np.array(berry_vals)==max_val).reshape(-1)
        assert len(max_is) > 0
        dynamics_idx = np.random.choice(max_is)

        x = berry.get_state()
        discretizer = self.discretizers[dynamics_idx]
        dynamics_model = self.dynamics_predictors[dynamics_idx]
        # predict next state
        max_context_transitions = 2
        k_act, cdf_act = None, None
        obs_dim, act_dim = x.shape[-1], self.playbook[dynamics_idx].shape[-1]

        x_disc = make_prefix(discretizer, [], x, False)

        sequence = beam_plan_one_play(
            dynamics_model, x_disc,
            obs_dim, act_dim, max_context_transitions,
            k_act=k_act, cdf_act=cdf_act, device=device,
        )
        play_recon = discretizer.reconstruct(sequence[-act_dim:], [obs_dim,obs_dim+act_dim])

        dist_w = self.playbook[dynamics_idx] - play_recon
        dist_w = np.sum(np.square(dist_w), axis=1)
        berry_idx = np.argmin(dist_w)
        if dynamics_idx > 0:
            berry_idx += cum_indicator[dynamics_idx]        

        if child_counts[berry_idx] == 0: new_one = True
        else: new_one = False
        return berry_idx, new_one

    def update_state(self, current_state, play_index):
        # select dynamics predictor
        cum_indicator = np.cumsum(self.dynamics_indicators)
        dynamics_idx = np.argwhere(cum_indicator>play_index)[0][0]
        dynamics_model = self.dynamics_predictors[dynamics_idx]
        discretizer = self.discretizers[dynamics_idx]
        # select play
        if dynamics_idx > 0:
            play_index -= cum_indicator[dynamics_idx-1]
        play = self.playbook[dynamics_idx][play_index]
        # predict next state
        max_context_transitions = 2
        k_obs, cdf_obs = 1, None
        obs_dim, act_dim = current_state.shape[-1], self.playbook[dynamics_idx].shape[-1]

        x = np.concatenate((current_state, np.expand_dims(play, axis=0)), axis=1)
        x_disc = make_prefix(discretizer, [], x, False)

        sequence = beam_plan_one_step(
            dynamics_model, x_disc,
            obs_dim, act_dim, max_context_transitions,
            k_obs=k_obs, cdf_obs=cdf_obs, device=device,
        )
        sequence_recon = discretizer.reconstruct(sequence[-obs_dim:], [0,obs_dim])
        return sequence_recon

    def measure_distance(self, current_state, prev_play):
        cum_indicator = np.cumsum(self.dynamics_indicators)
        dynamics_idx = np.argwhere(cum_indicator>prev_play)[0][0]

        discretizer = self.discretizers[dynamics_idx]
        dynamics_model = self.dynamics_predictors[dynamics_idx]
        state_disc = make_prefix(discretizer, [], current_state, False)

        max_context_transitions = 2
        k_act, cdf_act = 1, None
        obs_dim, act_dim = current_state.shape[-1], self.playbook[dynamics_idx].shape[-1]

        sequence = beam_plan_one_play(
            dynamics_model, state_disc,
            obs_dim, act_dim, max_context_transitions,
            k_act=k_act, cdf_act=cdf_act, device=device,
        )
        play_recon = discretizer.reconstruct(sequence[-act_dim:], [obs_dim,obs_dim+act_dim])

        dist_w = self.playbook[dynamics_idx] - play_recon
        dist_w = np.sum(np.square(dist_w), axis=1)
        play_index = np.argmin(dist_w)
        play = self.playbook[dynamics_idx][play_index]
        play = np.expand_dims(play, axis=0)

        rl_model = self.distance_metrics[dynamics_idx]
        V_t = rl_model.get_qvalue2(current_state, self.goal_state, play)
        V_t = V_t.detach().cpu().numpy()[0]
        return V_t

    def perform_rollout(self, current_state, current_depth, prev_play):
        return 0.0
        
    def berry_backpropagation(self, berry_list):
        prev_berry, _ = berry_list[-1]
        for i, (berry, c_idx) in enumerate(reversed(berry_list[:-1])):
            c_value = prev_berry.get_value()
            berry.update_child(c_idx, c_value)
            prev_berry = berry
        
    def get_search_time(self):
        return self.search_time
        
    def print_mcts_result(self):
        root_berry = self.berry_list[0]
        c_values, r_values, _ = root_berry.get_info()
        child_idx = np.argmax(r_values)
        if child_idx == 1:
            print("  [MCTS-Finished] Willow chooses {}-st Berry. :D".format(child_idx))
        elif child_idx == 2:
            print("  [MCTS-Finished] Willow chooses {}-nd Berry. :D".format(child_idx))
        elif child_idx == 3:
            print("  [MCTS-Finished] Willow chooses {}-rd Berry. :D".format(child_idx))
        else:
            print("  [MCTS-Finished] Willow chooses {}-th Berry. :D".format(child_idx))
        #print("                  R-Values: ", r_values)
        #print("                  N-Counts: ", c_values)
        print("                  BEST: {:.4f} / {} ".format(r_values[child_idx],int(c_values[child_idx])))
        print("                  Planning-Time: {:.3f}s".format(self.search_time))
        print()
    
    
    
