#include <algorithm>
#include <cmath>

#include "info_state_values.hpp"

namespace open_spiel {
namespace algorithms {

ActionsAndProbs InfoStateValues::GetCurrentPolicy() const {
    ActionsAndProbs actions_and_probs;
    actions_and_probs.reserve(legal_actions.size());
    for (int i = 0; i < legal_actions.size(); ++i) {
        actions_and_probs.push_back({legal_actions[i], current_policy[i]});
    }
    return actions_and_probs;
}

void InfoStateValues::ApplyRegretMatching(const bool regret_matching_plus) {
    update_count++;

    if (regret_matching_plus) {
        for (int aidx = 0; aidx < num_actions(); ++aidx) {
            cumulative_regrets[aidx] = std::max(cumulative_regrets[aidx], 0.0);
        }
    }

    double sum_positive_regrets = 0.0;

    for (int aidx = 0; aidx < num_actions(); ++aidx) {
        if (cumulative_regrets[aidx] > 0) {
            sum_positive_regrets += cumulative_regrets[aidx];
        }
    }

    for (int aidx = 0; aidx < num_actions(); ++aidx) {
        if (sum_positive_regrets > 0) {
            current_policy[aidx] = cumulative_regrets[aidx] > 0 ? cumulative_regrets[aidx] / sum_positive_regrets : 0;
        } else {
            current_policy[aidx] = 1.0 / legal_actions.size();
        }
    }
}

void InfoStateValues::ApplyPredictiveRegretMatching(const bool regret_matching_plus) {
    update_count++;

    if (regret_matching_plus) {
        for (int aidx = 0; aidx < num_actions(); ++aidx) {
            cumulative_regrets[aidx] = std::max(cumulative_regrets[aidx], 0.0);
        }
    }

    std::vector<double> predictive_cumulative_regrets =  std::vector<double>(num_actions(), 0);
    for (int aidx = 0; aidx < num_actions(); ++aidx) {
        if (regret_matching_plus) {
            predictive_cumulative_regrets[aidx] = std::max(cumulative_regrets[aidx] + instant_regrets[aidx], 0.0);
        } else {
            predictive_cumulative_regrets[aidx] = cumulative_regrets[aidx] + instant_regrets[aidx];
        }
    }

    double sum_positive_regrets = 0.0;

    for (int aidx = 0; aidx < num_actions(); ++aidx) {
        if (predictive_cumulative_regrets[aidx] > 0) {
            sum_positive_regrets += predictive_cumulative_regrets[aidx];
        }
    }

    for (int aidx = 0; aidx < num_actions(); ++aidx) {
        if (sum_positive_regrets > 0) {
            current_policy[aidx] = predictive_cumulative_regrets[aidx] > 0 ? predictive_cumulative_regrets[aidx] / sum_positive_regrets : 0;
        } else {
            current_policy[aidx] = 1.0 / legal_actions.size();
        }
    }
}

void InfoStateValues::ApplySoftmax(const double eta) {
    update_count++;
    double values[num_actions()];
    double denominator = 0;

    for (int aidx = 0; aidx < num_actions(); ++aidx) {
        values[aidx] = current_policy[aidx] * std::exp(eta * instant_regrets[aidx]);
        denominator += values[aidx];
    }

    for (int aidx = 0; aidx < num_actions(); ++aidx) {
        current_policy[aidx] = values[aidx] / denominator;
    }
}

void InfoStateValues::ApplyL2(const double eta) {
    update_count++;
    std::vector<double> y;
    for (auto& r: cumulative_regrets) {
        y.push_back(r * eta);
    }
    std::vector<double> u;
    for (auto& elem: y) {
        u.push_back(elem);
    }
    std::sort(u.begin(), u.end(), std::greater<double>{});
    std::vector<double> cumsum_u;
    double tmp = 0.0;
    for (auto& elem: u) {
        tmp += elem;
        cumsum_u.push_back(tmp);
    }
    double rho;
    for (int j = 0; j < u.size(); j++) {
        if ((u[j] + (1 - cumsum_u[j]) / (j + 1)) > 0.0) {
            rho = j;
        }
    }
    double lamb = (1 / (rho + 1)) * (1 - cumsum_u[rho]);
    for (int aidx = 0; aidx < num_actions(); ++aidx) {
        current_policy[aidx] = std::max(y[aidx] + lamb, 0.0);
    }
}

void InfoStateValues::ApplyKLProjection(std::vector<double> &p, std::vector<double> &q, std::vector<double> &y) const {
    double values[num_actions()];
    double denominator = 0;

    for (int aidx = 0; aidx < num_actions(); ++aidx) {
        values[aidx] = q[aidx] * std::exp(y[aidx]);
        denominator += values[aidx];
    }

    for (int aidx = 0; aidx < num_actions(); ++aidx) {
        p[aidx] = values[aidx] / denominator;
    }
}

void InfoStateValues::ApplyL2Projection(const double eta) {
    update_count++;
    std::vector<double> y;
    for (int aidx = 0; aidx < num_actions(); ++aidx) {
        y.push_back(current_policy[aidx] + instant_regrets[aidx] * eta);
    }
    std::vector<double> u;
    for (auto& elem: y) {
        u.push_back(elem);
    }
    std::sort(u.begin(), u.end(), std::greater<double>{});
    std::vector<double> cumsum_u;
    double tmp = 0.0;
    for (auto& elem: u) {
        tmp += elem;
        cumsum_u.push_back(tmp);
    }
    double rho;
    for (int j = 0; j < u.size(); j++) {
        if ((u[j] + (1 - cumsum_u[j]) / (j + 1)) > 0.0) {
            rho = j;
        }
    }
    double lamb = (1 / (rho + 1)) * (1 - cumsum_u[rho]);
    for (int aidx = 0; aidx < num_actions(); ++aidx) {
        current_policy[aidx] = std::max(y[aidx] + lamb, 0.0);
    }
}

void InfoStateValues::ApplyL2Projection(std::vector<double> &p, std::vector<double> &q, std::vector<double> &y) const {
    std::vector<double> z;
    for (int aidx = 0; aidx < num_actions(); ++aidx) {
        z.push_back(q[aidx] + y[aidx]);
    }
    std::vector<double> u;
    for (auto& elem: z) {
        u.push_back(elem);
    }
    std::sort(u.begin(), u.end(), std::greater<double>{});
    std::vector<double> cumsum_u;
    double tmp = 0.0;
    for (auto& elem: u) {
        tmp += elem;
        cumsum_u.push_back(tmp);
    }
    int rho;
    for (int j = 0; j < u.size(); j++) {
        if ((u[j] + (1 - cumsum_u[j]) / (j + 1)) > 0.0) {
            rho = j;
        }
    }

    double lamb = (1.0 / (rho + 1.0)) * (1.0 - cumsum_u[rho]);
    double denominator = 0;
    for (int aidx = 0; aidx < num_actions(); ++aidx) {
        p[aidx] = std::max(z[aidx] + lamb, 0.0);
        denominator += p[aidx];
    }
    for (int aidx = 0; aidx < num_actions(); ++aidx) {
        p[aidx] /= denominator;
    }
}

}  // namespace algorithms
}  // namespace open_spiel