#include "police/nnlp_policy_encoding.hpp"

#include "police/action.hpp"
#include "police/constants.hpp"
#include "police/constraint_factories.hpp"
#include "police/ffnn_lp_encoder.hpp"
#include "police/nnlp_encoders.hpp"
#include "police/nnlp_model_encoding.hpp"
#include "police/storage/id_map.hpp"
#include "police/variable_substitution.hpp"

#include <algorithm>

namespace police {

vector<size_t> add_nn(
    NNLP& lp,
    const EncodingInformation& info,
    const NeuralNetworkPolicy& policy)
{
    const vector<size_t> nn_input =
        get_vars_in_input_order(info.in_vars, policy.get_input());
    const size_t offset = encode_ffnn_in_lp(lp, policy.get_nn(), nn_input);
    return get_action_vars(policy.get_output(), offset);
}

namespace {
size_t create_indicator_variables(NNLPLP& lp, size_t num)
{
    assert(num > 0u);
    size_t var = lp.add_variable(BoolType());
    for (auto i = num - 1; i > 0; --i) {
        lp.add_variable(BoolType());
    }
    return var;
}

size_t add_policy_selection_noapp_indicators(
    NNLPLP& lp,
    const vector<size_t>& action_val_vars)
{
    auto var = create_indicator_variables(lp, action_val_vars.size());
    LP* ulp = lp.get_underlying_lp();
    for (size_t a = 0; a < action_val_vars.size(); ++a) {
        for (size_t b = 0; b < a; ++b) {
            ulp->add_constraint(LP::indicator_constraint_type(
                var + a,
                true,
                less_constraint(action_val_vars[b], action_val_vars[a])));
        }
        for (size_t b = a + 1; b < action_val_vars.size(); ++b) {
            ulp->add_constraint(LP::indicator_constraint_type(
                var + a,
                true,
                less_equal_constraint(action_val_vars[b], action_val_vars[a])));
        }
    }
    return var;
}

struct ApplicabilityFilterEncoding {
    ApplicabilityFilterEncoding(
        NNLPLP* lp,
        vector<Action>::const_iterator first,
        vector<Action>::const_iterator last,
        const EncodingInformation* info)
        : lp_(lp)
        , first_(first)
        , last_(last)
        , info_(info)
    {
        const auto& variables = lp_->get_variable_space();
        is_integer_.resize(variables.size(), false);
        for (int var = variables.size() - 1; var >= 0; --var) {
            assert(
                variables[var].type.is_int() || variables[var].type.is_real());
            is_integer_[var] = variables[var].type.is_int();
        }
    }

    ApplicabilityFilterEncoding(const ApplicabilityFilterEncoding&) = delete;

    size_t operator()()
    {
        const size_t edge_inapplicability = add_inapplicable_edges_indicators();
        const size_t action_inapplicability =
            add_inapplicable_actions_indicators(edge_inapplicability);
        const size_t base_var =
            create_indicator_variables(*lp_, info_->action_vars.size());
        for (size_t a = 0; a < info_->action_vars.size(); ++a) {
            add_action_selection_constraint(
                base_var + a,
                action_inapplicability,
                a);
        }
        return base_var;
    }

private:
    void add_action_selection_constraint(
        size_t indicator_var,
        size_t inapplicability_var,
        size_t action)
    {
        for (size_t x = 0; x < action; ++x) {
            LinearConstraintDisjunction dis;
            // if indicator_var is 1
            dis.push_back(upper_bound_constraint(indicator_var, 0));
            // then: action value is strictly greater than that of x
            dis.push_back(less_constraint(
                info_->action_vars[x],
                info_->action_vars[action]));
            // or x is not applicable
            dis.push_back(lower_bound_constraint(inapplicability_var + x, 1));
            lp_->add_constraint(dis);
        }
        for (size_t x = action + 1; x < info_->action_vars.size(); ++x) {
            LinearConstraintDisjunction dis;
            // if indicator_var is 1
            dis.push_back(upper_bound_constraint(indicator_var, 0));
            // then: action value is greater or equal than that of x
            dis.push_back(less_equal_constraint(
                info_->action_vars[x],
                info_->action_vars[action]));
            // or x is not applicable
            dis.push_back(lower_bound_constraint(inapplicability_var + x, 1));
            lp_->add_constraint(dis);
        }
    }

    size_t add_inapplicable_edges_indicators()
    {
        const size_t base_var =
            create_indicator_variables(*lp_, std::distance(first_, last_));
        size_t i = 0;
        for (auto it = first_; it != last_; ++it, ++i) {
            add_unsatisfied_guard_constraint(base_var + i, it->guard);
        }
        return base_var;
    }

    void add_unsatisfied_guard_constraint(
        size_t indicator_var,
        const LinearConstraintConjunction& guard)
    {
        LinearConstraint at_least_one(LinearConstraint::GREATER_EQUAL);
        at_least_one.rhs = 1.;
        for (const auto& constraint : guard) {
            const auto inapplicable_indicator = get_unsatisfied_indicator(
                substitute_vars(constraint, info_->in_vars));
            at_least_one.insert(inapplicable_indicator, 1.);
        }
        lp_->get_underlying_lp()->add_constraint(LP::indicator_constraint_type(
            indicator_var,
            true,
            std::move(at_least_one)));
    }

    size_t add_inapplicable_actions_indicators(size_t inapplicable_edge_var)
    {
        const size_t base_var =
            create_indicator_variables(*lp_, info_->action_vars.size());
        size_t edge_idx = 0;
        for (size_t a = 0; a < info_->action_vars.size(); ++a) {
            add_inapplicable_action_constraint(
                base_var + a,
                inapplicable_edge_var,
                a,
                edge_idx);
        }
        return base_var;
    }

    void add_inapplicable_action_constraint(
        size_t indicator_var,
        size_t inapplicable_guard_var,
        size_t action,
        size_t& edge_idx)
    {
        assert(
            edge_idx < std::distance(first_, last_) &&
            (first_ + edge_idx)->label == action);
        LinearConstraint all_constraint(LinearConstraint::GREATER_EQUAL);
        size_t num_edges = 0;
        for (auto it = first_ + edge_idx; it != last_ && it->label == action;
             ++it, ++edge_idx, ++num_edges) {
            all_constraint.insert(inapplicable_guard_var + edge_idx, 1.);
        }
        all_constraint.rhs = num_edges;
        lp_->get_underlying_lp()->add_constraint(LP::indicator_constraint_type(
            indicator_var,
            true,
            std::move(all_constraint)));
    }

    size_t get_unsatisfied_indicator(const LinearConstraint& constraint)
    {
        const auto id = constraint_ids_.insert(constraint);
        if (id.second) {
            LP* ulp = lp_->get_underlying_lp();
            const auto var = create_indicator_variables(*lp_, 1);
            const bool integer = std::all_of(
                constraint.begin(),
                constraint.end(),
                [&](const auto& elem) {
                    assert(elem.first < is_integer_.size());
                    return is_integer_[elem.first];
                });
            const real_t offset = integer ? 1. : LP_PRECISION;
            auto add_greater_indicator = [&](size_t indicator_var) {
                LinearConstraint negation(constraint);
                negation.type = LinearConstraint::GREATER_EQUAL;
                negation.rhs += offset;
                ulp->add_constraint(LP::indicator_constraint_type(
                    indicator_var,
                    true,
                    std::move(negation)));
            };
            auto add_less_indicator = [&](size_t indicator_var) {
                LinearConstraint negation(constraint);
                negation.type = LinearConstraint::LESS_EQUAL;
                negation.rhs -= offset;
                ulp->add_constraint(LP::indicator_constraint_type(
                    indicator_var,
                    true,
                    std::move(negation)));
            };
            switch (constraint.type) {
            case LinearConstraint::LESS_EQUAL: {
                add_greater_indicator(var);
                break;
            }
            case LinearConstraint::GREATER_EQUAL: {
                add_less_indicator(var);
                break;
            }
            case LinearConstraint::EQUAL: {
                if (compute_lb(*lp_, constraint) == constraint.rhs) {
                    add_greater_indicator(var);
                } else if (compute_ub(*lp_, constraint) == constraint.rhs) {
                    add_less_indicator(var);
                } else {
                    const size_t less = create_indicator_variables(*lp_, 1);
                    add_less_indicator(less);
                    const size_t greater = create_indicator_variables(*lp_, 1);
                    add_greater_indicator(greater);
                    LinearConstraint or_constr(LinearConstraint::GREATER_EQUAL);
                    or_constr.rhs = 1.;
                    or_constr.insert(greater, 1.);
                    or_constr.insert(less, 1.);
                    ulp->add_constraint(LP::indicator_constraint_type(
                        var,
                        true,
                        std::move(or_constr)));
                }
                break;
            }
            }
            indicator_var_.push_back(var);
        }
        assert(indicator_var_.size() == constraint_ids_.size());
        assert(id.first->second < indicator_var_.size());
        return indicator_var_[id.first->second];
    }

    id_map<LinearConstraint> constraint_ids_;
    vector<size_t> indicator_var_;

    vector<bool> is_integer_;

    NNLPLP* lp_;
    vector<Action>::const_iterator first_;
    vector<Action>::const_iterator last_;
    const EncodingInformation* info_;
};

void add_edge_indicator_constraints(
    NNLPLP& lp,
    size_t result_var,
    const Action& edge,
    size_t sel_var,
    const EncodingInformation& info)
{
    LP* ulp = lp.get_underlying_lp();
    // policy selection
    {
        LinearConstraint c(LinearConstraint::GREATER_EQUAL);
        c.insert(sel_var, 1.);
        c.rhs = 1.;
        ulp->add_constraint(LP::indicator_constraint_type(result_var, true, c));
    }
    // guard
    for (const auto& cond : edge.guard) {
        ulp->add_constraint(LP::indicator_constraint_type(
            result_var,
            true,
            substitute_vars(cond, info.in_vars)));
    }
    assert(edge.outcomes.size() >= 1u);
    // outcomes
    if (edge.outcomes.size() == 1u) {
        // for single outcome, separate outcome indicator variable not necessary
        const auto out = get_assignment_constraints(info, edge.outcomes.back());
        for (const auto& cond : out) {
            ulp->add_constraint(
                LP::indicator_constraint_type(result_var, true, cond));
        }
    } else {
        const size_t out_var =
            create_indicator_variables(lp, edge.outcomes.size());
        LinearConstraint c(LinearConstraint::GREATER_EQUAL);
        c.rhs = 1.;
        for (auto i = 0u; i < edge.outcomes.size(); ++i) {
            const auto out = get_assignment_constraints(info, edge.outcomes[i]);
            for (const auto& cond : out) {
                ulp->add_constraint(
                    LP::indicator_constraint_type(out_var + i, true, cond));
            }
            c.insert(out_var + i, 1.);
        }
        ulp->add_constraint(LP::indicator_constraint_type(result_var, true, c));
    }
}

void encode_all_transitions(
    NNLPLP& lp,
    const EncodingInformation& info,
    vector<Action>::const_iterator first,
    vector<Action>::const_iterator last,
    size_t action_selection_var)
{
    size_t evar = create_indicator_variables(lp, std::distance(first, last));
    LinearConstraint atleastone(LinearConstraint::GREATER_EQUAL);
    atleastone.rhs = 1.;
    size_t i = 0;
    for (auto it = first; it != last; ++it, ++i) {
        add_edge_indicator_constraints(
            lp,
            evar + i,
            *it,
            action_selection_var + it->label,
            info);
        atleastone.insert(evar + i, 1.);
    }
    lp.add_constraint(atleastone);
}

} // namespace

void encode_all_transitions_without_app_filter(
    NNLPLP& lp,
    const EncodingInformation& info,
    vector<Action>::const_iterator first,
    vector<Action>::const_iterator last)
{
    size_t action_selection_var =
        add_policy_selection_noapp_indicators(lp, info.action_vars);
    encode_all_transitions(lp, info, first, last, action_selection_var);
}

void encode_all_transitions_with_app_filter(
    NNLPLP& lp,
    const EncodingInformation& info,
    vector<Action>::const_iterator first,
    vector<Action>::const_iterator last)
{
    ApplicabilityFilterEncoding encoder(&lp, first, last, &info);
    size_t action_selection_var = encoder();
    encode_all_transitions(lp, info, first, last, action_selection_var);
}

} // namespace police
