import torch

import numpy as np
import core.ctree.cytree as tree
from core.utils import profile

from torch.cuda.amp import autocast as autocast

import time

class MCTS(object):
    def __init__(self, config, num_simulations, device = None):
        self.config = config
        self.device = device or config.device
        self.num_simulations = num_simulations

    def reanalyze_search(self, roots, model, hidden_state_roots, reward_hidden_roots, future_actions, future_value_prefixes, to_numpy=True):
        """Do MCTS for the roots (a batch of root nodes in parallel). Parallel in model inference
        Parameters
        ----------
        roots: Any
            a batch of expanded root nodes
        hidden_state_roots: list
            the hidden states of the roots
        reward_hidden_roots: list
            the value prefix hidden states in LSTM of the roots
        """
        gpu_to_cpu_time = 0
        cpu_to_gpu_time = 0
        model_inference_time = 0
        with torch.no_grad():
            # preparation
            num = roots.num
            device = self.device
            pb_c_base, pb_c_init, discount = self.config.pb_c_base, self.config.pb_c_init, self.config.discount
            # the data storage of hidden states: storing the states of all the tree nodes
            # hidden_state_pool = [hidden_state_roots]
            hidden_state_pool = torch.zeros((self.num_simulations + 1, *hidden_state_roots.shape), dtype=hidden_state_roots.dtype, device=device)
            hidden_state_pool[0] = hidden_state_roots
            # 1 x batch x 64
            # the data storage of value prefix hidden states in LSTM
            reward_hidden_c_pool = torch.zeros((self.num_simulations + 1, *reward_hidden_roots[0].shape), dtype=reward_hidden_roots[0].dtype, device=device)
            reward_hidden_h_pool = torch.zeros((self.num_simulations + 1, *reward_hidden_roots[1].shape), dtype=reward_hidden_roots[1].dtype, device=device)
            reward_hidden_c_pool[0] = reward_hidden_roots[0]
            reward_hidden_h_pool[0] = reward_hidden_roots[1]
            # print(hidden_state_pool.shape, reward_hidden_c_pool.shape, reward_hidden_h_pool.shape, flush=True)
            # the index of each layer in the tree
            hidden_state_index_x = 0
            # minimax value storage
            min_max_stats_lst = tree.MinMaxStatsList(num)
            min_max_stats_lst.set_delta(self.config.value_delta_max)
            horizons = self.config.lstm_horizon_len
            
            # label for whether following future trajs
            future_length = len(future_actions[0])
            future_id = -np.ones((self.num_simulations + 1, num), dtype=np.int32)
            future_id[0, :] = 0

            for index_simulation in range(self.num_simulations):
                hidden_states = []
                hidden_states_c_reward = []
                hidden_states_h_reward = []

                # prepare a result wrapper to transport results between python and c++ parts
                results = tree.ResultsWrapper(num)
                # traverse to select actions for each root
                # hidden_state_index_x_lst: the first index of leaf node states in hidden_state_pool
                # hidden_state_index_y_lst: the second index of leaf node states in hidden_state_pool
                # the hidden state of the leaf node is hidden_state_pool[x, y]; value prefix states are the same
                hidden_state_index_x_lst, hidden_state_index_y_lst, last_actions = tree.batch_traverse(roots, pb_c_base, pb_c_init, discount, min_max_stats_lst, results)
                # obtain the search horizon for leaf nodes
                search_lens = results.get_search_len()

                time0 = time.time()

                hidden_state_index_x_lst_np = np.asarray(hidden_state_index_x_lst)
                hidden_state_index_x_lst = torch.from_numpy(hidden_state_index_x_lst_np).to(device).long().reshape(-1)
                hidden_states = torch.gather(hidden_state_pool, 0, hidden_state_index_x_lst.reshape(1, -1, 1, 1, 1).repeat(1, 1, *hidden_state_roots.shape[-3:])).squeeze(0)
                hidden_states_c_reward = torch.gather(reward_hidden_c_pool, 0, hidden_state_index_x_lst.reshape(1, 1, -1, 1).repeat(1, 1, 1, reward_hidden_roots[0].shape[-1])).squeeze(0)
                hidden_states_h_reward = torch.gather(reward_hidden_h_pool, 0, hidden_state_index_x_lst.reshape(1, 1, -1, 1).repeat(1, 1, 1, reward_hidden_roots[1].shape[-1])).squeeze(0)

                last_actions_np = np.asarray(last_actions)
                last_actions = torch.from_numpy(last_actions_np).long().to(device).unsqueeze(1)
                time1 = time.time()
                cpu_to_gpu_time += time1 - time0

                time0 = time.time()
                # evaluation for leaf nodes
                if self.config.amp_type == 'torch_amp':
                    with autocast():
                        network_output = model.recurrent_inference(hidden_states, (hidden_states_c_reward, hidden_states_h_reward), last_actions, to_numpy=to_numpy)
                else:
                    network_output = model.recurrent_inference(hidden_states, (hidden_states_c_reward, hidden_states_h_reward), last_actions, to_numpy=to_numpy)
                time1 = time.time()
                model_inference_time += time1 - time0

                time0 = time.time()
                if to_numpy:
                    hidden_state_nodes = network_output.hidden_state
                    value_prefix_pool = network_output.value_prefix.reshape(-1).tolist()
                    value_pool = network_output.value.reshape(-1).tolist()
                    policy_logits_pool = network_output.policy_logits
                    policy_logits_pool = policy_logits_pool - policy_logits_pool.max(axis=-1).reshape(-1, 1)
                    policy_logits_pool = np.exp(policy_logits_pool)
                    policy_logits_pool = policy_logits_pool / policy_logits_pool.sum(axis=-1).reshape(-1, 1)
                    policy_logits_pool = policy_logits_pool.tolist()
                    reward_hidden_nodes = network_output.reward_hidden
                else:
                    hidden_state_nodes = network_output.hidden_state
                    value_prefix_pool = network_output.value_prefix.reshape(-1).detach().cpu().numpy().tolist()
                    value_pool = network_output.value.reshape(-1).detach().cpu().numpy().tolist()
                    policy_logits_pool = network_output.policy_logits.detach().cpu().numpy()
                    policy_logits_pool = policy_logits_pool - policy_logits_pool.max(axis=-1).reshape(-1, 1)
                    policy_logits_pool = np.exp(policy_logits_pool)
                    policy_logits_pool = policy_logits_pool / policy_logits_pool.sum(axis=-1).reshape(-1, 1)
                    policy_logits_pool = policy_logits_pool.tolist()
                    reward_hidden_nodes = network_output.reward_hidden
                time1 = time.time()
                gpu_to_cpu_time += time1 - time0


                # hidden_state_pool.append(hidden_state_nodes)
                hidden_state_pool[index_simulation + 1] = hidden_state_nodes
                # reset 0
                # reset the hidden states in LSTM every horizon steps in search
                # only need to predict the value prefix in a range (eg: s0 -> s5)
                assert horizons > 0
                reset_idx = (np.array(search_lens) % horizons == 0)
                assert len(reset_idx) == num
                reward_hidden_nodes[0][:, reset_idx, :] = 0
                reward_hidden_nodes[1][:, reset_idx, :] = 0
                is_reset_lst = reset_idx.astype(np.int32).tolist()

                reward_hidden_c_pool[index_simulation + 1] = reward_hidden_nodes[0]
                reward_hidden_h_pool[index_simulation + 1] = reward_hidden_nodes[1]
                hidden_state_index_x += 1

                for iy, ix in enumerate(hidden_state_index_x_lst_np):
                    if future_id[ix, iy] >= 0 and future_id[ix, iy] < future_length:
                        last_future_id = future_id[ix, iy]
                        if last_actions_np[iy] == future_actions[iy, last_future_id] or future_actions[iy, last_future_id] == -1:
                            # matched actions
                            future_id[hidden_state_index_x, iy] = last_future_id + 1
                            # set to ground-truth value prefix if the branch is in replay buffer
                            truth_ratio = 1.
                            value_prefix_pool[iy] = value_prefix_pool[iy] * (1 - truth_ratio) + future_value_prefixes[iy, last_future_id] * truth_ratio

                time0 = time.time()
                # backpropagation along the search path to update the attributes
                tree.batch_back_propagate(hidden_state_index_x, discount,
                                          value_prefix_pool, value_pool, policy_logits_pool,
                                          min_max_stats_lst, results, is_reset_lst)
                time1 = time.time()

        return cpu_to_gpu_time, gpu_to_cpu_time, model_inference_time

    # @profile
    def gpu_search(self, roots, model, hidden_state_roots, reward_hidden_roots, to_numpy=True):
        """Do MCTS for the roots (a batch of root nodes in parallel). Parallel in model inference
        Parameters
        ----------
        roots: Any
            a batch of expanded root nodes
        hidden_state_roots: list
            the hidden states of the roots
        reward_hidden_roots: list
            the value prefix hidden states in LSTM of the roots
        """
        gpu_to_cpu_time = 0
        cpu_to_gpu_time = 0
        model_inference_time = 0
        with torch.no_grad():
            # preparation
            num = roots.num
            device = self.device
            pb_c_base, pb_c_init, discount = self.config.pb_c_base, self.config.pb_c_init, self.config.discount
            # the data storage of hidden states: storing the states of all the tree nodes
            # hidden_state_pool = [hidden_state_roots]
            hidden_state_pool = torch.zeros((self.num_simulations + 1, *hidden_state_roots.shape), dtype=hidden_state_roots.dtype, device=device)
            hidden_state_pool[0] = hidden_state_roots
            # 1 x batch x 64
            # the data storage of value prefix hidden states in LSTM
            reward_hidden_c_pool = torch.zeros((self.num_simulations + 1, *reward_hidden_roots[0].shape), dtype=reward_hidden_roots[0].dtype, device=device)
            reward_hidden_h_pool = torch.zeros((self.num_simulations + 1, *reward_hidden_roots[1].shape), dtype=reward_hidden_roots[1].dtype, device=device)
            reward_hidden_c_pool[0] = reward_hidden_roots[0]
            reward_hidden_h_pool[0] = reward_hidden_roots[1]
            # print(hidden_state_pool.shape, reward_hidden_c_pool.shape, reward_hidden_h_pool.shape, flush=True)
            # the index of each layer in the tree
            hidden_state_index_x = 0
            # minimax value storage
            min_max_stats_lst = tree.MinMaxStatsList(num)
            min_max_stats_lst.set_delta(self.config.value_delta_max)
            horizons = self.config.lstm_horizon_len

            for index_simulation in range(self.num_simulations):
                hidden_states = []
                hidden_states_c_reward = []
                hidden_states_h_reward = []

                # prepare a result wrapper to transport results between python and c++ parts
                results = tree.ResultsWrapper(num)
                # traverse to select actions for each root
                # hidden_state_index_x_lst: the first index of leaf node states in hidden_state_pool
                # hidden_state_index_y_lst: the second index of leaf node states in hidden_state_pool
                # the hidden state of the leaf node is hidden_state_pool[x, y]; value prefix states are the same
                hidden_state_index_x_lst, hidden_state_index_y_lst, last_actions = tree.batch_traverse(roots, pb_c_base, pb_c_init, discount, min_max_stats_lst, results)
                # obtain the search horizon for leaf nodes
                search_lens = results.get_search_len()

                time0 = time.time()
                
                hidden_state_index_x_lst = torch.from_numpy(np.asarray(hidden_state_index_x_lst)).to(device).long().reshape(-1)
                hidden_states = torch.gather(hidden_state_pool, 0, hidden_state_index_x_lst.reshape(1, -1, 1, 1, 1).repeat(1, 1, *hidden_state_roots.shape[-3:])).squeeze(0)
                hidden_states_c_reward = torch.gather(reward_hidden_c_pool, 0, hidden_state_index_x_lst.reshape(1, 1, -1, 1).repeat(1, 1, 1, reward_hidden_roots[0].shape[-1])).squeeze(0)
                hidden_states_h_reward = torch.gather(reward_hidden_h_pool, 0, hidden_state_index_x_lst.reshape(1, 1, -1, 1).repeat(1, 1, 1, reward_hidden_roots[1].shape[-1])).squeeze(0)

                last_actions = torch.from_numpy(np.asarray(last_actions)).long().to(device).unsqueeze(1)
                time1 = time.time()
                cpu_to_gpu_time += time1 - time0

                time0 = time.time()
                # evaluation for leaf nodes
                if self.config.amp_type == 'torch_amp':
                    with autocast():
                        network_output = model.recurrent_inference(hidden_states, (hidden_states_c_reward, hidden_states_h_reward), last_actions, to_numpy=to_numpy)
                else:
                    network_output = model.recurrent_inference(hidden_states, (hidden_states_c_reward, hidden_states_h_reward), last_actions, to_numpy=to_numpy)
                time1 = time.time()
                model_inference_time += time1 - time0

                time0 = time.time()
                if to_numpy:
                    hidden_state_nodes = network_output.hidden_state
                    value_prefix_pool = network_output.value_prefix.reshape(-1).tolist()
                    value_pool = network_output.value.reshape(-1).tolist()
                    policy_logits_pool = network_output.policy_logits
                    policy_logits_pool = policy_logits_pool - policy_logits_pool.max(axis=-1).reshape(-1, 1)
                    policy_logits_pool = np.exp(policy_logits_pool)
                    policy_logits_pool = policy_logits_pool / policy_logits_pool.sum(axis=-1).reshape(-1, 1)
                    policy_logits_pool = policy_logits_pool.tolist()
                    reward_hidden_nodes = network_output.reward_hidden
                else:
                    hidden_state_nodes = network_output.hidden_state
                    value_prefix_pool = network_output.value_prefix.reshape(-1).detach().cpu().numpy().tolist()
                    value_pool = network_output.value.reshape(-1).detach().cpu().numpy().tolist()
                    policy_logits_pool = network_output.policy_logits.detach().cpu().numpy()
                    policy_logits_pool = policy_logits_pool - policy_logits_pool.max(axis=-1).reshape(-1, 1)
                    policy_logits_pool = np.exp(policy_logits_pool)
                    policy_logits_pool = policy_logits_pool / policy_logits_pool.sum(axis=-1).reshape(-1, 1)
                    policy_logits_pool = policy_logits_pool.tolist()
                    reward_hidden_nodes = network_output.reward_hidden
                time1 = time.time()
                gpu_to_cpu_time += time1 - time0


                # hidden_state_pool.append(hidden_state_nodes)
                hidden_state_pool[index_simulation + 1] = hidden_state_nodes
                # reset 0
                # reset the hidden states in LSTM every horizon steps in search
                # only need to predict the value prefix in a range (eg: s0 -> s5)
                assert horizons > 0
                reset_idx = (np.array(search_lens) % horizons == 0)
                assert len(reset_idx) == num
                reward_hidden_nodes[0][:, reset_idx, :] = 0
                reward_hidden_nodes[1][:, reset_idx, :] = 0
                is_reset_lst = reset_idx.astype(np.int32).tolist()

                # reward_hidden_c_pool.append(reward_hidden_nodes[0])
                # reward_hidden_h_pool.append(reward_hidden_nodes[1])
                reward_hidden_c_pool[index_simulation + 1] = reward_hidden_nodes[0]
                reward_hidden_h_pool[index_simulation + 1] = reward_hidden_nodes[1]
                hidden_state_index_x += 1

                time0 = time.time()
                # backpropagation along the search path to update the attributes
                tree.batch_back_propagate(hidden_state_index_x, discount,
                                          value_prefix_pool, value_pool, policy_logits_pool,
                                          min_max_stats_lst, results, is_reset_lst)
                time1 = time.time()

        return cpu_to_gpu_time, gpu_to_cpu_time, model_inference_time
    
    def search(self, roots, model, hidden_state_roots, reward_hidden_roots, to_numpy=True):
        """Do MCTS for the roots (a batch of root nodes in parallel). Parallel in model inference
        Parameters
        ----------
        roots: Any
            a batch of expanded root nodes
        hidden_state_roots: list
            the hidden states of the roots
        reward_hidden_roots: list
            the value prefix hidden states in LSTM of the roots
        """
        gpu_to_cpu_time = 0
        cpu_to_gpu_time = 0
        model_inference_time = 0
        with torch.no_grad():
            model.eval()

            # preparation
            num = roots.num
            device = self.device
            pb_c_base, pb_c_init, discount = self.config.pb_c_base, self.config.pb_c_init, self.config.discount
            # the data storage of hidden states: storing the states of all the tree nodes
            hidden_state_pool = [hidden_state_roots]
            # 1 x batch x 64
            # the data storage of value prefix hidden states in LSTM
            reward_hidden_c_pool = [reward_hidden_roots[0]]
            reward_hidden_h_pool = [reward_hidden_roots[1]]
            # the index of each layer in the tree
            hidden_state_index_x = 0
            # minimax value storage
            min_max_stats_lst = tree.MinMaxStatsList(num)
            min_max_stats_lst.set_delta(self.config.value_delta_max)
            horizons = self.config.lstm_horizon_len

            for index_simulation in range(self.num_simulations):
                hidden_states = []
                hidden_states_c_reward = []
                hidden_states_h_reward = []

                # prepare a result wrapper to transport results between python and c++ parts
                results = tree.ResultsWrapper(num)
                # traverse to select actions for each root
                # hidden_state_index_x_lst: the first index of leaf node states in hidden_state_pool
                # hidden_state_index_y_lst: the second index of leaf node states in hidden_state_pool
                # the hidden state of the leaf node is hidden_state_pool[x, y]; value prefix states are the same
                hidden_state_index_x_lst, hidden_state_index_y_lst, last_actions = tree.batch_traverse(roots, pb_c_base, pb_c_init, discount, min_max_stats_lst, results)
                # obtain the search horizon for leaf nodes
                search_lens = results.get_search_len()

                # obtain the states for leaf nodes
                for ix, iy in zip(hidden_state_index_x_lst, hidden_state_index_y_lst):
                    hidden_states.append(hidden_state_pool[ix][iy])
                    hidden_states_c_reward.append(reward_hidden_c_pool[ix][0][iy])
                    hidden_states_h_reward.append(reward_hidden_h_pool[ix][0][iy])

                time0 = time.time()
                if to_numpy:
                    # print(type(hidden_states), hidden_states)
                    hidden_states = torch.from_numpy(np.asarray(hidden_states)).to(device).float()
                    hidden_states_c_reward = torch.from_numpy(np.asarray(hidden_states_c_reward)).to(device).unsqueeze(0)
                    hidden_states_h_reward = torch.from_numpy(np.asarray(hidden_states_h_reward)).to(device).unsqueeze(0)
                else:
                    hidden_states = torch.stack(hidden_states, dim=0)
                    hidden_states_c_reward = torch.stack(hidden_states_c_reward, dim=0).unsqueeze(0)
                    hidden_states_h_reward = torch.stack(hidden_states_h_reward, dim=0).unsqueeze(0)

                last_actions = torch.from_numpy(np.asarray(last_actions)).long().to(device).unsqueeze(1)
                time1 = time.time()
                cpu_to_gpu_time += time1 - time0

                time0 = time.time()
                # evaluation for leaf nodes
                if self.config.amp_type == 'torch_amp':
                    with autocast():
                        network_output = model.recurrent_inference(hidden_states, (hidden_states_c_reward, hidden_states_h_reward), last_actions, to_numpy=to_numpy)
                else:
                    network_output = model.recurrent_inference(hidden_states, (hidden_states_c_reward, hidden_states_h_reward), last_actions, to_numpy=to_numpy)
                time1 = time.time()
                model_inference_time += time1 - time0

                time0 = time.time()
                if to_numpy:
                    hidden_state_nodes = network_output.hidden_state
                    value_prefix_pool = network_output.value_prefix.reshape(-1).tolist()
                    value_pool = network_output.value.reshape(-1).tolist()
                    policy_logits_pool = network_output.policy_logits
                    policy_logits_pool = policy_logits_pool - policy_logits_pool.max(axis=-1).reshape(-1, 1)
                    policy_logits_pool = np.exp(policy_logits_pool)
                    policy_logits_pool = policy_logits_pool / policy_logits_pool.sum(axis=-1).reshape(-1, 1)
                    policy_logits_pool = policy_logits_pool.tolist()
                    reward_hidden_nodes = network_output.reward_hidden
                else:
                    hidden_state_nodes = network_output.hidden_state
                    value_prefix_pool = network_output.value_prefix.reshape(-1).detach().cpu().numpy().tolist()
                    value_pool = network_output.value.reshape(-1).detach().cpu().numpy().tolist()
                    policy_logits_pool = network_output.policy_logits.detach().cpu().numpy()
                    policy_logits_pool = policy_logits_pool - policy_logits_pool.max(axis=-1).reshape(-1, 1)
                    policy_logits_pool = np.exp(policy_logits_pool)
                    policy_logits_pool = policy_logits_pool / policy_logits_pool.sum(axis=-1).reshape(-1, 1)
                    policy_logits_pool = policy_logits_pool.tolist()
                    reward_hidden_nodes = network_output.reward_hidden
                time1 = time.time()
                gpu_to_cpu_time += time1 - time0


                hidden_state_pool.append(hidden_state_nodes)
                # reset 0
                # reset the hidden states in LSTM every horizon steps in search
                # only need to predict the value prefix in a range (eg: s0 -> s5)
                assert horizons > 0
                reset_idx = (np.array(search_lens) % horizons == 0)
                assert len(reset_idx) == num
                reward_hidden_nodes[0][:, reset_idx, :] = 0
                reward_hidden_nodes[1][:, reset_idx, :] = 0
                is_reset_lst = reset_idx.astype(np.int32).tolist()

                reward_hidden_c_pool.append(reward_hidden_nodes[0])
                reward_hidden_h_pool.append(reward_hidden_nodes[1])
                hidden_state_index_x += 1

                # backpropagation along the search path to update the attributes
                tree.batch_back_propagate(hidden_state_index_x, discount,
                                          value_prefix_pool, value_pool, policy_logits_pool,
                                          min_max_stats_lst, results, is_reset_lst)
        return cpu_to_gpu_time, gpu_to_cpu_time, model_inference_time
