#include "binary_flip.hpp"

#include "police/action.hpp"
#include "police/model.hpp"
#include "police/storage/variable_space.hpp"
#include "police/successor_generator/successor_generator.hpp"
#include "police/verifiers/epic3/cube.hpp"

#include <algorithm>

namespace {
using namespace police;
police::VariableSpace create_variables(int bits)
{
    VariableSpace vars;
    for (int i = 0; i < bits; ++i) {
        vars.add_variable("", BoundedIntType(0, 1));
    }
    return vars;
}

Action create_flip_action(int bit)
{
    LinearConstraintConjunction guard;

    LinearConstraint c(LinearConstraint::EQUAL);
    c.insert(bit, 1.);
    c.rhs = 0.;
    guard &= std::move(c);

    vector<Assignment> effs;
    LinearExpression e;
    e.bias = 1.;
    effs.emplace_back(bit, e);

    return Action(2 * bit, std::move(guard), {Outcome(std::move(effs))});
}

Action create_unflip_action(int bit, int num_bits)
{
    LinearConstraintConjunction guard;

    LinearConstraint c(LinearConstraint::EQUAL);
    c.insert(bit, 1.);
    c.rhs = 1.;
    guard &= std::move(c);

    vector<Outcome> out;

    vector<Assignment> effs;
    LinearExpression e;
    e.bias = 0.;
    effs.emplace_back(bit, e);
    out.emplace_back(effs);

    e.bias = 1.;
    for (int bit2 = bit + 1; bit2 < num_bits; bit2 += 2) {
        effs.emplace_back(bit2, e);
        out.emplace_back(effs);
        effs.pop_back();
    }

    return Action(2 * bit + 1, std::move(guard), std::move(out));
}
} // namespace

Model binary_flip_model(int bits)
{
    auto vars = create_variables(bits);
    vector<Action> edges;
    for (int bit = 0; bit < bits; ++bit) {
        edges.push_back(create_flip_action(bit));
        edges.push_back(create_unflip_action(bit, bits));
    }
    return Model(std::move(vars), std::move(edges), {});
}

police::LinearCondition binary_flip_initial_state(int bits)
{
    LinearConstraint at_most_one(LinearConstraint::LESS_EQUAL);
    for (int i = 0; i < bits; i += 2) {
        at_most_one.insert(i, 1.);
    }
    at_most_one.rhs = 1.;

    LinearConstraint none(LinearConstraint::LESS_EQUAL);
    for (int i = 1; i < bits; i += 2) {
        none.insert(i, 1.);
    }
    none.rhs = 0.;

    LinearConstraintConjunction conj;
    conj &= at_most_one;
    conj &= none;

    LinearCondition result;
    result |= conj;

    return result;
}

police::LinearConstraintDisjunction binary_flip_negated_goal(int bits)
{
    police::LinearConstraintDisjunction result;
    LinearConstraint g(LinearConstraint::LESS_EQUAL);
    for (int i = 0; i < bits; i += 2) {
        g.rhs = 0.;
        g.insert(i, 1.);
        result |= g;
        g.clear();
    }
    return result;
}

police::LinearCondition binary_flip_goal(int bits)
{
    police::LinearConstraintConjunction result;
    LinearConstraint g(LinearConstraint::GREATER_EQUAL);
    for (int i = 0; i < bits; i += 2) {
        g.rhs = 1.;
        g.insert(i, 1.);
        result &= g;
        g.clear();
    }
    police::LinearCondition cond;
    cond |= result;
    return cond;
}

police::LinearCondition binary_flip_avoid(int bits)
{
    police::LinearCondition result;
    LinearConstraint g(LinearConstraint::GREATER_EQUAL);
    for (int i = 1; i < bits; i += 2) {
        g.rhs = 1.;
        g.insert(i, 1.);
        result |= g;
        g.clear();
    }
    return result;
}

police::vector<successor_generator::Successor>
BFlipSuccessorGenerator::operator()(const police::flat_state& state) const
{
    assert(state.size() == bits);
    bool is_goal = true;
    for (auto b = 0u; b < state.size(); b += 2) {
        is_goal = is_goal && state[b] > 0.;
    }
    if (is_goal) {
        return {};
    }
    vector<double> nn_in(state.size());
    std::transform(
        state.begin(),
        state.end(),
        nn_in.begin(),
        [](const auto& v) -> real_t { return v; });
    vector<double> nn_out = (*net)(nn_in);
    vector<std::pair<double, police::size_t>> ranked_actions(nn_out.size());
    for (police::size_t i = 0; i < nn_out.size(); ++i) {
        ranked_actions[i] = {-nn_out[i], i};
    }
    std::sort(ranked_actions.begin(), ranked_actions.end());
    vector<successor_generator::Successor> res;
    flat_state copy(state);
    for (const auto& [_, a] : ranked_actions) {
        if (a % 2 == 0) {
            if (!static_cast<int_t>(state[a / 2])) {
                copy[a / 2] = Value(1);
                res.emplace_back(copy, a);
                break;
            }
        } else {
            if (static_cast<int_t>(state[(a - 1) / 2])) {
                copy[(a - 1) / 2] = Value(0);
                res.emplace_back(copy, a);
                for (police::size_t b = ((a - 1) / 2) + 1; b < bits; b += 2) {
                    if (!static_cast<int_t>(state[b])) {
                        copy[b] = Value(1);
                        res.emplace_back(copy, a);
                        copy[b] = Value(0);
                    }
                }
                break;
            }
        }
        if (!appfilter) break;
    }
    return res;
}

bool BFlipCheckGoal::operator()(const police::epic3::Cube& cube) const
{
    for (auto it = cube.begin(); it != cube.end(); ++it) {
        if (it->first % 2 == 0 && it->second.ub <= 0) {
            return false;
        }
    }
    return true;
}

bool BFlipCheckAvoid::operator()(const police::epic3::Cube& cube) const
{
    for (auto it = cube.begin(); it != cube.end(); ++it) {
        if (it->first % 2 == 1 && it->second.ub <= 0) {
            return false;
        }
    }
    return true;
}
