#pragma once

#include "police/argument_store.hpp"
#include "police/utils/rng.hpp"

#include <cassert>
#include <memory>

namespace police {

struct Model;
struct VerificationProperty;
template <typename>
class FeedForwardNeuralNetwork;
class NeuralNetworkPolicy;
class AddTreePolicy;
class CGPolicy;
namespace jani {
class Model;
}
class ExecutionUnit;

struct GlobalArguments {
public:
    std::vector<std::string> arguments;
    std::shared_ptr<ArgumentStore> storage = nullptr;
    std::shared_ptr<Model> model = nullptr;
    std::shared_ptr<VerificationProperty> property = nullptr;
    std::shared_ptr<NeuralNetworkPolicy> nn_policy = nullptr;
    std::shared_ptr<AddTreePolicy> tree_policy = nullptr;
    std::shared_ptr<CGPolicy> cg_policy = nullptr;
    std::shared_ptr<RNG> rng = nullptr;
    std::shared_ptr<jani::Model> jani_model = nullptr;
    std::shared_ptr<ExecutionUnit> unit = nullptr;
    std::string model_path = "";
    std::string policy_path = "";
    std::string policy_adapter_path = "";
    int rng_seed = 1734;
    bool applicability_masking = false;

    GlobalArguments()
        : storage(std::make_shared<ArgumentStore>())
        , rng(std::make_shared<RNG>(rng_seed))
    {
    }

    [[nodiscard]]
    const Model& get_model() const
    {
        assert(model != nullptr);
        return *model;
    }

    [[nodiscard]]
    const VerificationProperty& get_property() const
    {
        assert(property != nullptr);
        return *property;
    }

    [[nodiscard]]
    bool has_policy() const
    {
        return nn_policy != nullptr || tree_policy != nullptr ||
               cg_policy != nullptr;
    }

    [[nodiscard]]
    bool has_nn_policy() const
    {
        return nn_policy != nullptr;
    }

    [[nodiscard]]
    const CGPolicy& get_cg_policy() const
    {
        assert(cg_policy != nullptr);
        return *cg_policy;
    }

    [[nodiscard]]
    const NeuralNetworkPolicy& get_nn_policy() const
    {
        assert(nn_policy != nullptr);
        return *nn_policy;
    }

    [[nodiscard]]
    bool has_addtree_policy() const
    {
        return tree_policy != nullptr;
    }

    [[nodiscard]]
    bool has_cg_policy() const
    {
        return cg_policy != nullptr;
    }

    [[nodiscard]]
    const AddTreePolicy& get_addtree_policy() const
    {
        assert(tree_policy != nullptr);
        return *tree_policy;
    }

    [[nodiscard]]
    int get_seed() const
    {
        return rng_seed;
    }

    [[nodiscard]]
    std::shared_ptr<RNG> get_rng() const
    {
        return rng;
    }

    [[nodiscard]]
    bool should_mask_applicable() const
    {
        return applicability_masking;
    }

    void set_rng_seed(int seed)
    {
        rng_seed = seed;
        rng = std::make_shared<RNG>(seed);
    }
};

} // namespace police
