#include "police/_bits/z3_optimizer.hpp"

#if POLICE_Z3

#include "police/_bits/z3_env.hpp"
#include "police/base_types.hpp"
#include "police/constants.hpp"
#include "police/lp.hpp"
#include "police/macros.hpp"

#include <limits>
#include <memory>
#include <z3++.h>

namespace police {

Z3Optimizer::Z3Optimizer()
    : env_()
    , last_model_(nullptr)
    , optimizer_(std::make_unique<z3::optimize>(*env_.context()))
{
}

Z3Optimizer::~Z3Optimizer()
{
    last_model_ = nullptr;
    optimizer_ = nullptr;
}

Z3Environment& Z3Optimizer::get_environment()
{
    return env_;
}

const Z3Environment& Z3Optimizer::get_environment() const
{
    return env_;
}

namespace {
void add_variable_bounds(
    z3::optimize& solver,
    z3::expr var_expr,
    const VariableType& type)
{
    std::visit(
        [&](auto&& t) {
            using T = std::decay_t<decltype(t)>;
            if constexpr (std::is_same_v<T, BoundedIntType>) {
                solver.add(var_expr >= t.lower_bound);
                solver.add(var_expr <= t.upper_bound);
            } else if constexpr (std::is_same_v<T, BoundedRealType>) {
                solver.add(
                    var_expr >=
                    solver.ctx().real_val(
                        std::to_string(static_cast<real_t>(t.lower_bound))
                            .c_str()));
                solver.add(
                    var_expr <=
                    solver.ctx().real_val(
                        std::to_string(static_cast<real_t>(t.upper_bound))
                            .c_str()));
            }
        },
        type);
}
} // namespace

void Z3Optimizer::add_variable(const VariableType& type)
{
    const auto& var = env_.add_variable(type);
    add_variable_bounds(*optimizer_, var, type);
}

namespace {
z3::expr weighted_variable(Z3Environment& env, size_t var, real_t weight)
{
    assert(weight != 0.);
    if (weight == 1.) {
        return env.get_variable(var);
    } else if (weight == -1.) {
        return -env.get_variable(var);
    } else {
        return env.get_variable(var) * env.add_constant(Value(weight));
    }
}

z3::expr z3_linear_expression(
    Z3Environment& env,
    const LinearCombination<size_t, real_t>& linexpr)
{
    assert(!linexpr.empty());
    z3::expr result =
        weighted_variable(env, linexpr.begin()->first, linexpr.begin()->second);
    for (auto it = linexpr.begin() + 1; it != linexpr.end(); ++it) {
        result = result + weighted_variable(env, it->first, it->second);
    }
    return result;
}

z3::expr
z3_linear_constraint(Z3Environment& env, const LinearConstraint& constraint)
{
    z3::expr rhs = env.add_constant(Value(constraint.rhs));
    assert(!constraint.empty());
    switch (constraint.type) {
    case LinearConstraint::LESS_EQUAL:
        return z3_linear_expression(env, constraint) <= rhs;
    case police::LinearConstraint::GREATER_EQUAL:
        return z3_linear_expression(env, constraint) >= rhs;
    case police::LinearConstraint::EQUAL:
        return z3_linear_expression(env, constraint) == rhs;
    }
    POLICE_UNREACHABLE();
}

z3::expr
z3_disjunction(Z3Environment& env, const LinearConstraintDisjunction& disj)
{
    assert(!disj.empty());
    z3::expr result = z3_linear_constraint(env, disj[0]);
    for (size_t i = 1; i < disj.size(); ++i) {
        result = result || z3_linear_constraint(env, disj[i]);
    }
    return result;
}

z3::expr z3_max(Z3Environment& env, const vector<size_t>& variables)
{
    assert(!variables.empty());
    z3::expr result = env.get_variable(variables[0]);
    for (size_t i = 1; i < variables.size(); ++i) {
        result = z3::max(result, env.get_variable(variables[i]));
    }
    return result;
}

} // namespace

void Z3Optimizer::add_constraint(const LinearConstraint& constraint)
{
    optimizer_->add(z3_linear_constraint(env_, constraint));
}

void Z3Optimizer::add_constraint(const LinearConstraintDisjunction& constraint)
{
    optimizer_->add(z3_disjunction(env_, constraint));
}

void Z3Optimizer::add_constraint(const MaxConstraint& constraint)
{
    z3::expr max_val = z3_max(env_, constraint.elements);
    if (constraint.c != -std::numeric_limits<real_t>::infinity()) {
        max_val = z3::max(max_val, env_.add_constant(Value(constraint.c)));
    }
    optimizer_->add(env_.get_variable(constraint.y) == max_val);
}

void Z3Optimizer::add_constraint(const IndicatorLPConstraint& constraint)
{
    optimizer_->add(z3::implies(
        env_.get_variable(constraint.indicator_var) ==
            env_.add_constant(Value((int_t)constraint.indicator_value)),
        z3_linear_constraint(env_, constraint.constraint)));
}

void Z3Optimizer::add_max_objective(const LinearExpression& expr)
{
    objective_ = optimizer_
                     ->maximize(
                         z3_linear_expression(env_, expr) +
                         env_.add_constant(Value(expr.bias)))
                     .h();
}

void Z3Optimizer::add_min_objective(const LinearExpression& expr)
{
    objective_ = optimizer_
                     ->minimize(
                         z3_linear_expression(env_, expr) +
                         env_.add_constant(Value(expr.bias)))
                     .h();
}

void Z3Optimizer::push_snapshot()
{
    optimizer_->push();
    env_.push_snapshot();
}

void Z3Optimizer::pop_snapshot()
{
    optimizer_->pop();
    env_.pop_snapshot();
}

bool Z3Optimizer::check()
{
    const auto result = optimizer_->check();
    switch (result) {
    case z3::check_result::sat: return true;
    case z3::check_result::unsat: return false;
    default: POLICE_RUNTIME_ERROR("unknown z3 check status " << result);
    }
    POLICE_UNREACHABLE();
}

bool Z3Optimizer::check(const vector<LinearConstraint>& assumptions)
{
    z3::expr_vector vec(*env_.context());
    for (const auto& expr : assumptions) {
        vec.push_back(z3_linear_constraint(env_, expr));
    }
    const auto result = optimizer_->check(vec);
    switch (result) {
    case z3::check_result::sat: return true;
    case z3::check_result::unsat: return false;
    default: POLICE_RUNTIME_ERROR("unknown z3 check status " << result);
    }
    POLICE_UNREACHABLE();
}

real_t Z3Optimizer::get_objective_value() const
{
    auto lb = env_.from_z3_expression(
        optimizer_->lower(z3::optimize::handle(objective_)));
    auto ub = env_.from_z3_expression(
        optimizer_->upper(z3::optimize::handle(objective_)));
    assert(lb.is_constant() && ub.is_constant());
    auto rlb = static_cast<real_t>(lb.get_value());
    auto rub = static_cast<real_t>(ub.get_value());
    assert(std::abs(rub - rlb) <= LP_PRECISION);
    return (rub + rlb) / 2.;
}

Z3Model Z3Optimizer::get_model() const
{
    last_model_.reset(new z3::model(optimizer_->get_model()));
    return Z3Model(&env_, last_model_.get());
}

vector<LinearConstraint> Z3Optimizer::get_unsat_core() const
{
    vector<LinearConstraint> result;
    const auto core = optimizer_->unsat_core();
    result.reserve(core.size());
    std::transform(
        core.begin(),
        core.end(),
        std::back_inserter(result),
        [&](const z3::expr& expr) {
            auto e = env_.from_z3_expression(expr);
            return LinearConstraint::from_expression(e);
        });
    return result;
}

void Z3Optimizer::dump(std::ostream& out)
{
    out << (*optimizer_) << std::endl;
}

} // namespace police

#endif
