#pragma once

#include "police/storage/flat_state.hpp"
#include "police/storage/value.hpp"
#include "police/storage/variable_space.hpp"
#include "police/storage/vector.hpp"
#include "police/utils/hash.hpp"

#include <algorithm>
#include <cassert>
#include <limits>
#include <tuple>

namespace police {
struct Model;
}

namespace police::ic3 {

struct Interval {
    constexpr static Value MIN = Value(std::numeric_limits<int_t>::min());
    constexpr static Value MAX = Value(std::numeric_limits<int_t>::max());

    Interval() = default;
    explicit Interval(Value value);
    Interval(Value lb, Value ub);

    [[nodiscard]]
    static Interval make_lb(Value lb)
    {
        return Interval(std::move(lb), MAX);
    }

    [[nodiscard]]
    static Interval make_ub(Value ub)
    {
        return Interval(MIN, std::move(ub));
    }

    [[nodiscard]]
    bool operator==(const Interval& other) const
    {
        return lb == other.lb && ub == other.ub;
    }

    [[nodiscard]]
    bool operator<=(const Interval& other) const
    {
        return (lb >= other.lb) && (ub <= other.ub);
    }

    [[nodiscard]]
    bool operator<(const Interval& other) const
    {
        return ((lb > other.lb) && (ub <= other.ub)) ||
               ((lb >= other.lb) && (ub < other.ub));
    }

    [[nodiscard]]
    bool operator>=(const Interval& other) const
    {
        return other <= *this;
    }

    [[nodiscard]]
    bool operator>(const Interval& other) const
    {
        return other < *this;
    }

    Interval& operator&=(const Interval& other)
    {
        lb = std::max(lb, other.lb);
        ub = std::min(ub, other.ub);
        return *this;
    }

    [[nodiscard]]
    Interval operator&&(const Interval& other) const
    {
        Interval result(*this);
        result &= other;
        return result;
    }

    [[nodiscard]]
    bool contains(const Value& value) const
    {
        return value >= lb && value <= ub;
    }

    void tighten(Value lb, Value ub)
    {
        this->lb = std::max(this->lb, lb);
        this->ub = std::min(this->ub, ub);
    }

    void swap(Interval& other)
    {
        std::swap(lb, other.lb);
        std::swap(ub, other.ub);
    }

    Interval& operator=(const Value& value)
    {
        lb = value;
        ub = value;
        return *this;
    }

    [[nodiscard]]
    bool has_lb() const
    {
        return lb != MIN;
    }

    [[nodiscard]]
    bool has_ub() const
    {
        return ub != MAX;
    }

    Value lb = MIN;
    Value ub = MAX;
};

class IntervalSet : private vector<Interval> {
public:
    using iterator = vector<Interval>::iterator;
    using const_iterator = vector<Interval>::const_iterator;

    IntervalSet() = default;
    explicit IntervalSet(Interval interval);

    using vector<Interval>::erase;
    using vector<Interval>::size;
    using vector<Interval>::cbegin;
    using vector<Interval>::cend;

    const_iterator begin() const;
    const_iterator end() const;

    const Interval& operator[](size_t pos) const;
    const Interval& at(size_t pos) const;

    const_iterator insert(const Interval& closed_interval);

    void remove(const Interval& open_interval);

    IntervalSet& operator&=(const IntervalSet& other);

    [[nodiscard]]
    bool operator<=(const IntervalSet& other) const;

    [[nodiscard]]
    bool operator<(const IntervalSet& other) const;

    [[nodiscard]]
    bool operator>=(const IntervalSet& other) const
    {
        return other <= *this;
    }

    [[nodiscard]]
    bool operator>(const IntervalSet& other) const
    {
        return other < *this;
    }

    [[nodiscard]]
    bool operator==(const IntervalSet& other) const
    {
        return size() == other.size() &&
               std::equal(begin(), end(), other.begin());
    }

    [[nodiscard]]
    IntervalSet operator&&(const IntervalSet& other) const
    {
        IntervalSet result(*this);
        result &= other;
        return result;
    }

    [[nodiscard]]
    bool contains(const Value& val) const;

    void swap(IntervalSet& other);

private:
    iterator begin() { return vector<Interval>::begin(); }
    iterator end() { return vector<Interval>::end(); }
};

class Cube : private vector<std::pair<size_t, Interval>> {
public:
    using iterator = vector<std::pair<size_t, Interval>>::iterator;
    using const_iterator = vector<std::pair<size_t, Interval>>::const_iterator;
    using value_type = vector<std::pair<size_t, Interval>>::value_type;
    using reference = vector<std::pair<size_t, Interval>>::reference;
    using const_reference =
        vector<std::pair<size_t, Interval>>::const_reference;

    Cube() = default;
    Cube(const Cube& cube);
    Cube(Cube&& cube);

    explicit Cube(const flat_state& vec);
    explicit Cube(vector<std::pair<size_t, Interval>>&& data);

    using vector<std::pair<size_t, Interval>>::size;
    using vector<std::pair<size_t, Interval>>::empty;
    using vector<std::pair<size_t, Interval>>::clear;
    using vector<std::pair<size_t, Interval>>::swap;
    using vector<std::pair<size_t, Interval>>::begin;
    using vector<std::pair<size_t, Interval>>::end;
    using vector<std::pair<size_t, Interval>>::cbegin;
    using vector<std::pair<size_t, Interval>>::cend;

    using vector<std::pair<size_t, Interval>>::operator[];
    using vector<std::pair<size_t, Interval>>::operator=;

    Cube& operator=(const Cube& other);
    Cube& operator=(Cube&& other);

    Cube& operator&=(const Cube& other);

    Cube& operator|=(const Cube& other);

    [[nodiscard]]
    bool operator<=(const Cube& other) const;

    [[nodiscard]]
    bool operator<(const Cube& other) const;

    [[nodiscard]]
    bool operator>=(const Cube& other) const
    {
        return other <= *this;
    }

    [[nodiscard]]
    bool operator>(const Cube& other) const
    {
        return other < *this;
    }

    [[nodiscard]]
    bool operator==(const Cube& other) const
    {
        return size() == other.size() &&
               std::equal(begin(), end(), other.begin());
    }

    [[nodiscard]]
    const_iterator find(size_t var) const;

    [[nodiscard]]
    iterator find(size_t var);

    [[nodiscard]]
    const Interval& get(size_t var) const;

    [[nodiscard]]
    Interval& get(size_t var);

    [[nodiscard]]
    bool has(size_t var) const;

    [[nodiscard]]
    bool contains(const flat_state& vec) const
    {
        return std::all_of(begin(), end(), [&](const auto& i) {
            assert(i.first < vec.size());
            return i.second.contains(vec[i.first]);
        });
    }

    void set(size_t var, Interval i);

    void extend(size_t var, const Interval& values);

    void shrink(size_t var, const Interval& values);

    std::pair<iterator, bool>
    emplace(size_t var, const Interval& values = Interval());

    iterator erase(iterator pos);

    iterator erase(iterator first, iterator last);

    struct JsonDump {
        JsonDump(const Cube* cube, const VariableSpace* vspace);

        const Cube* cube;
        const VariableSpace* variables;
    };

    [[nodiscard]]
    JsonDump json(const VariableSpace& vspace) const;

    struct JsonDumpModel {
        JsonDumpModel(const Cube* cube, const Model* model);

        const Cube* cube;
        const Model* model;
    };

    [[nodiscard]]
    JsonDumpModel json(const Model& model) const;

    void dump(std::ostream& out) const;

    struct Dump {
        Dump(const Cube* cube, const VariableSpace* vspace);

        const Cube* cube;
        const VariableSpace* variables;
    };

    [[nodiscard]]
    Dump dump(const VariableSpace& vspace) const;

private:
    void merge(Interval& i, const Interval& j) const;

    bool intersect(Interval& i, const Interval& j) const;
};

} // namespace police::ic3

namespace police {

std::ostream& operator<<(std::ostream& out, const police::ic3::Interval& i);

std::ostream& operator<<(std::ostream& out, const police::ic3::Cube& cube);

std::ostream&
operator<<(std::ostream& out, const police::ic3::Cube::JsonDump& cube);

std::ostream&
operator<<(std::ostream& out, const police::ic3::Cube::Dump& cube);

std::ostream&
operator<<(std::ostream& out, const police::ic3::Cube::JsonDumpModel& cube);

template <>
struct hash<police::ic3::Interval> {
    [[nodiscard]]
    std::size_t operator()(const police::ic3::Interval& i) const
    {
        return get_hash(std::tie(i.lb, i.ub));
    }
};

template <>
struct hash<police::ic3::IntervalSet> {
    [[nodiscard]]
    std::size_t operator()(const police::ic3::IntervalSet& i) const
    {
        return hash_sequence(i.begin(), i.end());
    }
};

template <>
struct hash<police::ic3::Cube> {
    [[nodiscard]]
    std::size_t operator()(const police::ic3::Cube& cube) const
    {
        return hash_sequence(cube.begin(), cube.end());
    }
};
} // namespace police
