#pragma once

#include "ffnn_generator.hpp"
#include "police/nnlp.hpp"
#include "police/storage/ffnn.hpp"
#include "police/storage/variable_space.hpp"

#include <catch2/catch.hpp>

struct NNLPData {
    police::VariableSpace variables;
    police::FeedForwardNeuralNetwork<> network;
    police::vector<police::LinearConstraint> constraints;
    police::vector<police::LinearConstraintDisjunction> disjunctions;

    void encode(police::NNLP& lp) const;

    void dump(std::ostream& out) const;
};

class NNLPDataGenerator final : public Catch::Generators::IGenerator<NNLPData> {
public:
    NNLPDataGenerator(
        size_t num_per_nn,
        size_t num_constraints,
        size_t num_disjunctions,
        size_t max_disjunction_size,
        FFNNGenerator nn_generator,
        int seed = 1734);

    bool next() override;

    [[nodiscard]]
    const NNLPData& get() const override
    {
        return cur_;
    }

    [[nodiscard]]
    police::VariableSpace generate_variables(size_t num_vars) const;

    [[nodiscard]]
    police::vector<double>
    generate_assignment(const police::VariableSpace& vars) const;

    [[nodiscard]]
    police::LinearConstraint generate_constraint(
        police::vector<size_t>::const_iterator var_begin,
        police::vector<size_t>::const_iterator var_end,
        const police::vector<double>& in,
        const police::vector<double>& out) const;

    [[nodiscard]]
    police::LinearConstraint generate_in_constraint(
        const police::vector<double>& in,
        const police::vector<double>& out) const;

    [[nodiscard]]
    police::LinearConstraint generate_out_constraint(
        const police::vector<double>& in,
        const police::vector<double>& out) const;

    [[nodiscard]]
    police::LinearConstraint generate_violated_constraint(
        const police::vector<double>& in,
        const police::vector<double>& out) const;

    [[nodiscard]]
    police::LinearConstraintDisjunction generate_disjunction(
        const police::vector<double>& in,
        const police::vector<double>& out) const;

private:
    FFNNGenerator nn_generator_;
    NNLPData cur_;
    police::RNG rng_;

    mutable police::vector<size_t> in_vars_;
    mutable police::vector<size_t> out_vars_;

    size_t left_ = 0;
    size_t num_per_nn_;
    size_t num_constraints_;
    size_t num_disjunctions_;
    size_t disjunction_size_;
};
