#include "police/verifiers/ic3/sat_based/sat_interface/smt.hpp"
#include "police/variable_substitution.hpp"
#include "police/expressions/expression.hpp"
#include "police/expressions/variable.hpp"
#include "police/storage/variable_space.hpp"
#include "police/verifiers/ic3/cube_utils.hpp"
#include <algorithm>
#include <limits>

namespace police::ic3 {

SatInterfaceSMT::SatInterfaceSMT(
    SMT* base_smt,
    size_t frame_var,
    vector<size_t> in_vars,
    vector<size_t> out_vars)
    : input_vars_(std::move(in_vars))
    , output_vars_(std::move(out_vars))
    , frame_var_(frame_var)
    , unsolvable_frame_(std::numeric_limits<size_t>::max())
    , smt_(base_smt)
{
    smt_->push_snapshot();
    assert(std::is_sorted(input_vars_.begin(), input_vars_.end()));
    smt_var_refs_.resize(input_vars_.back() + 1);
    for (int i = input_vars_.size() - 1; i >= 0; --i) {
        smt_var_refs_[input_vars_[i]] = i;
    }
}

SatInterfaceSMT::SatInterfaceSMT(
    SMT* base_smt,
    vector<size_t> in_vars,
    vector<size_t> out_vars)
    : input_vars_(std::move(in_vars))
    , output_vars_(std::move(out_vars))
    , frame_var_(base_smt->add_variable("frame_var", IntegerType()))
    , smt_(base_smt)
{
    smt_->push_snapshot();
}

namespace {
expressions::Expression block_expression(size_t var_id, const Interval& iv)
{
    assert(iv.lb.get_type() == Value::Type::INT);
    expressions::Variable var(var_id);
    if (iv.lb == iv.ub) {
        return expressions::not_equal(var, iv.lb);
    } else {
        return expressions::less_equal(var, iv.lb - Value(1)) ||
               expressions::greater_equal(var, iv.ub + Value(1));
    }
}

// expressions::Expression
// block_expression_set(size_t var_id, const IntervalSet& iset)
// {
//     assert(iset.size() > 0u);
//     expressions::Expression result = block_expression(var_id, iset[0]);
//     for (auto i = 1u; i < iset.size(); ++i) {
//         result = result && block_expression(var_id, iset[i]);
//     }
//     return result;
// }

expressions::Expression
block_expression(const Cube& cube, const vector<size_t>& vars)
{
    assert(cube.size() > 0u);
    auto it = cube.begin();
    expressions::Expression result =
        block_expression(vars[it->first], it->second);
    ++it;
    for (; it != cube.end(); ++it) {
        result = result || block_expression(vars[it->first], it->second);
    }
    return result;
}

expressions::Expression enforce_expression(size_t var_id, const Interval& iv)
{
    expressions::Variable var(var_id);
    if (iv.lb == iv.ub) {
        return expressions::equal(var, iv.lb);
    } else {
        return expressions::greater_equal(var, iv.lb) &&
               expressions::less_equal(var, iv.ub);
    }
}

expressions::Expression
enforce_expression(const Cube& cube, const vector<size_t>& vars)
{
    assert(cube.size() > 0u);
    auto it = cube.begin();
    expressions::Expression result =
        enforce_expression(vars[it->first], it->second);
    for (; it != cube.end(); ++it) {
        result = result && enforce_expression(vars[it->first], it->second);
    }
    return result;
}

LinearConstraintConjunction
to_linco(const vector<expressions::Expression>& exprs)
{
    LinearConstraintConjunction result;
    for (const auto& expr : exprs) {
        auto cond = LinearCondition::from_expression(expr);
        assert(cond.size() == 1u);
        result &= std::move(cond[0]);
    }
    return result;
}

} // namespace

std::pair<bool, size_t>
SatInterfaceSMT::is_blocked(const Cube& cube, size_t frame_id)
{
    if (unsolvable_frame_ <= frame_id) {
        unsat_core_.clear();
        return {true, frame_id};
    }
    // smt_->dump();
    smt_->push_snapshot();
    smt_->add_constraint(expressions::equal(
        expressions::Variable(frame_var_),
        Value(static_cast<int_t>(frame_id))));
    // smt_->add_constraint(block_expression(cube, output_vars_));
    const auto result = smt_->solve({enforce_expression(cube, input_vars_)});
    if (result == SMT::Status::UNSAT) {
        auto core = to_linco(smt_->get_unsat_core());
        unsat_core_.clear();
        if (core.empty()) {
            unsolvable_frame_ = frame_id;
        } else {
            insert_into_cube(
                unsat_core_,
                substitute_vars(std::move(core), smt_var_refs_));
        }
    }
    smt_->pop_snapshot();
#ifndef NDEBUG
    if (unsolvable_frame_ == frame_id) {
        smt_->push_snapshot();
        smt_->add_constraint(expressions::equal(
            expressions::Variable(frame_var_),
            Value(static_cast<int_t>(frame_id))));
        assert(smt_->solve() == SMT::Status::UNSAT);
        smt_->pop_snapshot();
    }
#endif
    return {result == SMT::Status::UNSAT, frame_id};
}

void SatInterfaceSMT::set_blocked(const Cube& cube, size_t frame_id)
{
    auto not_cube = block_expression(cube, output_vars_);
    smt_->add_constraint(
        not_cube || expressions::greater_equal(
                        expressions::Variable(frame_var_),
                        Value(static_cast<int_t>(frame_id) + 2)));
}

void SatInterfaceSMT::add_frame()
{
}

void SatInterfaceSMT::clear_frames()
{
    smt_->pop_snapshot();
    smt_->push_snapshot();
}

} // namespace police::ic3
