#include "police/static_variable_remover.hpp"
#include "police/addtree_policy.hpp"
#include "police/base_types.hpp"
#include "police/cg_constant_compressor.hpp"
#include "police/cg_policy.hpp"
#include "police/expressions/const_expression_folder.hpp"
#include "police/expressions/constants.hpp"
#include "police/expressions/expression.hpp"
#include "police/expressions/expression_transformer.hpp"
#include "police/expressions/variable.hpp"
#include "police/linear_condition.hpp"
#include "police/macros.hpp"
#include "police/model.hpp"
#include "police/nn_policy.hpp"
#include "police/smt_factory.hpp"
#include "police/storage/value.hpp"
#include "police/storage/variable_space.hpp"
#include "police/verification_property.hpp"

#include <algorithm>

namespace police {

namespace {

static constexpr int_t MAX_DOMAIN_SIZE = 10;

void mark_non_static_variables(vector<bool>& is_static, const Action& action)
{
    for (const auto& outcome : action.outcomes) {
        for (const auto& assignment : outcome.assignments) {
            is_static[assignment.var_id] = false;
        }
    }
}

enum SimplifiedConditionResult { SIMPLIFIED, ALWAYS_TRUE, ALWAYS_FALSE };

SimplifiedConditionResult remove_static_conditions(
    LinearConstraint& constraint,
    const vector<int_t>& value,
    const vector<bool>& is_static,
    const vector<size_t>& new_id)
{
    auto itt = constraint.begin();
    for (auto it = constraint.begin(); it != constraint.end(); ++it) {
        if (is_static[it->first]) {
            constraint.rhs -= it->second * value[it->first];
        } else {
            it->first = new_id[it->first];
            if (itt != it) {
                *itt = *it;
            }
            ++itt;
        }
    }
    constraint.erase(itt, constraint.end());
    if (constraint.empty()) {
        switch (constraint.type) {
        case LinearConstraint::EQUAL:
            if (constraint.rhs != 0.) {
                return ALWAYS_FALSE;
            }
            break;
        case LinearConstraint::GREATER_EQUAL:
            if (constraint.rhs > 0.) {
                return ALWAYS_FALSE;
            }
            break;
        case LinearConstraint::LESS_EQUAL:
            if (constraint.rhs < 0.) {
                return ALWAYS_FALSE;
            }
            break;
        }
        return ALWAYS_TRUE;
    }
    return SIMPLIFIED;
}

SimplifiedConditionResult remove_static_conditions(
    LinearConstraintConjunction& condition,
    const vector<int_t>& value,
    const vector<bool>& is_static,
    const vector<size_t>& new_id)
{
    size_t j = 0;
    for (size_t i = 0; i < condition.size(); ++i) {
        const auto res =
            remove_static_conditions(condition[i], value, is_static, new_id);
        if (res == SIMPLIFIED) {
            if (j != i) {
                condition[j] = std::move(condition[i]);
            }
            ++j;
        } else if (res == ALWAYS_FALSE) {
            return ALWAYS_FALSE;
        }
    }
    condition.erase(condition.begin() + j, condition.end());
    if (condition.empty()) {
        return ALWAYS_TRUE;
    }
    return SIMPLIFIED;
}

SimplifiedConditionResult remove_static_conditions(
    LinearCondition& condition,
    const vector<int_t>& value,
    const vector<bool>& is_static,
    const vector<size_t>& new_id)
{
    size_t j = 0;
    for (size_t i = 0; i < condition.size(); ++i) {
        const auto res =
            remove_static_conditions(condition[i], value, is_static, new_id);
        if (res == SIMPLIFIED) {
            if (j != i) {
                condition[j] = std::move(condition[i]);
            }
            ++j;
        } else if (res == ALWAYS_TRUE) {
            return ALWAYS_TRUE;
        }
    }
    condition.erase(condition.begin() + j, condition.end());
    if (condition.empty()) {
        return ALWAYS_FALSE;
    }
    return SIMPLIFIED;
}

SimplifiedConditionResult remove_static_conditions(
    expressions::Expression& condition,
    const vector<int_t>& value,
    const vector<bool>& is_static,
    const vector<size_t>& new_id)
{
    class ConstInserter : public expressions::ExpressionTransformer {
    public:
        ConstInserter(
            const vector<bool>& is_static,
            const vector<int_t>& value,
            const vector<size_t>& ids)
            : is_static(is_static)
            , value(value)
            , new_id(ids)
        {
        }

        void visit(expressions::Expression& ptr, expressions::Variable& expr)
            override
        {
            if (is_static[expr.var_id]) {
                ptr = expressions::Constant(Value(value[expr.var_id]));
            } else {
                ptr = expressions::Variable(new_id[expr.var_id]);
            }
        }

        const vector<bool>& is_static;
        const vector<int_t>& value;
        const vector<size_t>& new_id;
    };
    ConstInserter t(is_static, value, new_id);
    condition.transform(t);
    expressions::ConstExpressionFolder f;
    condition.transform(f);
    return SIMPLIFIED;
}

vector<int_t> mark_variable_with_unique_value(
    vector<bool>& is_static,
    const expressions::Expression& expr,
    const VariableSpace& variables,
    std::shared_ptr<SMTFactory> factory)
{
    vector<int_t> value(is_static.size(), -1);
#if 0
    vector<bool> refd(is_static.size(), false);
    auto cond = LinearCondition::from_expression(expr);
    bool is_first = true;
    for (const auto& conj : cond) {
        std::fill(refd.begin(), refd.end(), false);
        for (const auto& constr : conj) {
            if (constr.size() == 1u && constr.type == LinearConstraint::EQUAL) {
                assert(constr.coefs()[0] == 1.);
                const auto var_id = constr.refs()[0];
                refd[var_id] =
                    is_first || value[var_id] == static_cast<int_t>(constr.rhs);
                value[var_id] = static_cast<int_t>(constr.rhs);
            }
        }
        for (int var_id = refd.size() - 1; var_id >= 0; --var_id) {
            is_static[var_id] = is_static[var_id] && refd[var_id];
        }
    }
#else
    auto smt = factory->make_unique();
    smt->add_variables(variables);
    for (size_t var = 0; var < is_static.size(); ++var) {
        const auto& t = variables.get_type(var);
        if (!is_static[var] || !t.is_bounded_int()) {
            continue;
        }
        int_t lb = t.get_lower_bound();
        int_t ub = t.get_upper_bound();
        if (ub - lb > MAX_DOMAIN_SIZE) {
            is_static[var] = false;
            continue;
        }
        bool sat = false;
        int val = 0;
        for (int i = lb;; ++i) {
            expressions::Expression aux =
                expr && expressions::equal(
                            expressions::Variable(var),
                            expressions::Constant(Value(i)));
            smt->push_snapshot();
            smt->add_constraint(aux);
            const auto solvable = smt->solve();
            smt->pop_snapshot();
            if (solvable == SMT::Status::SAT) {
                if (sat) {
                    is_static[var] = false;
                    break;
                } else {
                    sat = true;
                    val = i;
                }
            }
            if (i == ub) {
                break;
            }
        }
        if (is_static[var]) {
            if (!sat) {
                POLICE_RUNTIME_ERROR("unsatisfiable start condition");
            }
            value[var] = val;
        }
    }
#endif
    return value;
}

} // namespace

StaticVariableRemover::StaticVariableRemover(
    const Model& model,
    const expressions::Expression& init,
    std::shared_ptr<SMTFactory> factory)
    : is_static_(model.variables.size(), true)
    , new_id_(model.variables.size(), -1)
{
    std::for_each(
        model.actions.begin(),
        model.actions.end(),
        [&](const Action& action) {
            mark_non_static_variables(is_static_, action);
        });
    value_ = mark_variable_with_unique_value(
        is_static_,
        init,
        model.variables,
        factory);
    for (size_t var = 0, new_id = 0; var < is_static_.size(); ++var) {
        if (!is_static_[var]) {
            new_id_[var] = new_id++;
        }
    }
}

bool StaticVariableRemover::has_static_variables() const
{
    return num_static_variables() > 0u;
}

size_t StaticVariableRemover::num_static_variables() const
{
    return std::count(is_static_.begin(), is_static_.end(), true);
}

void StaticVariableRemover::substitute_variable_ids(vector<size_t>& var_ids)
{
    size_t i = 0;
    for (size_t j = 0; j < var_ids.size(); ++j) {
        if (!is_static_[var_ids[j]]) {
            var_ids[i] = new_id_[var_ids[j]];
            ++i;
        }
    }
    var_ids.erase(var_ids.begin() + i, var_ids.end());
}

void StaticVariableRemover::apply(Model& model)
{
    size_t i = 0;
    for (size_t j = 0; j < model.actions.size(); ++j) {
        const auto r = remove_static_conditions(
            model.actions[j].guard,
            value_,
            is_static_,
            new_id_);
        if (r != ALWAYS_FALSE) {
            for (auto& out : model.actions[j].outcomes) {
                for (auto& ass : out.assignments) {
                    ass.var_id = new_id_[ass.var_id];
                }
            }
            if (i != j) model.actions[i] = std::move(model.actions[j]);
            ++i;
        }
    }
    model.actions.erase(model.actions.begin() + i, model.actions.end());
    VariableSpace space;
    unordered_map<size_t, vector<identifier_name_t>> names;
    i = 0;
    for (size_t j = 0; j < model.variables.size(); ++j) {
        if (!is_static_[j]) {
            space.add_variable(
                model.variables.get_name(j),
                model.variables.get_type(j));
            auto x = model.value_names.find(j);
            if (x != model.value_names.end()) {
                names[i] = std::move(x->second);
            }
            ++i;
        }
    }
    model.variables = std::move(space);
    model.value_names = std::move(names);
}

void StaticVariableRemover::apply(VerificationProperty& prop)
{
    apply(prop.start);
    apply(prop.avoid);
    apply(prop.reach);
}

void StaticVariableRemover::apply(expressions::Expression& expr)
{
    remove_static_conditions(expr, value_, is_static_, new_id_);
}

void StaticVariableRemover::apply(LinearCondition& cond)
{
    remove_static_conditions(cond, value_, is_static_, new_id_);
}

void StaticVariableRemover::apply(std::shared_ptr<CGPolicy>& cg)
{
    vector<std::pair<size_t, real_t>> constants;
    for (size_t i = 0; i < cg->get_input().size(); ++i) {
        const auto var = cg->get_input()[i];
        if (is_static_[var]) {
            constants.emplace_back(i, value_[var]);
        }
    }
    auto node = cg::ConstantCompressor()(cg->get_compute_graph(), constants);
    auto inputs = cg->get_input();
    substitute_variable_ids(inputs);
    cg = std::make_shared<CGPolicy>(
        std::move(node),
        std::move(inputs),
        cg->get_output(),
        cg->get_action_indices().size());
}

void StaticVariableRemover::apply(
    [[maybe_unused]] std::shared_ptr<AddTreePolicy>& tree)
{
    POLICE_RUNTIME_ERROR(
        "static variable removal has not been implemented for tree ensembles");
}

void StaticVariableRemover::apply(
    [[maybe_unused]] std::shared_ptr<NeuralNetworkPolicy>& tree)
{
    POLICE_RUNTIME_ERROR(
        "static variable removal has not been implemented for NN ensembles");
}

} // namespace police
