#include "nnlp_generator.hpp"
#include "police/ffnn_lp_encoder.hpp"
#include "police/variable_substitution.hpp"

#include <algorithm>
#include <cmath>
#include <iostream>
#include <numeric>

namespace {
constexpr int MIN_VALUE = 0;
constexpr int MAX_VALUE = 10;
constexpr int MIN_DOMAIN_SIZE = 3;
constexpr int MIN_COEF = 1;
constexpr int MAX_COEF = 10;
constexpr double BINARY_PROBABILITY = 0.3;
constexpr bool EQUALITY = 0;
} // namespace

void NNLPData::encode(police::NNLP& lp) const
{
    police::vector<police::size_t> in_vars;
    for (const auto& var : variables) {
        in_vars.push_back(lp.add_variable(var.type));
    }
    size_t out_var = police::encode_ffnn_in_lp(lp, network, in_vars);
    for (size_t out = 0; out < network.layers.back().size(); ++out) {
        in_vars.push_back(out_var + out);
    }
    for (const auto& c : constraints) {
        const auto c_ = police::substitute_vars(c, in_vars);
        lp.add_constraint(c_);
    }
    for (const auto& c : disjunctions) {
        const auto c_ = police::substitute_vars(c, in_vars);
        lp.add_constraint(c_);
    }
}

void NNLPData::dump(std::ostream& out) const
{
    for (const auto& var : variables) {
        out << var.id << ": " << var.type << "\n";
    }
    for (const auto& d : constraints) {
        out << d << "\n";
    }
    out << "\n";
    for (const auto& d : disjunctions) {
        out << d << "\n";
    }
}

NNLPDataGenerator::NNLPDataGenerator(
    size_t num_per_nn,
    size_t num_constraints,
    size_t num_disjunctions,
    size_t max_disjunction_size,
    FFNNGenerator nn_generator,
    int seed)
    : nn_generator_(std::move(nn_generator))
    , rng_(seed)
    , in_vars_(nn_generator_.get().input.size())
    , out_vars_(nn_generator_.get().layers.back().size())
    , left_(num_per_nn)
    , num_per_nn_(num_per_nn)
    , num_constraints_(num_constraints)
    , num_disjunctions_(num_disjunctions)
    , disjunction_size_(max_disjunction_size)
{
    std::iota(in_vars_.begin(), in_vars_.end(), 0);
    std::iota(out_vars_.begin(), out_vars_.end(), in_vars_.size());
    next();
}

bool NNLPDataGenerator::next()
{
    if (left_ == 0) {
        if (!nn_generator_.next()) {
            return false;
        }
        left_ = num_per_nn_;
    }
    --left_;
    auto network = nn_generator_.get();
    auto variables = generate_variables(network.input.size());
    auto assignment = generate_assignment(variables);
    auto out = network(assignment);
    police::vector<police::LinearConstraint> constraints;
    constraints.push_back(generate_in_constraint(assignment, out));
    constraints.push_back(generate_out_constraint(assignment, out));
    for (int i = num_constraints_ - 2; i > 0; --i) {
        if (rng_(0, 1)) {
            constraints.push_back(generate_in_constraint(assignment, out));
        } else {
            constraints.push_back(generate_out_constraint(assignment, out));
        }
    }
    police::vector<police::LinearConstraintDisjunction> disjs;
    for (int i = num_disjunctions_; i > 0; --i) {
        disjs.push_back(generate_disjunction(assignment, out));
    }
    cur_ = {
        std::move(variables),
        std::move(network),
        std::move(constraints),
        std::move(disjs)};
    return true;
}

police::VariableSpace
NNLPDataGenerator::generate_variables(size_t num_vars) const
{
    police::VariableSpace variables;
    for (int i = num_vars; i > 0; --i) {
        if (rng_() < BINARY_PROBABILITY) {
            variables.add_variable("", police::BoolType());
        } else {
            int lb = rng_(MIN_VALUE, MAX_VALUE - MIN_DOMAIN_SIZE);
            int ub = rng_(lb + MIN_DOMAIN_SIZE, MAX_VALUE);
            variables.add_variable("", police::BoundedIntType(lb, ub));
        }
    }
    return variables;
}

police::vector<double>
NNLPDataGenerator::generate_assignment(const police::VariableSpace& vars) const
{
    police::vector<double> vals(vars.size());
    for (int i = vals.size() - 1; i >= 0; --i) {
        if (vars[i].type.index() == 0) {
            vals[i] = rng_(0, 1);
        } else {
            const auto bounds = std::get<police::BoundedIntType>(vars[i].type);
            vals[i] = rng_(bounds.lower_bound, bounds.upper_bound);
        }
    }
    return vals;
}

police::LinearConstraint NNLPDataGenerator::generate_constraint(
    police::vector<size_t>::const_iterator it,
    police::vector<size_t>::const_iterator end,
    const police::vector<double>& in,
    const police::vector<double>& out) const
{
    police::LinearConstraint res(
        police::LinearConstraint::Type(rng_(0, 1 + EQUALITY)));
    double rhs = 0;
    bool has_out = false;
    for (; it != end; ++it) {
        int coef = 0;
        while (!coef) coef = rng_(MIN_COEF, MAX_COEF);
        res.insert(*it, coef);
        if (*it >= in.size()) {
            rhs += out[*it - in.size()] * coef;
            has_out = true;
        } else {
            rhs += in[*it] * coef;
        }
    }
    if (has_out && res.type == police::LinearConstraint::EQUAL) {
        res.type = police::LinearConstraint::Type(rng_(0, 1));
    }
    assert(res.size() >= 1u);
    switch (res.type) {
    case police::LinearConstraint::LESS_EQUAL: res.rhs = std::ceil(rhs); break;
    case police::LinearConstraint::GREATER_EQUAL:
        res.rhs = std::floor(rhs);
        break;
    case police::LinearConstraint::EQUAL:
        assert(!has_out); // only integer components
        res.rhs = std::round(rhs);
        break;
    }
    return res;
}

police::LinearConstraint NNLPDataGenerator::generate_in_constraint(
    const police::vector<double>& in,
    const police::vector<double>& out) const
{
    assert(in_vars_.size() == in.size());
    rng_.shuffle(in_vars_.begin(), in_vars_.end());
    size_t size = rng_(1, std::min(4ul, in.size() - 1));
    return generate_constraint(
        in_vars_.begin(),
        in_vars_.begin() + size,
        in,
        out);
}

police::LinearConstraint NNLPDataGenerator::generate_out_constraint(
    const police::vector<double>& in,
    const police::vector<double>& out) const
{
    assert(out_vars_.size() == out.size());
    rng_.shuffle(out_vars_.begin(), out_vars_.end());
    bool compare = rng_(0, 4);
    return generate_constraint(
        out_vars_.begin(),
        out_vars_.begin() + 1 + compare,
        in,
        out);
}

police::LinearConstraint NNLPDataGenerator::generate_violated_constraint(
    const police::vector<double>& in,
    const police::vector<double>& out) const
{
    auto res = rng_(0, 2) ? generate_in_constraint(in, out)
                          : generate_out_constraint(in, out);
    switch (res.type) {
    case police::LinearConstraint::LESS_EQUAL:
        res.rhs -= rng_(1, MAX_COEF);
        break;
    case police::LinearConstraint::GREATER_EQUAL:
        res.rhs += rng_(1, MAX_COEF);
        break;
    case police::LinearConstraint::EQUAL: {
        const auto off = rng_(MIN_COEF, MAX_COEF);
        res.rhs += off + (off == 0);
    } break;
    }
    return res;
}

police::LinearConstraintDisjunction NNLPDataGenerator::generate_disjunction(
    const police::vector<double>& in,
    const police::vector<double>& out) const
{
    const size_t size = rng_(0, disjunction_size_ - 1);
    police::LinearConstraintDisjunction res;
    res |= generate_in_constraint(in, out);
    for (int i = size; i > 0; --i) {
        if (rng_(0, 2)) {
            res |= generate_violated_constraint(in, out);
        } else if (rng_(0, 1)) {
            res |= generate_in_constraint(in, out);
        } else {
            res |= generate_out_constraint(in, out);
        }
    }
    return res;
}
