
#include "nonzero_adaptive.h"

namespace nonzero {


void AdaptiveNode::initialize(const float* hypernet_params, std::uint64_t seed) {
    
    if (seed == 0) {
        std::random_device rd;
        rng.seed(((uint64_t)rd() << 1) ^ (uint64_t)(uintptr_t)this);
    } else {
        rng.seed(seed);
    }

    if (link_scale <= 0.0) link_scale = 1.0;

    
    int total_params = n * k;
    theta_.resize(total_params);
    opt_.resize(total_params);

    
    for (int i = 0; i < total_params; ++i) {
        theta_[i] = static_cast<double>(hypernet_params[i]);
        opt_[i] = Momentum{};  
    }

    
    const int sz = std::max(0, n);
    action.assign(sz, 1);
    last_action.assign(sz, 1);
    last_action_u.assign(sz, 1);
    last_action_v.assign(sz, 1);
    last_action_u_v.assign(sz, 1);

    act_reward = last_act_reward = last_act_u_reward = 
                 last_act_v_reward = last_act_u_v_reward = 0.0;
}

void AdaptiveNode::initialize(std::uint64_t seed) {
    
    if (seed == 0) {
        std::random_device rd;
        rng.seed(((uint64_t)rd() << 1) ^ (uint64_t)(uintptr_t)this);
    } else {
        rng.seed(seed);
    }

    if (link_scale <= 0.0) link_scale = 1.0;

    
    int total_params = n * k;
    theta_.assign(total_params, 0.0);
    opt_.resize(total_params);
    for (int i = 0; i < total_params; ++i) {
        opt_[i] = Momentum{};
    }

    
    const int sz = std::max(0, n);
    action.assign(sz, 1);
    last_action.assign(sz, 1);
    last_action_u.assign(sz, 1);
    last_action_v.assign(sz, 1);
    last_action_u_v.assign(sz, 1);

    act_reward = last_act_reward = last_act_u_reward = 
                 last_act_v_reward = last_act_u_v_reward = 0.0;
}


double AdaptiveNode::link_asinh(double z_over_s) {
    return std::asinh(z_over_s);
}

double AdaptiveNode::link_asinh_prime(double z_over_s) {
    return 1.0 / std::sqrt(1.0 + z_over_s * z_over_s);
}


double AdaptiveNode::eta_hat(const std::vector<int>& a) const {
    
    double result = 0.0;
    for (int i = 0; i < n; ++i) {
        int idx = getIndex(i, a[i]);
        double z_over_s = theta_[idx] / link_scale;
        result += link_asinh(z_over_s);
    }
    return result;
}

double AdaptiveNode::d_eta_d_theta(int agent_idx, int action_idx) const {
    int idx = getIndex(agent_idx, action_idx);
    double z_over_s = theta_[idx] / link_scale;
    return (1.0 / link_scale) * link_asinh_prime(z_over_s);
}


double AdaptiveNode::getTildeEtaWithRegularizer(const std::vector<int>& a) const {
    double base = eta_hat(a);
    if (lambda_reg <= 0.0) return base;
    
    double reg = 0.0;
    for (int i = 0; i < n; ++i) {
        int idx = getIndex(i, a[i]);
        double z = theta_[idx];
        reg += (z * z) / (link_scale * link_scale);
    }
    return base - 0.5 * lambda_reg * reg;
}


void AdaptiveNode::resampleDirectionsAndNeighbors() {
    if ((int)last_action.size() != n)
        last_action.assign(std::max(0, n), 0);

    last_action_u = last_action;
    last_action_v = last_action;
    last_action_u_v = last_action;

    if (n <= 0 || k <= 0) return;

    std::uniform_int_distribution<int> uni_agent(0, n - 1);
    std::uniform_int_distribution<int> uni_act(0, k - 1);

    int i = uni_agent(rng);
    int j = (n >= 2) ? uni_agent(rng) : i;
    while (n >= 2 && j == i) j = uni_agent(rng);

    auto change_one_at = [&](std::vector<int>& a, int idx) {
        if (k <= 1) return;
        int newv = a[idx];
        for (int t = 0; t < 8 && newv == a[idx]; ++t) newv = uni_act(rng);
        if (newv == a[idx]) newv = (a[idx] + 1) % k;
        a[idx] = newv;
    };

    
    change_one_at(last_action_u, i);
    change_one_at(last_action_v, j);
    last_action_u_v = last_action;
    change_one_at(last_action_u_v, i);
    change_one_at(last_action_u_v, j);
}


void AdaptiveNode::updateThetaGivenY() {
    if (n == 0) return;

    
    const double y1 = act_reward;               
    const double y2 = last_act_reward;          
    const double g_target = last_act_u_reward - last_act_reward;
    const double h_target = (last_act_u_v_reward - last_act_v_reward)
                          - (last_act_u_reward   - last_act_reward);

    
    const double e_a      = eta_hat(action);
    const double e_last   = eta_hat(last_action);
    const double e_last_u = eta_hat(last_action_u);
    const double e_last_v = eta_hat(last_action_v);
    const double e_last_uv= eta_hat(last_action_u_v);

    const double g_hat = (e_last_u - e_last);
    const double h_hat = (e_last_uv - e_last_v) - (e_last_u - e_last);

    
    const double l1_grad = 2.0 * (e_a    - y1);
    const double l2_grad = 2.0 * (e_last - y2);
    const double lg_grad = 2.0 * (g_hat  - g_target);
    const double lh_grad = 2.0 * (h_hat  - h_target);

    
    std::vector<double> grads(n * k, 0.0);

    auto accum = [&](const std::vector<int>& a, double coeff) {
        if (coeff == 0.0) return;
        for (int i = 0; i < n; ++i) {
            int idx = getIndex(i, a[i]);
            grads[idx] += coeff * d_eta_d_theta(i, a[i]);
        }
    };

    
    accum(action, l1_grad);
    
    accum(last_action, l2_grad);
    
    accum(last_action_u, +lg_grad);
    accum(last_action, -lg_grad);
    
    accum(last_action_u_v, +lh_grad);
    accum(last_action_v, -lh_grad);
    accum(last_action_u, -lh_grad);
    accum(last_action, +lh_grad);

    
    for (int idx = 0; idx < n * k; ++idx) {
        double gtheta = grads[idx];
        if (gtheta == 0.0 && lambda_wd <= 0.0) continue;

        
        if (lambda_wd > 0.0) {
            gtheta += lambda_wd * theta_[idx];
        }

        Momentum& st = opt_[idx];
        st.t += 1;
        st.m = beta1 * st.m + (1.0 - beta1) * gtheta;
        st.v = beta2 * st.v + (1.0 - beta2) * (gtheta * gtheta);

        const double mhat = st.m / (1.0 - std::pow(beta1, st.t));
        const double vhat = st.v / (1.0 - std::pow(beta2, st.t));

        theta_[idx] -= lr * (mhat / (std::sqrt(vhat) + eps));
    }
}

} 
