#include "cnode.h"

#include <cmath>
#include <stack>
#include <sys/time.h>
#include <map>
#include <cstdlib>
#include <cstring>
#include <omp.h>
#include <optional>
#include <random>
#include <limits>
#include <unordered_map>


namespace tree
{
    CNode::CNode(float prior, float pred_prob, float beta, float beta_hat, bool is_root, float rho, float lam)
        : visit_count(0), num_children(0), hidden_state_index_x(-1),
          reward(0.), pred_value(0.),
          prior(prior), pred_prob(pred_prob), beta(beta), beta_hat(beta_hat),
          is_root(is_root),
          subtree_info(rho, lam),
          children(), children_action()
    {
        
    }

    CNode::~CNode() {}

    bool CNode::expanded()
    {
        
        return this->num_children > 0;
    }

    float CNode::value()
    {
        
        if (!this->expanded())
        {
            return 0;
        }
        else
        {
            return this->subtree_info.value_estimation();
        }
    }

    float CNode::get_qsa(float discount)
    {
        
        return this->reward + discount * this->value();
    }

    void CNode::get_marginal_visit_count(tools::Array2D<int> logits)
    {
        for (int i = 0; i < this->num_children; ++i)
        {
            CNode *child = this->children[i];
            for (size_t j = 0; j < logits.d1; ++j)
            {
                logits(j, this->children_action[i][j]) += child->visit_count;
            }
        }
    }

    void CNode::get_marginal_priors(tools::Array2D<float> priors)
    {
        for (int i = 0; i < this->num_children; ++i)
        {
            CNode *child = this->children[i];
            for (size_t j = 0; j < priors.d1; ++j)
            {
                priors(j, this->children_action[i][j]) += child->prior;
            }
        }
    }

    void CNode::get_sampled_visit_count(int *logits)
    {
        for (int i = 0; i < this->num_children; ++i)
        {
            logits[i] = this->children[i]->visit_count;
        }
    }

    void CNode::get_sampled_pred_probs(float *probs)
    {
        for (int i = 0; i < this->num_children; ++i)
        {
            probs[i] = this->children[i]->pred_prob;
        }
    }

    void CNode::get_sampled_beta(float *probs)
    {
        for (int i = 0; i < this->num_children; ++i)
        {
            probs[i] = this->children[i]->beta;
        }
    }

    void CNode::get_sampled_beta_hat(float *probs)
    {
        for (int i = 0; i < this->num_children; ++i)
        {
            probs[i] = this->children[i]->beta_hat;
        }
    }

    void CNode::get_sampled_priors(float *priors)
    {
        for (int i = 0; i < this->num_children; ++i)
        {
            priors[i] = this->children[i]->prior;
        }
    }

    void CNode::get_sampled_imp_ratio(float *imp_ratio)
    {
        for (int i = 0; i < this->num_children; ++i)
        {
            imp_ratio[i] = this->children[i]->beta_hat / this->children[i]->beta * this->children[i]->pred_prob;
        }
    }

    void CNode::get_sampled_pred_values(float *values)
    {
        for (int i = 0; i < this->num_children; ++i)
        {
            values[i] = this->children[i]->pred_value;
        }
    }

    void CNode::get_sampled_mcts_values(float *values)
    {
        for (int i = 0; i < this->num_children; ++i)
        {
            values[i] = this->children[i]->value();
        }
    }

    void CNode::get_sampled_rewards(float *rewards)
    {
        for (int i = 0; i < this->num_children; ++i)
        {
            rewards[i] = this->children[i]->reward;
        }
    }

    void CNode::get_sampled_qvalues(float *values, float discount)
    {
        for (int i = 0; i < this->num_children; ++i)
        {
            values[i] = this->children[i]->get_qsa(discount);
        }
    }

    

    SearchResult::SearchResult(int length_max)
    {
        search_path = (CNode **)malloc(sizeof(CNode *) * (length_max + 2));
    }
    SearchResult::~SearchResult()
    {
        free(search_path);
    }

    

    CTree::CTree(int agent_num, int action_space_size, int sampled_times, int simulation_num, float tree_value_stat_delta_lb, CNode *node_pool_ptr, unsigned int seed, float rho, float lam)
        : gen(seed), agent_num(agent_num), action_space_size(action_space_size), sampled_times(sampled_times), tot_nodes(0),
          rho(rho), lam(lam), use_adaptive(true),  
          node_pool_ptr(node_pool_ptr), root(node_pool_ptr), minmax_stat(tree_value_stat_delta_lb), result(simulation_num)
    {
        
    }

    CTree::~CTree()
    {
        for (int i = 0; i < this->tot_nodes; ++i)
        {
            this->node_pool_ptr[i].~CNode();
        }
    }

    
    void CTree::prepare(float reward, float value, tools::Array2D<float> policy_probs, tools::Array2D<float> beta, int sampled_times, float noise_eps, tools::Array2D<float> noises, tools::Array2D<float> hypernet_params)
    {
        
        new (this->root) CNode(1., 1., 1., 1., true, this->rho, this->lam);

        
        if (this->use_adaptive)
        {
            this->root->adaptive_node.n = this->agent_num;
            this->root->adaptive_node.k = this->action_space_size;
            this->root->adaptive_node.initialize(&hypernet_params(0, 0));
        }

        ++(this->tot_nodes);
        this->expand(this->root, 0, reward, value, policy_probs, beta, sampled_times, noise_eps, noises, hypernet_params);
        this->root->visit_count += 1;
        this->root->subtree_info.update(value, 0);
    }

    struct VecHash {
        size_t operator()(const std::vector<int>& v) const noexcept {
            
            size_t h = 1469598103934665603ull;
            for (int x : v) {
                h ^= static_cast<size_t>(x) + 0x9e3779b97f4a7c15ull + (h << 6) + (h >> 2);
            }
            return h;
        }
    };


    
    void CTree::expand(CNode *node, int hidden_state_index_x, float reward, float value, tools::Array2D<float> policy_probs, tools::Array2D<float> beta, int sampled_times, float noise_eps, tools::Array2D<float> noises, tools::Array2D<float> hypernet_params)
    {
        
        node->hidden_state_index_x = hidden_state_index_x;
        node->reward = reward;
        node->pred_value = value;

        
        std::unordered_map<std::vector<int>, int, VecHash> beta_hat;
        beta_hat.reserve(static_cast<size_t>(sampled_times) * 2);

        std::vector<std::discrete_distribution<int>> dists;
        dists.reserve(this->agent_num);
        for (int i = 0; i < this->agent_num; ++i) {
            dists.emplace_back(std::discrete_distribution<int>(&beta(i, 0), &beta(i + 1, 0)));
        }

        
        for (int t = 0; t < sampled_times; ++t) {
            std::vector<int> a(this->agent_num, 0);
            for (int i = 0; i < this->agent_num; ++i) {
                a[i] = dists[i](this->gen);
            }
            beta_hat[a] += 1;
        }

        node->num_children = static_cast<int>(beta_hat.size());
        node->children.reserve(node->num_children);
        node->children_action.reserve(node->num_children);

        
        constexpr double kEps = 1e-20;
        constexpr double kLogClip = 20.0;     
        for (const auto& kv : beta_hat) {
            const std::vector<int>& a = kv.first;
            const double count = static_cast<double>(kv.second);
        
            const double betahat_prob = count / static_cast<double>(sampled_times);
        
            double log_pred_prob = 0.0;
            double log_beta_prob = 0.0;
            double log_pi_for_prior = 0.0;
        
            for (int i = 0; i < this->agent_num; ++i) {
                const int ai = a[i];
            
                const double pi = static_cast<double>(policy_probs(i, ai));
                const double b  = static_cast<double>(beta(i, ai));
            
                log_pred_prob += std::log(std::max(pi, kEps));
                log_beta_prob += std::log(std::max(b,  kEps));
            
                
                double p = pi;
                if (noise_eps > 0) {
                    const double nz = static_cast<double>(noises(i, ai));
                    p = pi * (1.0 - noise_eps) + nz * noise_eps;
                }
                log_pi_for_prior += std::log(std::max(p, kEps));
            }

            
            double log_prior = log_pi_for_prior + std::log(std::max(betahat_prob, kEps)) - log_beta_prob;
            if (log_prior >  kLogClip) log_prior =  kLogClip;
            if (log_prior < -kLogClip) log_prior = -kLogClip;
                
            const float prior     = static_cast<float>(std::exp(log_prior));
            const float pred_prob = static_cast<float>(std::exp(log_pred_prob));
            const float beta_prob = static_cast<float>(std::exp(log_beta_prob));
            const float betahat_f = static_cast<float>(betahat_prob);
                
            new (this->node_pool_ptr + this->tot_nodes)
                CNode(prior, pred_prob, beta_prob, betahat_f, false, this->rho, this->lam);
                
            if (this->use_adaptive) {
                CNode* new_child = this->node_pool_ptr + this->tot_nodes;
                new_child->adaptive_node.n = this->agent_num;
                new_child->adaptive_node.k = this->action_space_size;
                new_child->adaptive_node.initialize(&hypernet_params(0, 0));
            }
        
            ++(this->tot_nodes);
            node->children.push_back(this->node_pool_ptr + this->tot_nodes - 1);
            node->children_action.push_back(a);
        }
        
    
        
        
    }

    float CTree::ucb_score(CNode *child, float parent_q, int total_children_visit_counts, float pb_c_base, float pb_c_init, float discount)
    {
        
        float pb_c = 0.0, prior_score = 0.0, value_score = 0.0;
        pb_c = log((total_children_visit_counts + pb_c_base + 1) / pb_c_base) + pb_c_init;
        pb_c *= (sqrt(total_children_visit_counts) / (child->visit_count + 1));

        prior_score = pb_c * child->prior;
        
        if (child->visit_count == 0)
        {
            value_score = 0;
        }
        else
        {
            value_score = child->get_qsa(discount) - parent_q;
        }

        value_score = this->minmax_stat.normalize(value_score);

        if (value_score < 0)
            value_score = 0;
        if (value_score > 1)
            value_score = 1;

        float ucb_value = prior_score + value_score;

        
        child->ucb_value = ucb_value;
        return ucb_value;
    }

    int CTree::select_child(CNode *node, float pb_c_base, float pb_c_init, float discount, float parent_q)
    {
        
        
        
        float max_score = -std::numeric_limits<float>::infinity();
        
        
        
        const float epsilon = this->use_adaptive ? 0.01 : 0.000001;
        std::vector<int> max_index_lst;
        for (int child_index = 0; child_index < node->num_children; ++child_index)
        {
            CNode *child = node->children[child_index];
            float temp_score;
            
            
            if (this->use_adaptive)
            {
                
                temp_score = node->adaptive_node.getTildeEtaWithRegularizer(node->children_action[child_index]);
            }
            else
            {
                
                temp_score = ucb_score(child, parent_q, node->visit_count - 1, pb_c_base, pb_c_init, discount);
            }

            if (max_score < temp_score)
            {
                max_score = temp_score;

                max_index_lst.clear();
                max_index_lst.push_back(child_index);
            }
            else if (temp_score >= max_score - epsilon)
            {
                max_index_lst.push_back(child_index);
            }
        }

        int child_index = 0;
        if (max_index_lst.size() > 0)
        {
            auto rand_index = this->gen() % max_index_lst.size();
            child_index = max_index_lst[rand_index];
        }

        return child_index;
    }

    void CTree::select_path(float pb_c_base, float pb_c_init, float discount)
    {
        
        CNode *node = this->root;
        this->result.search_len = 0;
        this->result.search_path[this->result.search_len] = node;

        while (node->expanded())
        {
            int child_index;
            if (node->is_root && node->visit_count <= node->num_children)
            {
                child_index = node->visit_count - 1;
            }
            else
            {
                child_index = select_child(node, pb_c_base, pb_c_init, discount, node->pred_value);
            }

            
            if (this->use_adaptive)
            {
                node->adaptive_node.last_action = node->adaptive_node.action;
                node->adaptive_node.action = node->children_action[child_index];
                node->adaptive_node.resampleDirectionsAndNeighbors();
            }

            
            this->result.action = &(node->children_action[child_index][0]);
            node = node->children[child_index];
            this->result.search_len += 1;
            this->result.search_path[this->result.search_len] = node;
        }

        CNode *parent = this->result.search_path[this->result.search_len - 1];
        this->result.idx = (parent->hidden_state_index_x);
        this->result.leaf = node;
    }

    void CTree::back_propagate(float value, float discount)
    {
        
        float bootstrap_value = value;
        int path_len = this->result.search_len;

        for (int i = path_len; i >= 0; --i)
        {
            CNode *node = this->result.search_path[i];

            if (i != path_len && i != 0)
            {
                
                CNode *father = this->result.search_path[i - 1];
                this->minmax_stat.remove(node->get_qsa(discount) - father->pred_value);
            }

            node->visit_count += 1;

            node->subtree_info.update(bootstrap_value, path_len - i);

            if (i != 0)
            {
                
                CNode *father = this->result.search_path[i - 1];
                this->minmax_stat.insert(node->get_qsa(discount) - father->pred_value);
            }

            bootstrap_value = node->reward + discount * bootstrap_value;
        }
    }

    
    void CTree::expand_and_backprop(int hidden_state_index_x, float discount, int sampled_times, float reward, float value, tools::Array2D<float> policy_prob, tools::Array2D<float> beta, tools::Array2D<float> hypernet_params)
    {
        
        tools::Array2D<float> _(nullptr, 0, 0); 
        expand(this->result.leaf, hidden_state_index_x, reward, value, policy_prob, beta, sampled_times, 0., _, hypernet_params);
        back_propagate(value, discount);
    }

    void CTree::get_root_value(float *val)
    {
        *val = this->root->value();
    }
    void CTree::get_root_marginal_visit_count(tools::Array2D<int> logits)
    {
        root->get_marginal_visit_count(logits);
    }
    void CTree::get_root_marginal_priors(tools::Array2D<float> priors)
    {
        root->get_marginal_priors(priors);
    }
    void CTree::get_root_sampled_actions(tools::Array2D<int> actions)
    {
        tools::my_assert(root->children_action.size() == actions.d1 && root->children_action[0].size() == actions.d2,
                         "Error in `CTree::get_root_sampled_actions`: dimensions of `root->children_action` does not match that of the receiving buffer.");
        for (size_t i = 0; i < actions.d1; ++i)
            for (size_t j = 0; j < actions.d2; ++j)
                actions(i, j) = root->children_action[i][j];
    }
    void CTree::get_root_sampled_visit_count(int *logits)
    {
        root->get_sampled_visit_count(logits);
    }
    void CTree::get_root_sampled_pred_probs(float *probs)
    {
        root->get_sampled_pred_probs(probs);
    }
    void CTree::get_root_sampled_imp_ratio(float *imp_ratio)
    {
        root->get_sampled_imp_ratio(imp_ratio);
    }
    void CTree::get_root_sampled_beta(float *probs)
    {
        root->get_sampled_beta(probs);
    }
    void CTree::get_root_sampled_beta_hat(float *probs)
    {
        root->get_sampled_beta_hat(probs);
    }
    void CTree::get_root_sampled_priors(float *priors)
    {
        root->get_sampled_priors(priors);
    }
    void CTree::get_root_sampled_pred_values(float *values)
    {
        root->get_sampled_pred_values(values);
    }
    void CTree::get_root_sampled_mcts_values(float *values)
    {
        root->get_sampled_mcts_values(values);
    }
    void CTree::get_root_sampled_rewards(float *rewards)
    {
        root->get_sampled_rewards(rewards);
    }
    void CTree::get_root_sampled_qvalues(float *values, float discount)
    {
        root->get_sampled_qvalues(values, discount);
    }

    void CTree::print()
    {
        for (int i = 0; i < this->tot_nodes; ++i)
        {
            fprintf(stderr, "node %d info:\n", i);
            auto u = this->node_pool_ptr[i];
            fprintf(stderr, "\tvisit count: %d, idx: %d\n", u.visit_count, u.hidden_state_index_x);
            fprintf(stderr, "\treward: %f, prior: %f, estimate_value: %f\n", u.reward, u.prior, u.subtree_info.value_estimation());
            fprintf(stderr, "\tchildren (%d in total):\n", u.num_children);
            for (int j = 0; j < u.num_children; ++j)
            {
                fprintf(stderr, "\t\tid: %ld,\t action: ", u.children[j] - this->node_pool_ptr);
                for (auto it : u.children_action[j])
                    fprintf(stderr, "%d ", it);
                fprintf(stderr, "\n");
            }
        }
    }

    
    void CTree::adaptive_update(float *act_rewards, float *last_act_rewards, float *last_act_u_rewards, float *last_act_v_rewards, float *last_act_u_v_rewards)
    {
        int path_len = this->result.search_len;

        for (int i = 0; i < path_len; ++i) 
        {
            CNode *node = this->result.search_path[i];
            nonzero::AdaptiveNode &adaptive_node = node->adaptive_node;

            adaptive_node.act_reward = act_rewards[i];
            adaptive_node.last_act_reward = last_act_rewards[i];
            adaptive_node.last_act_u_reward = last_act_u_rewards[i];
            adaptive_node.last_act_v_reward = last_act_v_rewards[i];
            adaptive_node.last_act_u_v_reward = last_act_u_v_rewards[i];

            node->adaptive_node.updateThetaGivenY();
        }
    }

    CTree_batch::CTree_batch(int root_num, int agent_num, int action_space_size, int sampled_times, int simulation_num, float tree_value_stat_delta_lb, unsigned int random_seed, float rho, float lam)
    {
        
        this->root_num = root_num;
        this->agent_num = agent_num;
        this->action_space_size = action_space_size;
        this->pool_size_per_root = sampled_times * (simulation_num + 2);  
        this->thread_num = 1;
        this->use_adaptive = true;  

        
        this->node_pool = (CNode *)malloc(sizeof(CNode) * this->root_num * this->pool_size_per_root);
        this->trees = (CTree *)malloc(sizeof(CTree) * this->root_num);
        tools::my_assert(this->node_pool && this->trees, "Error in `CTree_batch::CTree_batch`: `malloc` fails for `node_pool` or `trees`.");

        
        for (int i = 0; i < this->root_num; ++i)
        {
            auto ptr_i = this->node_pool + i * this->pool_size_per_root;
            unsigned int seed_i = random_seed * 2333 + i;
            new (this->trees + i) CTree(agent_num, action_space_size, sampled_times, simulation_num, tree_value_stat_delta_lb, ptr_i, seed_i, rho, lam);
        }
    }

    
    void CTree_batch::set_use_adaptive(bool use)
    {
        this->use_adaptive = use;
        for (int i = 0; i < this->root_num; ++i)
        {
            this->trees[i].use_adaptive = use;
        }
    }

    CTree_batch::~CTree_batch()
    {
        for (int i = 0; i < this->root_num; ++i)
        {
            this->trees[i].~CTree();
        }
        free(this->trees);
        free(this->node_pool);
    }

    
    void CTree_batch::prepare(float *rewards_buf, float *values_buf, float *policy_probs_buf, float *beta_buf, int sampled_times, float noise_eps, float *noises_buf, float *hypernet_params_buf)
    {
        
        float *rewards = rewards_buf;
        float *values = values_buf;
        tools::Array3D<float> policy_probs(policy_probs_buf, this->root_num, this->agent_num, this->action_space_size);
        tools::Array3D<float> beta(beta_buf, this->root_num, this->agent_num, this->action_space_size);
        tools::Array3D<float> noises(noises_buf, this->root_num, this->agent_num, this->action_space_size);
        tools::Array3D<float> hypernet_params(hypernet_params_buf, this->root_num, this->agent_num, this->action_space_size);

        for (int i = 0; i < this->root_num; ++i)
        {
            tools::Array2D<float> prob_i(&policy_probs(i, 0, 0), policy_probs.d2, policy_probs.d3);
            tools::Array2D<float> beta_i(&beta(i, 0, 0), beta.d2, beta.d3);
            tools::Array2D<float> noise_i(&noises(i, 0, 0), noises.d2, noises.d3);
            tools::Array2D<float> hypernet_i(&hypernet_params(i, 0, 0), hypernet_params.d2, hypernet_params.d3);
            this->trees[i].prepare(rewards[i], values[i], prob_i, beta_i, sampled_times, noise_eps, noise_i, hypernet_i);
        }
    }

    void CTree_batch::cbatch_selection(float pb_c_base, float pb_c_init, float discount, int *idx_buf, int *idy_buf, int *act_buf)
    {
        
        int *idx_arr = idx_buf;
        int *idy_arr = idy_buf;
        tools::Array2D<int> act_arr(act_buf, this->root_num, this->agent_num);

#pragma omp parallel for num_threads(this->thread_num)
        for (int i = 0; i < this->root_num; ++i)
        {
            this->trees[i].select_path(pb_c_base, pb_c_init, discount);
            idx_arr[i] = this->trees[i].result.idx;
            idy_arr[i] = i;
            for (int j = 0; j < this->agent_num; ++j)
                act_arr(i, j) = this->trees[i].result.action[j];
        }
    }

    
    void CTree_batch::cbatch_expansion_and_backup(int hidden_state_index_x, float discount, int sampled_times, float *rewards_buf, float *values_buf, float *policy_probs_buf, float *beta_buf, float *hypernet_params_buf)
    {
        
        float *rewards = rewards_buf;
        float *values = values_buf;
        tools::Array3D<float> policy_probs(policy_probs_buf, this->root_num, this->agent_num, this->action_space_size);
        tools::Array3D<float> beta(beta_buf, this->root_num, this->agent_num, this->action_space_size);
        tools::Array3D<float> hypernet_params(hypernet_params_buf, this->root_num, this->agent_num, this->action_space_size);

#pragma omp parallel for num_threads(this->thread_num)
        for (int i = 0; i < this->root_num; ++i)
        {
            tools::Array2D<float> prob_i(&policy_probs(i, 0, 0), policy_probs.d2, policy_probs.d3);
            tools::Array2D<float> beta_i(&beta(i, 0, 0), beta.d2, beta.d3);
            tools::Array2D<float> hypernet_i(&hypernet_params(i, 0, 0), hypernet_params.d2, hypernet_params.d3);
            this->trees[i].expand_and_backprop(hidden_state_index_x, discount, sampled_times, rewards[i], values[i], prob_i, beta_i, hypernet_i);
        }
    }

    void CTree_batch::get_roots_values(float *buf)
    {
        float *val = buf;
        for (int i = 0; i < this->root_num; ++i)
        {
            this->trees[i].get_root_value(val + i);
        }
    }

    void CTree_batch::get_roots_marginal_visit_count(int *buf)
    {
        memset(buf, 0, sizeof(int) * this->root_num * this->agent_num * this->action_space_size);
        tools::Array3D<int> arr(buf, this->root_num, this->agent_num, this->action_space_size);
        for (int i = 0; i < this->root_num; ++i)
        {
            this->trees[i].get_root_marginal_visit_count(tools::Array2D<int>(&arr(i, 0, 0), this->agent_num, this->action_space_size));
        }
    }

    void CTree_batch::get_roots_marginal_priors(float *buf)
    {
        memset(buf, 0, sizeof(float) * this->root_num * this->agent_num * this->action_space_size);
        tools::Array3D<float> arr(buf, this->root_num, this->agent_num, this->action_space_size);
        for (int i = 0; i < this->root_num; ++i)
        {
            this->trees[i].get_root_marginal_priors(tools::Array2D<float>(&arr(i, 0, 0), this->agent_num, this->action_space_size));
        }
    }

    int CTree_batch::get_num_children_of_root(int tree_id)
    {
        return this->trees[tree_id].root->num_children;
    }

    void CTree_batch::get_root_sampled_actions(int tree_id, int *buf)
    {
        tools::Array2D<int> arr(buf, this->trees[tree_id].root->num_children, this->agent_num);
        this->trees[tree_id].get_root_sampled_actions(arr);
    }

    void CTree_batch::get_root_sampled_visit_count(int tree_id, int *buf)
    {
        memset(buf, 0, sizeof(int) * this->trees[tree_id].root->num_children);
        int *arr = buf;
        this->trees[tree_id].get_root_sampled_visit_count(arr);
    }

    void CTree_batch::get_root_sampled_pred_probs(int tree_id, float *buf)
    {
        memset(buf, 0, sizeof(float) * this->trees[tree_id].root->num_children);
        float *arr = buf;
        this->trees[tree_id].get_root_sampled_pred_probs(arr);
    }

    void CTree_batch::get_root_sampled_imp_ratio(int tree_id, float *buf)
    {
        memset(buf, 0, sizeof(float) * this->trees[tree_id].root->num_children);
        float *arr = buf;
        this->trees[tree_id].get_root_sampled_imp_ratio(arr);
    }

    void CTree_batch::get_root_sampled_beta(int tree_id, float *buf)
    {
        memset(buf, 0, sizeof(float) * this->trees[tree_id].root->num_children);
        float *arr = buf;
        this->trees[tree_id].get_root_sampled_beta(arr);
    }

    void CTree_batch::get_root_sampled_beta_hat(int tree_id, float *buf)
    {
        memset(buf, 0, sizeof(float) * this->trees[tree_id].root->num_children);
        float *arr = buf;
        this->trees[tree_id].get_root_sampled_beta_hat(arr);
    }

    void CTree_batch::get_root_sampled_priors(int tree_id, float *buf)
    {
        memset(buf, 0, sizeof(float) * this->trees[tree_id].root->num_children);
        float *arr = buf;
        this->trees[tree_id].get_root_sampled_priors(arr);
    }

    void CTree_batch::get_root_sampled_rewards(int tree_id, float *buf)
    {
        memset(buf, 0, sizeof(float) * this->trees[tree_id].root->num_children);
        float *arr = buf;
        this->trees[tree_id].get_root_sampled_rewards(arr);
    }

    void CTree_batch::get_root_sampled_pred_values(int tree_id, float *buf)
    {
        memset(buf, 0, sizeof(float) * this->trees[tree_id].root->num_children);
        float *arr = buf;
        this->trees[tree_id].get_root_sampled_pred_values(arr);
    }

    void CTree_batch::get_root_sampled_mcts_values(int tree_id, float *buf)
    {
        memset(buf, 0, sizeof(float) * this->trees[tree_id].root->num_children);
        float *arr = buf;
        this->trees[tree_id].get_root_sampled_mcts_values(arr);
    }

    void CTree_batch::get_root_sampled_qvalues(int tree_id, float *buf, float discount)
    {
        memset(buf, 0, sizeof(float) * this->trees[tree_id].root->num_children);
        float *arr = buf;
        this->trees[tree_id].get_root_sampled_qvalues(arr, discount);
    }

    void CTree_batch::print()
    {
        for (int i = 0; i < root_num; ++i)
        {
            fprintf(stderr, "---------- Tree %d info ----------\n", i);
            this->trees[i].print();
            fprintf(stderr, "\n");
        }
    }

    
    void CTree_batch::get_path_lengths(int *buf)
    {
        memset(buf, 0, sizeof(int) * this->root_num);
        for (int i = 0; i < this->root_num; ++i)
        {
            buf[i] = this->trees[i].result.search_len;
        }
    }

    void CTree_batch::cadaptive_get_batch_inputs(int *idx_buf, int *idy_buf, int *act_buf, int *last_act_buf, int *last_act_u_buf, int *last_act_v_buf, int *last_act_u_v_buf)
    {
        int total_path_length = 0;
        std::vector<int> path_lengths(this->root_num, 0);
        for (int i = 0; i < this->root_num; ++i)
        {
            path_lengths[i] = this->trees[i].result.search_len;
            total_path_length += path_lengths[i];
        }

        int *idx_arr = idx_buf;
        int *idy_arr = idy_buf;
        tools::Array2D<int> act_arr(act_buf, total_path_length, this->agent_num);
        tools::Array2D<int> last_act_arr(last_act_buf, total_path_length, this->agent_num);
        tools::Array2D<int> last_act_u_arr(last_act_u_buf, total_path_length, this->agent_num);
        tools::Array2D<int> last_act_v_arr(last_act_v_buf, total_path_length, this->agent_num);
        tools::Array2D<int> last_act_u_v_arr(last_act_u_v_buf, total_path_length, this->agent_num);

        
        std::vector<int> offset(root_num, 0);
        for (int i = 1; i < root_num; ++i)
        {
            offset[i] = offset[i - 1] + path_lengths[i - 1];
        }

#pragma omp parallel for num_threads(this->thread_num)
        for (int i = 0; i < root_num; ++i)
        {
            int cur_idx = offset[i]; 
            for (int j = 0; j < path_lengths[i]; ++j)
            {
                CNode *node = trees[i].result.search_path[j];
                auto &adaptive_node = node->adaptive_node;

                idx_arr[cur_idx + j] = node->hidden_state_index_x;
                idy_arr[cur_idx + j] = i;
                for (int k = 0; k < agent_num; ++k)
                {
                    act_arr(cur_idx + j, k) = adaptive_node.action[k];
                    last_act_arr(cur_idx + j, k) = adaptive_node.last_action[k];
                    last_act_u_arr(cur_idx + j, k) = adaptive_node.last_action_u[k];
                    last_act_v_arr(cur_idx + j, k) = adaptive_node.last_action_v[k];
                    last_act_u_v_arr(cur_idx + j, k) = adaptive_node.last_action_u_v[k];
                }
            }
        }
    }

    void CTree_batch::cadaptive_batch_update(float *rewards_buf)
    {
        
        if (!this->use_adaptive)
        {
            return;
        }

        int total_path_length = 0;
        std::vector<int> path_lengths(this->root_num); 
        for (int i = 0; i < this->root_num; ++i)
        {
            path_lengths[i] = this->trees[i].result.search_len;
            total_path_length += path_lengths[i];
        }
        int type_offset = total_path_length; 

        
        std::vector<int> offset(root_num, 0);
        for (int i = 1; i < root_num; ++i)
        {
            offset[i] = offset[i - 1] + path_lengths[i - 1];
        }
#pragma omp parallel for num_threads(this->thread_num)
        for (int i = 0; i < this->root_num; ++i)
        {
            int cur_idx = offset[i];
            float *act_rewards = rewards_buf + cur_idx;
            float *last_act_rewards = rewards_buf + type_offset + cur_idx;
            float *last_act_u_rewards = rewards_buf + type_offset * 2 + cur_idx;
            float *last_act_v_rewards = rewards_buf + type_offset * 3 + cur_idx;
            float *last_act_u_v_rewards = rewards_buf + type_offset * 4 + cur_idx;

            this->trees[i].adaptive_update(act_rewards, last_act_rewards, last_act_u_rewards, last_act_v_rewards, last_act_u_v_rewards);
        }
    }

    
    void CTree_batch::get_roots_adaptive_theta(float *buf)
    {
        
        int param_size = this->agent_num * this->action_space_size;
        
        
        if (!this->use_adaptive)
        {
            memset(buf, 0, sizeof(float) * this->root_num * param_size);
            return;
        }
        
#pragma omp parallel for num_threads(this->thread_num)
        for (int i = 0; i < this->root_num; ++i)
        {
            CNode* root = this->trees[i].root;
            float* buf_i = buf + i * param_size;
            
            
            for (int j = 0; j < param_size; ++j)
            {
                buf_i[j] = static_cast<float>(root->adaptive_node.theta_[j]);
            }
        }
    }
}


