#pragma once

#include "police/base_types.hpp"
#include "police/macros.hpp"
#include "police/storage/value.hpp"
#include "police/storage/vector.hpp"
#include "police/utils/type_traits.hpp"

#include <cassert>
#include <iterator>
#include <limits>
#include <type_traits>
#include <variant>

namespace police {

class VariableSpace;

struct BoolType {};
struct IntegerType {};
struct RealType {};

struct BoundedIntType {
    [[nodiscard]]
    bool is_lower_bounded() const
    {
        return lower_bound != std::numeric_limits<int_t>::min();
    }

    [[nodiscard]]
    bool is_upper_bounded() const
    {
        return upper_bound != std::numeric_limits<int_t>::max();
    }

    int_t lower_bound = std::numeric_limits<int_t>::min();
    int_t upper_bound = std::numeric_limits<int_t>::max();
};

struct BoundedRealType {
    [[nodiscard]]
    bool is_lower_bounded() const
    {
        return lower_bound != -std::numeric_limits<real_t>::infinity();
    }

    [[nodiscard]]
    bool is_upper_bounded() const
    {
        return upper_bound != std::numeric_limits<real_t>::infinity();
    }

    real_t lower_bound = -std::numeric_limits<real_t>::infinity();
    real_t upper_bound = std::numeric_limits<real_t>::infinity();
};

inline std::uint32_t get_interval_size(const BoundedIntType& t)
{
    return static_cast<std::uint32_t>(
        static_cast<std::int64_t>(t.upper_bound) -
        static_cast<std::int64_t>(t.lower_bound));
}

struct VariableType
    : public std::variant<
          BoolType,
          BoundedIntType,
          IntegerType,
          BoundedRealType,
          RealType> {
    using variant<
        BoolType,
        BoundedIntType,
        IntegerType,
        BoundedRealType,
        RealType>::variant;

    [[nodiscard]]
    constexpr bool is_bool() const
    {
        return index() == 0;
    }

    [[nodiscard]]
    constexpr bool is_int() const
    {
        return index() == 2;
    }

    [[nodiscard]]
    constexpr bool is_real() const
    {
        return index() == 4;
    }

    [[nodiscard]]
    constexpr bool is_bounded_int() const
    {
        return index() == 1;
    }

    [[nodiscard]]
    constexpr bool is_bounded_real() const
    {
        return index() == 3;
    }

    [[nodiscard]]
    constexpr bool is_bounded() const
    {
        return is_bool() || is_bounded_int() || is_bounded_real();
    }

    [[nodiscard]]
    Value get_lower_bound() const
    {
        assert(is_bounded());
        return std::visit(
            [](auto&& t) -> Value {
                using T = std::decay_t<decltype(t)>;
                if constexpr (
                    std::is_same_v<T, BoundedIntType> ||
                    std::is_same_v<T, BoundedRealType>) {
                    return Value(t.lower_bound);
                }
                return Value(0);
            },
            *this);
    }

    [[nodiscard]]
    Value get_upper_bound() const
    {
        assert(is_bounded());
        return std::visit(
            [](auto&& t) -> Value {
                using T = std::decay_t<decltype(t)>;
                if constexpr (
                    std::is_same_v<T, BoundedIntType> ||
                    std::is_same_v<T, BoundedRealType>) {
                    return Value(t.upper_bound);
                }
                return Value(1);
            },
            *this);
    }

    [[nodiscard]]
    constexpr VariableType unbounded() const
    {
        return std::visit(
            [](auto&& t) -> VariableType {
                using T = std::decay_t<decltype(t)>;
                if constexpr (
                    std::is_same_v<T, BoundedIntType> ||
                    std::is_same_v<T, BoolType>) {
                    return IntegerType();
                } else if constexpr (std::is_same_v<T, BoundedRealType>) {
                    return RealType();
                } else {
                    return t;
                }
            },
            *this);
    }

    [[nodiscard]]
    constexpr VariableType relax() const
    {
        return std::visit(
            [](auto&& t) -> VariableType {
                using T = std::decay_t<decltype(t)>;
                if constexpr (std::is_same_v<T, BoundedIntType>) {
                    auto res = BoundedRealType();
                    if (t.is_lower_bounded()) {
                        res.lower_bound = static_cast<real_t>(t.lower_bound);
                    }
                    if (t.is_upper_bounded()) {
                        res.upper_bound = static_cast<real_t>(t.upper_bound);
                    }
                    return res;
                } else if constexpr (std::is_same_v<T, BoolType>) {
                    return BoundedRealType(0, 1);
                } else if constexpr (std::is_same_v<T, IntegerType>) {
                    return RealType();
                } else {
                    return t;
                }
            },
            *this);
    }

    [[nodiscard]]
    constexpr Value::Type value_type() const
    {
        return std::visit(
            [](auto&& t) -> Value::Type {
                using T = std::decay_t<decltype(t)>;
                if constexpr (std::is_same_v<T, BoolType>) {
                    return Value::Type::BOOL;
                } else if constexpr (
                    std::is_same_v<T, IntegerType> ||
                    std::is_same_v<T, BoundedIntType>) {
                    return Value::Type::INT;
                } else {
                    return Value::Type::REAL;
                }
            },
            *this);
    }

    [[nodiscard]]
    static constexpr VariableType from_value_type(const Value::Type& t)
    {
        switch (t) {
        case Value::Type::BOOL: return BoolType();
        case Value::Type::INT: return IntegerType();
        case Value::Type::REAL: return RealType();
        }
        POLICE_UNREACHABLE();
    }
};

template <bool Constant>
struct VariableReference {
    VariableReference(
        size_t id,
        police::make_const_t<Constant, VariableType>& type,
        police::make_const_t<Constant, identifier_name_t>& name)
        : id(id)
        , type(type)
        , name(name)
    {
    }

    VariableReference(VariableReference&& other)
        : id(std::move(other.id))
        , type(std::move(other.type))
        , name(std::move(other.name))
    {
    }

    template <bool B, typename = std::enable_if_t<!B && B != Constant>>
    VariableReference(const VariableReference<B>& other)
        : id(other.id)
        , type(other.type)
        , name(other.name)
    {
    }

    VariableReference& operator=(VariableReference&& other)
    {
        id = std::move(other.id);
        type = std::move(other.type);
        name = std::move(other.name);
    }

    void swap(VariableReference& other)
    {
        std::swap(id, other.id);
        std::swap(type, other.type);
        std::swap(name, other.name);
    }

    police::make_const_t<Constant, size_t> id;
    police::make_const_t<Constant, VariableType>& type;
    police::make_const_t<Constant, identifier_name_t>& name;
};

template <bool Constant>
struct VariablePointer {
    VariablePointer(
        size_t id,
        police::make_const_t<Constant, VariableType>& type,
        police::make_const_t<Constant, identifier_name_t>& name)
        : ref(id, type, name)
    {
    }

    VariableReference<Constant>* operator->() { return &ref; }

    const VariableReference<Constant>* operator->() const { return &ref; }

    VariableReference<Constant> ref;
};

namespace detail::variable_space {

template <bool Constant>
class Iterator {
public:
    using value_type = VariableReference<Constant>;
    using difference_type = int;
    using pointer_type = VariablePointer<Constant>;

    Iterator() = default;

    template <bool T, typename = std::enable_if_t<T != Constant && !T>>
    Iterator(const Iterator<T>& other)
        : vspace_(other.vspace_)
        , idx_(other.idx_)
    {
    }

    [[nodiscard]]
    VariableReference<Constant> operator*() const
    {
        return {idx_, vspace_->get_type(idx_), vspace_->get_name(idx_)};
    }

    [[nodiscard]]
    VariablePointer<Constant> operator->() const
    {
        return {idx_, vspace_->get_type(idx_), vspace_->get_name(idx_)};
    }

    [[nodiscard]]
    VariableReference<Constant> operator[](difference_type n) const
    {
        return *(*this + n);
    }

    [[nodiscard]]
    constexpr bool operator==(const Iterator& other) const
    {
        return idx_ == other.idx_;
    }

    [[nodiscard]]
    constexpr auto operator<=>(const Iterator& other) const
    {
        return idx_ <=> other.idx_;
    }

    constexpr Iterator& operator++()
    {
        ++idx_;
        return *this;
    }

    constexpr Iterator operator++(int)
    {
        auto copy(*this);
        ++*this;
        return copy;
    }

    constexpr Iterator& operator--()
    {
        --idx_;
        return *this;
    }

    constexpr Iterator operator--(int)
    {
        auto copy(*this);
        --*this;
        return copy;
    }

    constexpr Iterator& operator+=(difference_type n)
    {
        idx_ += n;
        return *this;
    }

    constexpr Iterator& operator-=(difference_type n)
    {
        idx_ -= n;
        return *this;
    }

    [[nodiscard]]
    constexpr Iterator operator+(difference_type n) const
    {
        auto copy(*this);
        copy += n;
        return copy;
    }

    [[nodiscard]]
    constexpr Iterator operator-(difference_type n) const
    {
        auto copy(*this);
        copy -= n;
        return copy;
    }

    [[nodiscard]]
    difference_type operator-(const Iterator& other) const
    {
        return idx_ - other.idx_;
    }

    friend constexpr Iterator
    operator+(difference_type lhs, const Iterator& rhs)
    {
        return rhs + lhs;
    }

    friend constexpr Iterator
    operator-(difference_type lhs, const Iterator& rhs)
    {
        return rhs - lhs;
    }

private:
    using vspace_ptr = police::make_const_t<Constant, VariableSpace>*;

    friend class police::VariableSpace;
    constexpr Iterator(vspace_ptr vspace, size_t idx)
        : vspace_(vspace)
        , idx_(idx)
    {
    }

    vspace_ptr vspace_ = nullptr;
    size_t idx_{};
};

static_assert(std::random_access_iterator<Iterator<true>>);
static_assert(std::random_access_iterator<Iterator<false>>);

} // namespace detail::variable_space

class VariableSpace {
public:
    using iterator = detail::variable_space::Iterator<false>;
    using const_iterator = detail::variable_space::Iterator<true>;
    using reference = VariableReference<false>;
    using const_reference = VariableReference<true>;
    using value = const_reference;

    VariableSpace() = default;

    size_t add_variable(identifier_name_t name, VariableType type);

    [[nodiscard]]
    size_t size() const;

    [[nodiscard]]
    reference operator[](size_t var_id);

    [[nodiscard]]
    const_reference operator[](size_t var_id) const;

    [[nodiscard]]
    reference at(size_t var_id)
    {
        return (*this)[var_id];
    }

    [[nodiscard]]
    const_reference at(size_t var_id) const
    {
        return (*this)[var_id];
    }

    [[nodiscard]]
    identifier_name_t& get_name(size_t var_id);

    [[nodiscard]]
    VariableType& get_type(size_t var_id);

    [[nodiscard]]
    const identifier_name_t& get_name(size_t var_id) const;

    [[nodiscard]]
    const VariableType& get_type(size_t var_id) const;

    [[nodiscard]]
    iterator begin();

    [[nodiscard]]
    iterator end();

    [[nodiscard]]
    const_iterator begin() const;

    [[nodiscard]]
    const_iterator end() const;

    [[nodiscard]]
    const_iterator cbegin() const;

    [[nodiscard]]
    const_iterator cend() const;

    void erase(iterator pos);

    void erase(iterator begin, iterator end);

    iterator insert(const_iterator begin, const_iterator end);

    void clear();

    [[nodiscard]]
    size_t get_variable_id(std::string_view name) const;

private:
    vector<VariableType> types_;
    vector<identifier_name_t> names_;
};

std::ostream& operator<<(std::ostream& out, const police::VariableType& type);

} // namespace police
