#pragma once

#include "police/base_types.hpp"
#include "police/expressions/constants.hpp"
#include "police/expressions/expression.hpp"
#include "police/expressions/variable.hpp"
#include "police/storage/vector.hpp"
#include "police/utils/hash.hpp"
#include "police/utils/type_traits.hpp"

#include <algorithm>
#include <cassert>
#include <iterator>
#include <limits>
#include <sstream>
#include <string>
#include <type_traits>

namespace police {

template <typename RefType, typename CoefType>
class LinearCombination {
private:
    template <bool Const>
    struct __reference_type {
        __reference_type(
            make_const_t<Const, RefType>& ref,
            make_const_t<Const, CoefType>& coef)
            : first(ref)
            , second(coef)
        {
        }

        __reference_type(const __reference_type&) = delete;

        __reference_type& operator=(const __reference_type& other)
        {
            first = other.first;
            second = other.second;
            return *this;
        }

        template <bool C>
        operator std::enable_if_t<C != Const && C, __reference_type<C>>() const
        {
            return {first, second};
        }

        make_const_t<Const, RefType>& first;
        make_const_t<Const, CoefType>& second;
    };

    template <bool Const>
    struct __pointer_type : public __reference_type<Const> {
        using __reference_type<Const>::__reference_type;

        __reference_type<Const>* operator->() { return this; }

        const __reference_type<Const>* operator->() const { return this; }
    };

    template <bool Const>
    class __iterator {
        template <typename RefType_, typename CoefType_>
        friend class LinearCombination;

    public:
        using reference_type = __reference_type<Const>;
        using pointer_type = __pointer_type<Const>;
        using value_type = reference_type;
        using difference_type = int_t;

        __iterator() = default;

        __iterator(
            make_const_t<Const, vector<RefType>>* refs,
            make_const_t<Const, vector<CoefType>>* coefs,
            size_t pos = 0)
            : refs_(refs)
            , coefs_(coefs)
            , pos_(pos)
        {
        }

        template <
            typename T,
            typename =
                std::enable_if_t<std::is_same_v<T, __iterator<false>> && Const>>
        __iterator(const T& other)
            : refs_(other.refs)
            , coefs_(other.coefs)
            , pos_(other.pos_)
        {
        }

        [[nodiscard]]
        bool operator==(const __iterator& other) const
        {
            return pos_ == other.pos_;
        }

        [[nodiscard]]
        auto operator<=>(const __iterator& other) const
        {
            return pos_ <=> other.pos_;
        }

        [[nodiscard]]
        difference_type operator-(const __iterator& other) const
        {
            return static_cast<difference_type>(pos_) -
                   static_cast<difference_type>(other.pos_);
        }

        __iterator& operator++()
        {
            ++pos_;
            return *this;
        }

        __iterator operator++(int)
        {
            __iterator temp(*this);
            ++*this;
            return temp;
        }

        __iterator& operator+=(difference_type n)
        {
            pos_ += n;
            return *this;
        }

        __iterator& operator--()
        {
            --pos_;
            return *this;
        }

        __iterator operator--(int)
        {
            __iterator temp(*this);
            --*this;
            return temp;
        }

        __iterator& operator-=(difference_type n) { return *this += -n; }

        [[nodiscard]]
        __iterator operator+(difference_type n) const
        {
            __iterator temp(*this);
            temp += n;
            return temp;
        }

        [[nodiscard]]
        __iterator operator-(difference_type n) const
        {
            __iterator temp(*this);
            temp -= n;
            return temp;
        }

        [[nodiscard]]
        friend __iterator operator+(difference_type n, const __iterator& it)
        {
            return it + n;
        }

        [[nodiscard]]
        friend __iterator operator-(difference_type n, const __iterator& it)
        {
            return it - n;
        }

        [[nodiscard]]
        reference_type operator*() const
        {
            return reference_type(refs_->at(pos_), coefs_->at(pos_));
        }

        [[nodiscard]]
        pointer_type operator->() const
        {
            return {refs_->at(pos_), coefs_->at(pos_)};
        }

        [[nodiscard]]
        reference_type operator[](difference_type n) const
        {
            return {refs_->at(pos_ + n), coefs_->at(pos_ + n)};
        }

    private:
        make_const_t<Const, vector<RefType>>* refs_ = nullptr;
        make_const_t<Const, vector<CoefType>>* coefs_ = nullptr;
        size_t pos_ = 0;
    };

    static_assert(std::random_access_iterator<__iterator<true>>);

public:
    using reference_type = __reference_type<false>;
    using const_reference_type = __reference_type<true>;
    using iterator = __iterator<false>;
    using const_iterator = __iterator<true>;

    LinearCombination() = default;

    [[nodiscard]]
    size_t size() const
    {
        return refs_.size();
    }

    [[nodiscard]]
    bool empty() const
    {
        return refs_.empty();
    }

    void clear()
    {
        refs_.clear();
        coefs_.clear();
    }

    void reserve(size_t size)
    {
        refs_.reserve(size);
        coefs_.reserve(size);
    }

    [[nodiscard]]
    iterator begin()
    {
        return {&refs_, &coefs_, 0};
    }

    [[nodiscard]]
    iterator end()
    {
        return {&refs_, &coefs_, size()};
    }

    [[nodiscard]]
    const_iterator begin() const
    {
        return {&refs_, &coefs_, 0};
    }

    [[nodiscard]]
    const_iterator end() const
    {
        return {&refs_, &coefs_, size()};
    }

    [[nodiscard]]
    const_iterator cbegin() const
    {
        return {&refs_, &coefs_, 0};
    }

    [[nodiscard]]
    const_iterator cend() const
    {
        return {&refs_, &coefs_, size()};
    }

    [[nodiscard]]
    iterator find(RefType ref)
    {
        return {
            &refs_,
            coefs_,
            static_cast<const LinearCombination*>(this)->find(ref).pos_};
    }

    [[nodiscard]]
    const_iterator find(RefType ref) const
    {
        const auto pos = lower_bound(ref);
        if (pos == end() || ref < refs_[pos]) {
            return end();
        }
        return pos;
    }

    std::pair<iterator, bool> insert(const RefType& ref, const CoefType& coef)
    {
        const auto pos = lower_bound(ref);
        bool inserted = false;
        if (pos == end() || ref < pos->first) {
            insert_at(pos.pos_, ref, coef);
            inserted = true;
        }
        return {pos, inserted};
    }

    iterator erase(iterator pos)
    {
        refs_.erase(refs_.begin() + pos.pos_);
        coefs_.erase(coefs_.begin() + pos.pos_);
        return pos;
    }

    iterator erase(iterator first, iterator last)
    {
        refs_.erase(refs_.begin() + first.pos_, refs_.begin() + last.pos_);
        coefs_.erase(coefs_.begin() + first.pos_, coefs_.begin() + last.pos_);
        return first;
    }

    void remove(RefType ref)
    {
        const auto pos = find(ref);
        if (pos != end()) {
            erase(pos);
        }
    }

    [[nodiscard]]
    const RefType* refs() const
    {
        return refs_.data();
    }

    [[nodiscard]]
    const CoefType* coefs() const
    {
        return coefs_.data();
    }

    [[nodiscard]]
    const_reference_type at(size_t idx) const
    {
        assert(idx < size());
        return {refs_[idx], coefs_[idx]};
    }

    [[nodiscard]]
    bool operator==(const LinearCombination& other) const
    {
        return refs_ == other.refs_ && coefs_ == other.coefs_;
    }

    template <typename ValueGetter>
    [[nodiscard]]
    real_t evaluate(ValueGetter get_value) const
    {
        real_t result = 0.;
        for (int i = size() - 1; i >= 0; --i) {
            result += get_value(refs_[i]) * coefs_[i];
        }
        return result;
    }

    void remove_zero_coefficients()
    {
        size_t j = 0;
        for (size_t i = 0; i < size(); ++i) {
            if (!number_utils::is_zero(coefs_[i])) {
                if (i != j) {
                    refs_[j] = refs_[i];
                    coefs_[j] = coefs_[i];
                }
                ++j;
            }
        }
        refs_.erase(refs_.begin() + j, refs_.end());
        coefs_.erase(coefs_.begin() + j, coefs_.end());
    }

protected:
    [[nodiscard]]
    std::size_t elements_hash() const
    {
        return police::hash_combine(
            police::get_hash(refs_),
            police::get_hash(coefs_));
    }

    template <typename BinOp = std::plus<>>
    void merge(const_iterator first, const_iterator last, BinOp op = BinOp())
    {
        iterator i = begin();
        while (i != end() && first != last) {
            if (i->first < first->first) {
                ++i;
            } else if (first->first < i->first) {
                insert_at(i.pos_, first->first, op(0., first->second));
                ++first;
            } else {
                i->second = op(i->second, first->second);
                ++i;
                ++first;
            }
        }
        for (; first != last; ++first) {
            refs_.push_back(first->first);
            coefs_.push_back(op(0., first->second));
        }
    }

    void scale_coefs(CoefType multiplier)
    {
        std::for_each(begin(), end(), [multiplier](auto&& elem) {
            elem.second *= multiplier;
        });
    }

    [[nodiscard]]
    expressions::Expression as_expression() const
    {
        if (empty()) {
            return expressions::Constant(Value(static_cast<CoefType>(0)));
        } else {
            expressions::Expression result = get_expr(at(0));
            for (auto it = begin() + 1; it != end(); ++it) {
                add_to_expr(result, *it);
            }
            return result;
        }
    }

    static expressions::Expression get_expr(const_reference_type ref)
    {
        if (std::abs(ref.second - 1.) <
            std::numeric_limits<real_t>::epsilon()) {
            if (ref.second > 0.) {
                return expressions::Variable(ref.first);
            } else {
                return expressions::Variable(ref.first) *
                       expressions::MakeConstant()(static_cast<real_t>(-1.));
            }
        }
        return expressions::Variable(ref.first) *
               expressions::MakeConstant()(ref.second);
    }

    static void
    add_to_expr(expressions::Expression& expr, const_reference_type ref)
    {
        if (std::abs(ref.second - 1.) <
            std::numeric_limits<real_t>::epsilon()) {
            if (ref.second > 0.) {
                expr = expr + expressions::Variable(ref.first);
            } else {
                expr = expr - expressions::Variable(ref.first);
            }
        } else if (ref.second > 0.) {
            expr = expr + (expressions::Variable(ref.first) *
                           expressions::MakeConstant()(ref.second));
        } else if (ref.second < 0.) {
            expr = expr - (expressions::Variable(ref.first) *
                           expressions::MakeConstant()(-ref.second));
        }
    }

    [[nodiscard]]
    std::string to_string() const
    {
        if (empty()) return "";
        std::ostringstream oss;
        oss << coefs_[0] << "x" << refs_[0];
        for (auto i = 1u; i < refs_.size(); ++i) {
            if (coefs_[i] < 0.) {
                oss << " - " << (-coefs_[i]);
            } else {
                oss << " + " << coefs_[i];
            }
            oss << "x" << refs_[i];
        }
        return oss.str();
    }

private:
    [[nodiscard]]
    iterator lower_bound(const RefType& ref)
    {
        return {
            &refs_,
            &coefs_,
            static_cast<const LinearCombination*>(this)->lower_bound(ref).pos_};
    }

    [[nodiscard]]
    const_iterator lower_bound(const RefType& ref) const
    {
        auto pos = std::lower_bound(refs_.begin(), refs_.end(), ref);
        return const_iterator(
            &refs_,
            &coefs_,
            std::distance(refs_.begin(), pos));
    }

    void insert_at(size_t pos, RefType ref, CoefType coef)
    {
        assert(pos == refs_.size() || ref < refs_[pos]);
        refs_.insert(refs_.begin() + pos, std::move(ref));
        coefs_.insert(coefs_.begin() + pos, std::move(coef));
    }

    vector<RefType> refs_;
    vector<CoefType> coefs_;
};

struct LinearExpression : public LinearCombination<size_t, real_t> {
    template <typename>
    friend struct police::hash;

    LinearExpression() = default;

    LinearExpression(LinearCombination<size_t, real_t> elems, real_t bias)
        : LinearCombination<size_t, real_t>(std::move(elems))
        , bias(bias)
    {
    }

    [[nodiscard]]
    static LinearExpression
    from_expression(const expressions::Expression& expr);

    [[nodiscard]]
    static LinearExpression constant(real_t constant);

    [[nodiscard]]
    static LinearExpression
    unit(size_t var_id, real_t coef = 1., real_t constant = 0.);

    void clear()
    {
        LinearCombination<size_t, real_t>::clear();
        bias = 0.;
    }

    LinearExpression& operator-=(const LinearExpression& other)
    {
        merge(other.begin(), other.end(), std::minus<>());
        bias -= other.bias;
        return *this;
    }

    LinearExpression& operator+=(const LinearExpression& other)
    {
        merge(other.begin(), other.end(), std::plus<>());
        bias += other.bias;
        return *this;
    }

    LinearExpression& operator*=(real_t multiplier)
    {
        scale_coefs(multiplier);
        bias *= multiplier;
        return *this;
    }

    LinearExpression& operator/=(real_t divisor)
    {
        return *this *= (1. / divisor);
    }

    [[nodiscard]]
    LinearExpression operator+(const LinearExpression& other) const
    {
        LinearExpression temp(*this);
        temp += other;
        return temp;
    }

    [[nodiscard]]
    LinearExpression operator-(const LinearExpression& other) const
    {
        LinearExpression temp(*this);
        temp -= other;
        return temp;
    }

    [[nodiscard]]
    LinearExpression operator*(real_t multiplier) const
    {
        LinearExpression temp(*this);
        temp *= multiplier;
        return temp;
    }

    [[nodiscard]]
    LinearExpression operator/(real_t divisor) const
    {
        LinearExpression temp(*this);
        temp /= divisor;
        return temp;
    }

    template <typename ValueGetter>
    [[nodiscard]]
    real_t evaluate(ValueGetter get_value) const
    {
        return LinearCombination<size_t, real_t>::evaluate(
                   std::move(get_value)) +
               bias;
    }

    [[nodiscard]]
    expressions::Expression as_expression() const
    {
        return LinearCombination<size_t, real_t>::as_expression() +
               expressions::Constant(Value(bias));
    }

    [[nodiscard]]
    std::string to_string() const
    {
        std::ostringstream oss;
        if (size() == 0u) {
            oss << bias;
        } else {
            oss << LinearCombination<size_t, real_t>::to_string();
            if (bias > 0.) {
                oss << " + " << bias;
            } else if (bias < 0.) {
                oss << " - " << (-bias);
            }
        }
        return oss.str();
    }

    [[nodiscard]]
    bool operator==(const LinearExpression& e) const
    {
        return bias == e.bias && LinearCombination::operator==(e);
    }

    real_t bias{};
};

template <>
struct hash<LinearExpression> {
    [[nodiscard]]
    std::size_t operator()(const LinearExpression& e) const
    {
        return police::hash_combine(e.elements_hash(), get_hash(e.bias));
    }
};

std::ostream&
operator<<(std::ostream& out, const police::LinearExpression& expr);

} // namespace police
