#pragma once

#include "police/base_types.hpp"
#include "police/storage/lin_vector.hpp"
#include "police/storage/matrix.hpp"
#include "police/storage/vector.hpp"

#include <memory>

namespace police::cg {

class LinearLayer;
class ReluLayer;
class MaxPoolLayer;
class ExpandToConstLayer;

class NodeVisitor {
public:
    virtual ~NodeVisitor() = default;

    virtual void visit([[maybe_unused]] const LinearLayer* node) = 0;

    virtual void visit([[maybe_unused]] const ReluLayer* node) = 0;

    virtual void visit([[maybe_unused]] const MaxPoolLayer* node) = 0;

    virtual void visit([[maybe_unused]] const ExpandToConstLayer* node) = 0;
};

struct ComputeGraphSize {
    size_t inputs;
    size_t outputs;
    size_t layers;
    size_t neurons;
    size_t linear;
    size_t relus;
    size_t pools;
};

class Node {
public:
    Node() = default;

    virtual ~Node() = default;

    [[nodiscard]]
    LinVector operator()(const vector<real_t>& input) const;

    [[nodiscard]]
    std::shared_ptr<Node> copy() const;

    [[nodiscard]]
    std::shared_ptr<Node> deep_copy() const;

    [[nodiscard]]
    size_t num_relus() const;

    [[nodiscard]]
    size_t num_pools() const;

    [[nodiscard]]
    size_t num_neurons() const;

    [[nodiscard]]
    size_t num_linears() const;

    [[nodiscard]]
    bool is_leaf() const;

    void set_successor(std::shared_ptr<Node> successor);

    [[nodiscard]]
    size_t num_layers() const;

    [[nodiscard]]
    const std::shared_ptr<Node>& successor() const;

    [[nodiscard]]
    Node* leaf();

    [[nodiscard]]
    const Node* leaf() const;

    virtual void accept(NodeVisitor* visitor) const = 0;

    [[nodiscard]]
    virtual size_t num_inputs() const = 0;

    [[nodiscard]]
    virtual size_t num_outputs() const = 0;

    [[nodiscard]]
    ComputeGraphSize statistics() const;

    [[nodiscard]]
    virtual std::string name() const = 0;

private:
    [[nodiscard]]
    virtual size_t get_num_neurons() const = 0;

    [[nodiscard]]
    virtual size_t get_num_linears() const
    {
        return 0;
    }

    [[nodiscard]]
    virtual size_t get_num_relus() const
    {
        return 0;
    }

    [[nodiscard]]
    virtual size_t get_num_pools() const
    {
        return 0;
    }

    [[nodiscard]]
    virtual LinVector compute(const LinVector& input) const = 0;

    [[nodiscard]]
    virtual std::shared_ptr<Node> copy_impl() const = 0;

    [[nodiscard]]
    virtual std::shared_ptr<Node> deep_copy_impl() const = 0;

    std::shared_ptr<Node> successor_ = nullptr;
};

template <typename NodeType>
class NodeClass
    : public Node
    , public std::enable_shared_from_this<NodeType> {
public:
    void accept(NodeVisitor* visitor) const override
    {
        visitor->visit(static_cast<const NodeType*>(this));
    }

private:
    std::shared_ptr<Node> copy_impl() const override
    {
        return std::make_shared<NodeType>(*static_cast<const NodeType*>(this));
    }
};

class LinearLayer final : public NodeClass<LinearLayer> {
public:
    LinearLayer(Matrix weights, LinVector biases);

    [[nodiscard]]
    size_t num_inputs() const override;

    [[nodiscard]]
    size_t num_outputs() const override;

    [[nodiscard]]
    const Matrix& get_weights() const
    {
        return *weights_;
    }

    [[nodiscard]]
    const LinVector& get_biases() const
    {
        return *biases_;
    }

    [[nodiscard]]
    Matrix& get_weights()
    {
        return *weights_;
    }

    [[nodiscard]]
    LinVector& get_biases()
    {
        return *biases_;
    }

    [[nodiscard]]
    std::string name() const override
    {
        return "linear";
    }

private:
    [[nodiscard]]
    size_t get_num_linears() const override
    {
        return num_outputs();
    }

    [[nodiscard]]
    size_t get_num_neurons() const override;

    [[nodiscard]]
    LinVector compute(const LinVector& input) const override;

    [[nodiscard]]
    std::shared_ptr<Node> deep_copy_impl() const override;

    std::shared_ptr<Matrix> weights_;
    std::shared_ptr<LinVector> biases_;
};

class ReluLayer final : public NodeClass<ReluLayer> {
public:
    explicit ReluLayer(size_t dimension);

    [[nodiscard]]
    size_t num_inputs() const override;

    [[nodiscard]]
    size_t num_outputs() const override;

    [[nodiscard]]
    std::string name() const override
    {
        return "relu";
    }

private:
    [[nodiscard]]
    size_t get_num_neurons() const override;

    [[nodiscard]]
    size_t get_num_relus() const override;

    [[nodiscard]]
    LinVector compute(const LinVector& input) const override;

    [[nodiscard]]
    std::shared_ptr<Node> deep_copy_impl() const override;

    [[maybe_unused]]
    std::shared_ptr<LinVector> zero_;
};

class MaxPoolLayer final : public NodeClass<MaxPoolLayer> {
public:
    MaxPoolLayer(vector<vector<size_t>> pools, size_t inputs);

    [[nodiscard]]
    size_t num_inputs() const override;

    [[nodiscard]]
    size_t num_outputs() const override;

    [[nodiscard]]
    const vector<size_t> get_input_refs(size_t pool_idx) const;

    [[nodiscard]]
    std::string name() const override
    {
        return "maxpool";
    }

    [[nodiscard]]
    const vector<vector<size_t>>& get_pools() const
    {
        return *pools_;
    }

private:
    [[nodiscard]]
    size_t get_num_neurons() const override;

    [[nodiscard]]
    size_t get_num_pools() const override;

    [[nodiscard]]
    LinVector compute(const LinVector& input) const override;

    [[nodiscard]]
    std::shared_ptr<Node> deep_copy_impl() const override;

    std::shared_ptr<vector<vector<size_t>>> pools_;
    [[maybe_unused]]
    size_t inputs_;
};

class ExpandToConstLayer final : public NodeClass<ExpandToConstLayer> {
public:
    ExpandToConstLayer(
        size_t num_outputs,
        vector<size_t> remap_inputs,
        real_t scalar);

    [[nodiscard]]
    size_t num_inputs() const override;

    [[nodiscard]]
    size_t num_outputs() const override;

    [[nodiscard]]
    std::string name() const override;

    [[nodiscard]]
    real_t get_scalar() const
    {
        return scalar_;
    }

    [[nodiscard]]
    const vector<size_t>& get_input_remap() const
    {
        return *remap_inputs_;
    }

private:
    [[nodiscard]]
    size_t get_num_neurons() const override
    {
        return num_outputs();
    }

    [[nodiscard]]
    LinVector compute(const LinVector& input) const override;

    [[nodiscard]]
    std::shared_ptr<Node> deep_copy_impl() const override;

    std::shared_ptr<vector<size_t>> remap_inputs_;
    real_t scalar_;
    size_t num_outputs_;
};

} // namespace police::cg

namespace police {
std::ostream& operator<<(std::ostream& out, const cg::ComputeGraphSize& stats);
}
