
#ifndef NONZERO_ADAPTIVE_H
#define NONZERO_ADAPTIVE_H

#include <vector>
#include <random>
#include <cmath>
#include <cstdint>
#include <algorithm>

namespace nonzero {


struct Momentum {
    int t = 0;
    double m = 0.0;
    double v = 0.0;
};

class AdaptiveNode {
public:
    
    int n{0};                      
    int k{0};                      
    std::vector<double> theta_;    
    std::vector<Momentum> opt_;    

    
    std::vector<int> action;
    std::vector<int> last_action;
    std::vector<int> last_action_u;
    std::vector<int> last_action_v;
    std::vector<int> last_action_u_v;

    
    double act_reward{0.0};
    double last_act_reward{0.0};
    double last_act_u_reward{0.0};
    double last_act_v_reward{0.0};
    double last_act_u_v_reward{0.0};

    
    double link_scale{1.0};        
    double lambda_reg{0.0};        
    double lambda_wd{0.0};         
    double lr{0.001};              
    double beta1{0.9};             
    double beta2{0.999};           
    double eps{1e-8};              

    
    std::mt19937_64 rng;

    
    
    void initialize(const float* hypernet_params, std::uint64_t seed = 0);
    
    void initialize(std::uint64_t seed = 0);

    
    double eta_hat(const std::vector<int>& a) const;
    
    
    double getTildeEtaWithRegularizer(const std::vector<int>& a) const;
    
    
    double d_eta_d_theta(int agent_idx, int action_idx) const;

    
    void resampleDirectionsAndNeighbors();

    
    void updateThetaGivenY();

private:
    
    inline int getIndex(int agent_idx, int action_idx) const {
        return agent_idx * k + action_idx;
    }

    
    static double link_asinh(double z_over_s);
    static double link_asinh_prime(double z_over_s);
};

} 

#endif 

