#include "police/verifiers/ic3/cube.hpp"
#include "police/model.hpp"
#include "police/storage/variable_space.hpp"
#include "police/utils/io.hpp"

#include <algorithm>
#include <cassert>

namespace police::ic3 {

Interval::Interval(Value val)
    : lb(val)
    , ub(std::move(val))
{
}

Interval::Interval(Value lb, Value ub)
    : lb(std::move(lb))
    , ub(std::move(ub))
{
}

IntervalSet::IntervalSet(Interval interval)
{
    push_back(std::move(interval));
}

IntervalSet::const_iterator IntervalSet::begin() const
{
    return vector<Interval>::cbegin();
}

IntervalSet::const_iterator IntervalSet::end() const
{
    return vector<Interval>::cend();
}

const Interval& IntervalSet::operator[](size_t pos) const
{
    return vector<Interval>::at(pos);
}

const Interval& IntervalSet::at(size_t pos) const
{
    return vector<Interval>::at(pos);
}

IntervalSet::const_iterator IntervalSet::insert(const Interval& closed_interval)
{
    assert(closed_interval.lb <= closed_interval.ub);
    auto pos = std::lower_bound(
        begin(),
        end(),
        closed_interval.lb,
        [](auto&& interval, const auto& lb) { return interval.ub < lb; });
    if (pos == end()) {
        push_back(closed_interval);
        return begin() + size() - 1;
    }
    pos->lb = std::min(pos->lb, closed_interval.lb);
    auto ub = std::upper_bound(
        pos,
        end(),
        closed_interval.ub,
        [](const Value& ub, auto&& interval) { return interval.ub > ub; });
    if (ub == end()) {
        pos->ub = closed_interval.ub;
    } else {
        if (ub->lb <= closed_interval.ub) {
            ub->lb = pos->lb;
        } else {
            assert(pos != ub && (ub - 1)->ub <= closed_interval.ub);
            --ub;
            ub->lb = pos->lb;
            ub->ub = closed_interval.ub;
        }
        pos = erase(pos, ub);
    }
    return pos;
}

void IntervalSet::remove(const Interval& open_interval)
{
    assert(open_interval.lb <= open_interval.ub);
    auto pos = std::upper_bound(
        begin(),
        end(),
        open_interval.lb,
        [](const auto& lb, auto&& interval) { return interval.ub > lb; });
    if (pos == end()) {
        return;
    }
    pos->ub = open_interval.lb;
    if (pos->lb <= pos->ub) {
        ++pos;
    }
    assert(pos == end() || pos->lb > open_interval.lb);
    auto ub = std::lower_bound(
        pos,
        end(),
        open_interval.ub,
        [](auto&& interval, const Value& ub) { return interval.lb < ub; });
    if (ub == pos) {
        return;
    }
    if ((ub - 1)->ub >= open_interval.ub) {
        --ub;
        ub->lb = open_interval.ub;
    }
    erase(pos, ub);
}

namespace {
std::pair<bool, bool>
covers_other(const IntervalSet& subj, const IntervalSet& other)
{
    bool strict = false;
    auto pos = subj.begin();
    const auto e = subj.end();
    for (const auto& interval : other) {
        pos = std::lower_bound(
            pos,
            e,
            interval.lb,
            [](const Interval& i, const Value& lb) { return i.ub < lb; });
        if (pos == e || pos->lb > interval.lb || pos->ub < interval.ub) {
            return {false, false};
        }
        strict = strict || pos->lb != interval.lb || pos->ub != interval.ub;
    }
    return {true, strict};
}
} // namespace

bool IntervalSet::operator<=(const IntervalSet& other) const
{
    return covers_other(other, *this).first;
}

bool IntervalSet::operator<(const IntervalSet& other) const
{
    return covers_other(other, *this).second;
}

bool IntervalSet::contains(const Value& val) const
{
    auto lb = std::lower_bound(
        begin(),
        end(),
        val,
        [](const Interval& i, const Value& val) { return i.ub < val; });
    return lb != end() && lb->lb <= val;
}

void IntervalSet::swap(IntervalSet& other)
{
    vector<Interval>::swap(other);
}

IntervalSet& IntervalSet::operator&=(const IntervalSet& other)
{
    auto i = begin();
    auto j = other.begin();
    auto je = other.end();
    while (i != end() && j != je) {
        if (i->ub < j->lb) {
            i = erase(i);
        } else if (j->ub < i->lb) {
            ++j;
        } else if (i->ub < j->ub) {
            i->lb = std::max(i->lb, j->lb);
            ++i;
        } else {
            i = vector<Interval>::insert(
                    i,
                    Interval(std::max(i->lb, j->lb), j->ub)) +
                1;
            i->lb = j->ub;
            ++j;
        }
    }
    erase(i, end());
    return *this;
}

Cube::iterator Cube::find(size_t var)
{
    auto pos = const_cast<const Cube*>(this)->find(var);
    return begin() + std::distance(cbegin(), pos);
}

Cube::const_iterator Cube::find(size_t var) const
{
    auto it = std::lower_bound(
        begin(),
        end(),
        var,
        [](const auto& pr, const auto& var) { return pr.first < var; });
    if (it != end() && it->first == var) return it;
    return end();
}

const Interval& Cube::get(size_t var) const
{
    assert(has(var));
    return find(var)->second;
}

Interval& Cube::get(size_t var)
{
    assert(has(var));
    return find(var)->second;
}

bool Cube::has(size_t var) const
{
    return find(var) != end();
}

void Cube::merge(Interval& i, const Interval& j) const
{
    i.lb = std::min(i.lb, j.lb);
    i.ub = std::max(i.ub, j.ub);
}

bool Cube::intersect(Interval& i, const Interval& j) const
{
    i.lb = std::max(i.lb, j.lb);
    i.ub = std::min(i.ub, j.ub);
    return i.lb <= i.ub;
}

std::pair<Cube::iterator, bool>
Cube::emplace(size_t var, const Interval& values)
{
    auto it = std::lower_bound(
        vector::begin(),
        vector::end(),
        var,
        [](const auto& pr, const auto& var) { return pr.first < var; });
    if (it == end() || it->first > var) {
        it = insert(it, {var, values});
        return {it, true};
    }
    return {it, false};
}

void Cube::extend(size_t var, const Interval& values)
{
    auto res = emplace(var, values);
    if (!res.second) {
        merge(at(std::distance(begin(), res.first)).second, values);
    }
}

void Cube::shrink(size_t var, const Interval& values)
{
    auto it = find(var);
    if (it != end()) {
        intersect(at(std::distance(begin(), it)).second, values);
    } else {
        emplace(var, values);
    }
}

Cube& Cube::operator&=(const Cube& other)
{
    auto i = vector::begin();
    auto j = other.begin();
    auto e = other.end();
    for (; i != vector::end() && j != e;) {
        if (i->first < j->first) {
            ++i;
        } else if (j->first < i->first) {
            i = insert(i, *j);
            ++i;
            ++j;
        } else {
            intersect(i->second, j->second);
            ++i;
            ++j;
        }
    }
    for (; j != e; ++j) {
        push_back(*j);
    }
    return *this;
}

Cube& Cube::operator|=(const Cube& other)
{
    auto i = vector::begin();
    auto j = other.begin();
    auto e = other.end();
    for (; i != vector::end() && j != e;) {
        if (i->first < j->first) {
            i = vector::erase(i);
        } else if (j->first < i->first) {
            ++j;
        } else {
            intersect(i->second, j->second);
            ++i;
            ++j;
        }
    }
    return *this;
}

Cube::JsonDump::JsonDump(const Cube* cube, const VariableSpace* vars)
    : cube(cube)
    , variables(vars)
{
}

Cube::JsonDump Cube::json(const VariableSpace& vspace) const
{
    return {this, &vspace};
}

void Cube::dump(std::ostream& out) const
{
    out << "[";
    for (auto i = begin(); i != end(); ++i) {
        if (i != begin()) {
            out << ", ";
        }
        out << i->first << ": <" << i->second.lb << ", " << i->second.ub << ">";
    }
    out << "]";
}

namespace {
unsigned check_subsumption(const Cube& a, const Cube& b)
{
    if (a.size() < b.size()) {
        return false;
    }
    auto i = a.begin();
    auto ie = a.end();
    auto j = b.begin();
    auto je = b.end();
    int n = a.size();
    int m = b.size();
    unsigned strict = 0;
    for (; i != ie && j != je && n >= m;) {
        if (i->first < j->first) {
            ++i;
            --n;
        } else if (i->first > j->first || !(i->second <= j->second)) {
            return false;
        } else {
            strict |= i->second < j->second;
            ++i;
            ++j;
            --n;
            --m;
        }
    }
    return static_cast<unsigned>(j == je) | (strict << 1);
}
} // namespace

bool Cube::operator<(const Cube& other) const
{
    return check_subsumption(*this, other) == 3;
}

bool Cube::operator<=(const Cube& other) const
{
    return check_subsumption(*this, other) & 1;
}

Cube::Cube(const Cube& cube)
    : vector(cube)
{
}

Cube::Cube(vector<std::pair<size_t, Interval>>&& data)
    : vector(std::move(data))
{
    assert(std::is_sorted(begin(), end(), [](const auto& a, const auto& b) {
        return a.first < b.first;
    }));
}

Cube::Cube(Cube&& cube)
    : vector(std::move(cube))
{
}

Cube::Cube(const flat_state& vec)
    : vector(vec.size())
{
    for (int i = vec.size() - 1; i >= 0; --i) {
        at(i).first = i;
        at(i).second = vec[i];
    }
}

Cube& Cube::operator=(const Cube& other)
{
    vector::operator=(other);
    return *this;
}

Cube& Cube::operator=(Cube&& other)
{
    vector::operator=(std::move(other));
    return *this;
}

void Cube::set(size_t var, Interval i)
{
    auto it = emplace(var).first;
    it->second = std::move(i);
}

Cube::iterator Cube::erase(iterator it)
{
    return vector::erase(it);
}

Cube::iterator Cube::erase(iterator first, iterator last)
{
    return vector::erase(first, last);
}

Cube::Dump Cube::dump(const VariableSpace& vars) const
{
    return Dump(this, &vars);
}

Cube::Dump::Dump(const Cube* cube, const VariableSpace* vars)
    : cube(cube)
    , variables(vars)
{
}

Cube::JsonDumpModel::JsonDumpModel(const Cube* cube, const Model* model)
    : cube(cube)
    , model(model)
{
}

Cube::JsonDumpModel Cube::json(const Model& model) const
{
    return {this, &model};
}

} // namespace police::ic3

namespace police {

std::ostream& operator<<(std::ostream& out, const police::ic3::Interval& i)
{
    return out << "[" << i.lb << ", " << i.ub << "]";
}

std::ostream& operator<<(std::ostream& out, const police::ic3::Cube& core)
{
    core.dump(out);
    return out;
}

std::ostream&
operator<<(std::ostream& out, const police::ic3::Cube::JsonDump& cube)
{
    auto print_value = [&](size_t var, Value value, std::string_view op) {
        out << "{\"var\": \"" << (*cube.variables)[var].name << "\", \"op\": \""
            << op << "\", \"value\": " << static_cast<int_t>(value) << "}";
    };
    out << generic_print_sequence(
        *cube.cube,
        [&](std::ostream& out,
            const std::pair<size_t, police::ic3::Interval>& x) {
            if (x.second.lb == x.second.ub) {
                print_value(x.first, x.second.lb, "=");
            } else {
                const auto ub =
                    (*cube.variables)[x.first].type.get_upper_bound();
                const auto lb =
                    (*cube.variables)[x.first].type.get_lower_bound();
                bool sep = false;
                if (x.second.lb > lb) {
                    print_value(
                        x.first,
                        x.second.lb,
                        x.second.lb == ub ? "=" : ">=");
                    sep = true;
                }
                if (x.second.ub < ub) {
                    if (sep) out << ", ";
                    print_value(
                        x.first,
                        x.second.ub,
                        x.second.ub == lb ? "=" : "<=");
                }
            }
        });
    return out;
}

std::ostream&
operator<<(std::ostream& out, const police::ic3::Cube::JsonDumpModel& cube)
{
    auto print_value = [&](size_t var, Value value, std::string_view op) {
        out << "{\"var\": \"" << cube.model->variables[var].name
            << "\", \"op\": \"" << op
            << "\", \"value\": " << static_cast<int_t>(value)
            << ", \"value_name\": " << "\""
            << cube.model->get_value_name(var, value) << "\""
            << "}";
    };
    out << generic_print_sequence(
        *cube.cube,
        [&](std::ostream& out,
            const std::pair<size_t, police::ic3::Interval>& x) {
            if (x.second.lb == x.second.ub) {
                print_value(x.first, x.second.lb, "=");
            } else {
                const auto ub =
                    cube.model->variables[x.first].type.get_upper_bound();
                const auto lb =
                    cube.model->variables[x.first].type.get_lower_bound();
                bool sep = false;
                if (x.second.lb > lb) {
                    print_value(
                        x.first,
                        x.second.lb,
                        x.second.lb == ub ? "=" : ">=");
                    sep = true;
                }
                if (x.second.ub < ub) {
                    if (sep) out << ", ";
                    print_value(
                        x.first,
                        x.second.ub,
                        x.second.ub == lb ? "=" : "<=");
                }
            }
        });
    return out;
}

std::ostream& operator<<(std::ostream& out, const police::ic3::Cube::Dump& cube)
{
    out << generic_print_sequence(
        *cube.cube,
        [cube](
            std::ostream& out,
            const std::pair<size_t, police::ic3::Interval>& x) {
            if (x.second.lb == x.second.ub) {
                out << (*cube.variables)[x.first].name << "=" << x.second.lb;
            } else {
                bool sep = false;
                if (x.second.lb !=
                    (*cube.variables)[x.first].type.get_lower_bound()) {
                    out << (*cube.variables)[x.first].name
                        << ">=" << x.second.lb;
                    sep = true;
                }
                if (x.second.ub !=
                    (*cube.variables)[x.first].type.get_upper_bound()) {
                    if (sep) out << ", ";
                    out << (*cube.variables)[x.first].name
                        << "<=" << x.second.ub;
                }
            }
        });
    return out;
}

} // namespace police
