#ifndef CNODE_H
#define CNODE_H

#include "../../common_lib/utils.h"

#include <vector>
#include <random>
#include "nonzero_adaptive.h"


namespace tree
{
    struct CNode
    {
        int visit_count, num_children, hidden_state_index_x;
        float reward, pred_value, prior, pred_prob, beta, beta_hat;
        bool is_root;
        tools::SubTreeValueSet subtree_info;
        std::vector<CNode *> children;                 
        std::vector<std::vector<int>> children_action; 
        
        float ucb_value = 0.0;

        
        nonzero::AdaptiveNode adaptive_node;

        

        CNode(float prior, float pred_prob, float beta, float beta_hat, bool is_root, float rho, float lam);
        ~CNode();

        bool expanded();
        float value();
        float get_qsa(float discount);
        float get_mean_q(float parent_q, float discount);

        
        void get_marginal_visit_count(tools::Array2D<int>); 
        void get_marginal_priors(tools::Array2D<float>);    

        
        void get_sampled_visit_count(int *);                     
        void get_sampled_pred_probs(float *);                    
        void get_sampled_beta(float *);                          
        void get_sampled_beta_hat(float *);                      
        void get_sampled_priors(float *);                        
        void get_sampled_imp_ratio(float *);                     
        void get_sampled_pred_values(float *);                   
        void get_sampled_mcts_values(float *);                   
        void get_sampled_rewards(float *);                       
        void get_sampled_qvalues(float *values, float discount); 
    };

    struct SearchResult
    {
        int idx;             
        int search_len;      
        int *action;         
        CNode *leaf;         
        CNode **search_path; 

        SearchResult(int length_max);
        ~SearchResult();
    };

    struct CTree
    {
    private:
        std::mt19937 gen; 

    public:
        int agent_num, action_space_size, sampled_times, tot_nodes;
        float rho, lam;
        bool use_adaptive;  
        CNode *node_pool_ptr;
        CNode *root;
        tools::CMinMaxStats minmax_stat; 
        SearchResult result;


        CTree(int agent_num, int action_space_size, int sampled_times, int simulation_num, float tree_value_stat_delta_lb, CNode *node_pool_ptr, unsigned int random_seed, float rho, float lam);
        ~CTree();

        
        void prepare(float reward, float value, tools::Array2D<float> policy_probs, tools::Array2D<float> beta, int sampled_times, float noise_eps, tools::Array2D<float> noises, tools::Array2D<float> hypernet_params);
        void expand(CNode *node, int hidden_state_index_x, float reward, float value, tools::Array2D<float> policy_probs, tools::Array2D<float> beta, int sampled_times, float noise_eps, tools::Array2D<float> noises, tools::Array2D<float> hypernet_params);

        float ucb_score(CNode *child, float parent_mean_q, int total_children_visit_counts, float pb_c_base, float pb_c_init, float discount);
        int select_child(CNode *node, float pb_c_base, float pb_c_init, float discount, float mean_q);
        void select_path(float pb_c_base, float pb_c_init, float discount);

        void back_propagate(float value, float discount);
        void expand_and_backprop(int hidden_state_index_x, float discount, int sampled_times, float reward, float value, tools::Array2D<float> policy_prob, tools::Array2D<float> beta, tools::Array2D<float> hypernet_params);

        void get_root_value(float *);

        
        void get_root_marginal_visit_count(tools::Array2D<int>);
        void get_root_marginal_priors(tools::Array2D<float>);

        
        void get_root_sampled_actions(tools::Array2D<int>);
        
        void get_root_sampled_visit_count(int *);
        void get_root_sampled_pred_probs(float *);
        void get_root_sampled_beta(float *);
        void get_root_sampled_beta_hat(float *);
        void get_root_sampled_priors(float *);
        void get_root_sampled_imp_ratio(float *);
        void get_root_sampled_pred_values(float *);
        void get_root_sampled_mcts_values(float *);
        void get_root_sampled_rewards(float *);
        void get_root_sampled_qvalues(float *values, float discount);

        
        void adaptive_update(float* act_rewards, float* last_act_rewards, float* last_act_u_rewards, float* last_act_v_rewards, float* last_act_u_v_rewards);

        void print(); 
    };

    struct CTree_batch
    {
        int root_num, pool_size_per_root, agent_num, action_space_size, thread_num;
        CTree *trees;
        CNode *node_pool;
        bool use_adaptive;  

        CTree_batch(int root_num, int agent_num, int action_space_size, int sampled_times, int simulation_num, float tree_value_stat_delta_lb, unsigned int random_seed, float rho, float lam);
        void set_use_adaptive(bool use);  
        ~CTree_batch();

        
        void prepare(float *rewards, float *values, float *policy_probs, float *beta, int sampled_times, float noise_eps, float* noises, float* hypernet_params);

        void cbatch_selection(float pb_c_base, float pb_c_init, float discount, int *idx_buf, int *idy_buf, int *act_buf);

        
        void cbatch_expansion_and_backup(int hidden_state_index_x, float discount, int sampled_times, float *rewards, float *values, float *policy_probs, float *beta, float *hypernet_params);

        void get_roots_values(float *buf); 

        void get_roots_marginal_visit_count(int *buf); 
        void get_roots_marginal_priors(float *buf);    

        int get_num_children_of_root(int tree_id);
        void get_root_sampled_actions(int tree_id, int *buf);                   
        void get_root_sampled_visit_count(int tree_id, int *buf);               
        void get_root_sampled_pred_probs(int tree_id, float *buf);              
        void get_root_sampled_beta(int tree_id, float *buf);                    
        void get_root_sampled_beta_hat(int tree_id, float *buf);                
        void get_root_sampled_priors(int tree_id, float *buf);                  
        void get_root_sampled_imp_ratio(int tree_id, float *buf);               
        void get_root_sampled_pred_values(int tree_id, float *buf);             
        void get_root_sampled_mcts_values(int tree_id, float *buf);             
        void get_root_sampled_rewards(int tree_id, float* buf);                 
        void get_root_sampled_qvalues(int tree_id, float* buf, float discount); 

        
        void get_path_lengths(int *buf);
        void cadaptive_get_batch_inputs(int* idx_ptr, int* idy_ptr, int* act_ptr, int* last_act_ptr, int* last_act_u_ptr, int* last_act_v_ptr, int* last_act_u_v_ptr);
        void cadaptive_batch_update(float* rewards_buf);
        
        
        void get_roots_adaptive_theta(float* buf);  

        void print(); 
    };
}

#endif
