#pragma once

#include "police/base_types.hpp"
#include "police/storage/lin_vector.hpp"
#include "police/storage/matrix.hpp"
#include "police/storage/matrix_view.hpp"
#include "police/storage/vector.hpp"

#include <algorithm>
#include <cassert>
#include <iostream>

namespace police {

template <typename Value>
class SparseRow : public vector<std::pair<size_t, Value>> {
public:
    SparseRow& operator=(const LinVector<Value>& values)
    {
        this->clear();
        for (size_t c = 0; c < values.size(); ++c) {
            this->emplace_back(c, values[c]);
        }
    }
};

template <typename Value>
class SparseMatrix {
public:
    using value_type = Value;
    using row_type = SparseRow<value_type>;
    using view_type = MatrixView<SparseMatrix<Value>>;

    SparseMatrix() = default;

    SparseMatrix(size_t num_rows, size_t num_cols)
        : rows_(num_rows)
        , cols_(num_cols)
        , entries_(num_rows)
    {
    }

    explicit SparseMatrix(const Matrix<value_type>& matrix)
        : rows_(matrix.num_rows())
        , cols_(matrix.num_cols())
        , entries_(rows_)
    {
        for (size_t r = 0; r < num_rows(); ++r) {
            auto& row = entries_[r];
            for (size_t c = 0; c < num_cols(); ++c) {
                const auto coef = matrix.at(r, c);
                if (coef != 0) {
                    row.emplace_back(c, coef);
                }
            }
        }
    }

    operator view_type() const { return view(0, 0); }

    [[nodiscard]]
    view_type view(
        size_t row_begin,
        size_t col_begin,
        size_t row_end = -1,
        size_t col_end = -1) const
    {
        return view_type(this, row_begin, col_begin, row_end, col_end);
    }

    [[nodiscard]]
    value_type at(size_t row, size_t col) const
    {
        auto pos = std::lower_bound(
            entries_[row].begin(),
            entries_[row].end(),
            col,
            [](const auto& a, const size_t& col) { return a.first < col; });
        if (pos == entries_[row].end() || pos->first > col) {
            return 0;
        }
        return pos->second;
    }

    [[nodiscard]]
    size_t num_cols() const
    {
        return cols_;
    }

    [[nodiscard]]
    size_t num_rows() const
    {
        return rows_;
    }

    void set(size_t row, size_t col, value_type value)
    {
        auto pos = std::lower_bound(
            entries_[row].begin(),
            entries_[row].end(),
            col,
            [](const auto& a, const size_t& col) { return a.first < col; });
        if (value == 0.) {
            if (pos != entries_[row].end()) {
                entries_[row].erase(pos);
            }
        } else {
            if (pos == entries_[row].end() || pos->first > col) {
                entries_[row].insert(
                    pos,
                    std::pair<size_t, value_type>(col, value));
            } else {
                pos->second = value;
            }
        }
    }

    SparseMatrix& operator*=(value_type mult)
    {
        if (mult == 0.) {
            entries_.swap(vector<row_type>(rows_));
            return *this;
        }
        for (size_t r = 0; r < num_rows(); ++r) {
            for (auto& [col, coef] : entries_[r]) {
                coef *= mult;
            }
        }
        return *this;
    }

    SparseMatrix& operator/=(value_type mult)
    {
        return (*this *= (static_cast<value_type>(1) / mult));
    }

    [[nodiscard]]
    SparseMatrix operator*(value_type mult) const
    {
        SparseMatrix aux(*this);
        aux *= mult;
        return aux;
    }

    [[nodiscard]]
    SparseMatrix operator/(value_type mult) const
    {
        return (*this * (static_cast<value_type>(1) / mult));
    }

    template <typename F>
    [[nodiscard]]
    SparseMatrix map(F f) const
    {
        SparseMatrix res(*this);
        for (size_t r = 0; r < num_rows(); ++r) {
            for (auto& [col, coef] : entries_[r]) {
                coef = f(coef);
            }
        }
        return *this;
    }

    [[nodiscard]]
    LinVector<value_type> operator*(const LinVector<value_type>& vec) const
    {
        assert(vec.size() == num_cols());
        LinVector<value_type> result(num_rows(), 0);
        for (size_t r = 0; r < num_rows(); ++r) {
            value_type val = 0;
            for (const auto& [col, coef] : entries_[r]) {
                val += coef * vec[col];
            }
            result[r] = val;
        }
        return result;
    }

    [[nodiscard]]
    SparseMatrix operator*(const SparseMatrix& other) const
    {
        assert(num_cols() == other.num_rows());
        const auto transposed = other.transpose();
        SparseMatrix result(num_rows(), other.num_cols());
        for (size_t r = 0; r < rows_; ++r) {
            const row_type& row = entries_[r];
            if (row.empty()) {
                continue;
            }
            for (size_t c = 0; c < other.cols_; ++c) {
                const row_type& col = transposed.entries_[c];
                bool zero = true;
                value_type val = 0;
                for (size_t i = 0, j = 0; i < row.size() && j < col.size();) {
                    const auto x = row[i].first;
                    if (x == col[j].first) {
                        zero = false;
                        val += row[i].second * col[j].second;
                    }
                    i += x <= col[j].first;
                    j += x >= col[j].first;
                }
                if (!zero) {
                    result.entries_[r].emplace_back(c, val);
                }
            }
        }
        return result;
    }

    SparseMatrix<value_type>& operator+=(const SparseMatrix<value_type>& mat)
    {
        assert(mat.num_rows() == num_rows());
        assert(mat.num_cols() == num_cols());
        for (size_t r = 0; r < num_rows(); ++r) {
            auto& row = this->row(r);
            const auto& other = mat.row(r);
            row_type new_row;
            new_row.reserve(row.size());
            size_t i = 0, j = 0;
            for (; i < row.size() && j < other.size();) {
                const auto col = row[i].first;
                if (col == other[j].first) {
                    const auto val = row[i].second + other[j].second;
                    if (val != 0) {
                        new_row.emplace_back(col, val);
                    }
                    ++i;
                    ++j;
                } else if (col < other[j].first) {
                    new_row.push_back(row[i]);
                    ++i;
                } else {
                    new_row.push_back(other[j]);
                    ++j;
                }
            }
            new_row.insert(new_row.end(), row.begin() + i, row.end());
            new_row.insert(new_row.end(), other.begin() + j, other.end());
            row.swap(new_row);
        }
        return *this;
    }

    [[nodiscard]]
    SparseMatrix<value_type>
    operator+(const SparseMatrix<value_type>& other) const
    {
        auto result(*this);
        result += other;
        return result;
    }

    SparseMatrix<value_type>& inplace_hadamard(const Matrix<value_type>& matrix)
    {
        assert(rows_ == matrix.num_rows() && cols_ == matrix.num_cols());
        for (size_t r = 0; r < entries_.size(); ++r) {
            auto& row = entries_[r];
            size_t i = 0, j = 0;
            for (; i < row.size(); ++i) {
                const auto col = row[i].first;
                const auto coef = row[i].second * matrix.at(r, col);
                if (coef != 0) {
                    row[j] = {col, coef};
                    ++j;
                }
            }
            row.erase(row.begin() + j, row.end());
        }
        return *this;
    }

    SparseMatrix<value_type>&
    inplace_hadamard_division(const SparseMatrix<value_type>& matrix)
    {
        assert(rows_ == matrix.rows_ && cols_ == matrix.cols_);
        for (size_t r = 0; r < entries_.size(); ++r) {
            auto& row = entries_[r];
            const auto& other_row = matrix.entries_[r];
            size_t i = 0, j = 0, k = 0;
            for (; i < row.size() && j < other_row.size();) {
                const auto x = row[i].first;
                if (x == other_row[j].first) {
                    row[k] = {x, row[i].second / other_row[j].second};
                    ++k;
                }
                i += x <= other_row[j].first;
                j += x >= other_row[j].first;
            }
            row.erase(row.begin() + k, row.end());
        }
        return *this;
    }

    SparseMatrix<value_type>&
    inplace_hadamard(const SparseMatrix<value_type>& matrix)
    {
        assert(rows_ == matrix.rows_ && cols_ == matrix.cols_);
        for (size_t r = 0; r < entries_.size(); ++r) {
            auto& row = entries_[r];
            const auto& other_row = matrix.entries_[r];
            size_t i = 0, j = 0, k = 0;
            for (; i < row.size() && j < other_row.size();) {
                const auto x = row[i].first;
                if (x == other_row[j].first) {
                    row[k] = {x, row[i].second * other_row[j].second};
                    ++k;
                }
                i += x <= other_row[j].first;
                j += x >= other_row[j].first;
            }
            row.erase(row.begin() + k, row.end());
        }
        return *this;
    }

    [[nodiscard]]
    LinVector<value_type>
    diagonal_product(const SparseMatrix<value_type>& mat) const
    {
        assert(num_rows() == mat.num_cols());
        assert(num_cols() == mat.num_rows());
        LinVector<value_type> result(num_rows());
        const auto trans = mat.transpose();
        for (size_t r = 0; r < num_rows(); ++r) {
            const auto& row = this->row(r);
            const auto& col = trans.row(r);
            value_type val = 0;
            for (size_t i = 0, j = 0; i < row.size() && j < col.size();) {
                const auto idx = row[i].first;
                if (idx == col[j].first) {
                    val += row[i].second * col[j].second;
                }
                i += idx <= col[j].first;
                j += idx >= col[j].first;
            }
            result[r] = val;
        }
        return result;
    }

    void add_row_vector(const LinVector<value_type>& row_vec)
    {
        assert(row_vec.size() == num_cols());
        row_type new_row;
        row_type aux;
        for (size_t r = 0; r < num_rows(); ++r) {
            auto& row = entries_[r];
            row_type* dest[] = {&new_row, &aux};
            size_t j = 0;
            for (size_t i = 0; i < row.size(); ++j) {
                const value_type val = row[i].first == j
                                           ? (row[i].second + row_vec[j])
                                           : row_vec[j];
                dest[val == 0.]->emplace_back(j, val);
                i += row[i].first == j;
            }
            for (; j < row_vec.size(); ++j) {
                const value_type val = row_vec[j];
                dest[val == 0.]->emplace_back(j, val);
            }
            row.swap(new_row);
            new_row.clear();
            aux.clear();
        }
    }

    void add_column_vector(const LinVector<value_type>& vec)
    {
        assert(vec.size() == num_rows());
        row_type new_row;
        row_type aux;
        for (size_t r = 0; r < num_rows(); ++r) {
            const auto val = vec[r];
            if (val == 0.) {
                continue;
            }
            auto& row = entries_[r];
            row_type* dest[] = {&new_row, &aux};
            size_t col = 0;
            for (size_t i = 0; i < row.size(); ++col) {
                const auto new_val =
                    val + (row[i].first == col ? row[i].second : 0.);
                dest[new_val == 0.]->emplace_back(col, new_val);
                i += row[i].first == col;
            }
            for (; col < num_cols(); ++col) {
                new_row.emplace_back(col, val);
            }
            row.swap(new_row);
            new_row.clear();
            aux.clear();
        }
    }

    [[nodiscard]]
    SparseMatrix transpose() const
    {
        SparseMatrix result(cols_, rows_);
        for (size_t r = 0; r < rows_; ++r) {
            const auto& row = entries_[r];
            for (const auto& [col, val] : row) {
                assert(col < cols_);
                result.entries_[col].emplace_back(r, val);
            }
        }
        return result;
    }

    [[nodiscard]]
    const vector<row_type>& entries() const
    {
        return entries_;
    }

    [[nodiscard]]
    vector<row_type>& entries()
    {
        return entries_;
    }

    [[nodiscard]]
    const row_type& row(size_t r) const
    {
        assert(r < num_rows());
        return entries()[r];
    }

    [[nodiscard]]
    row_type& row(size_t r)
    {
        assert(r < num_rows());
        return entries()[r];
    }

    void swap(SparseMatrix<value_type>& other)
    {
        std::swap(rows_, other.rows_);
        std::swap(cols_, other.cols_);
        entries_.swap(other.entries_);
    }

    void clear()
    {
        rows_ = 0;
        cols_ = 0;
        entries_.clear();
    }

    void resize(size_t rows, size_t cols)
    {
        entries_.resize(rows);
        if (cols < cols_) {
            for (int i = rows - 1; i >= 0; --i) {
                auto& row = entries_[i];
                for (int j = row.size() - 1; j >= 0; --j) {
                    if (row[j].first < cols) {
                        break;
                    }
                    row.pop_back();
                }
            }
        }
        rows_ = rows;
        cols_ = cols;
    }

    [[nodiscard]]
    Matrix<value_type> unpack() const
    {
        Matrix<value_type> res(rows_, cols_);
        for (int r = rows_ - 1; r >= 0; --r) {
            const auto& row = entries_[r];
            for (const auto& [c, w] : row) {
                res.at(r, c) = w;
            }
        }
        return res;
    }

    void fill_row(size_t row, value_type val)
    {
        entries_[row].resize(cols_, 0);
        for (int c = cols_ - 1; c >= 0; --c) {
            entries_[row][c].first = c;
            entries_[row][c].second = val;
        }
    }

    [[nodiscard]]
    static SparseMatrix diagonal(size_t n)
    {
        SparseMatrix res(n, n);
        for (int o = n - 1; o >= 0; --o) {
            res.set(o, o, 1);
        }
        return res;
    }

private:
    size_t rows_ = 0;
    size_t cols_ = 0;
    vector<row_type> entries_;
};

template <typename value_type>
[[nodiscard]]
SparseMatrix<value_type>
operator*(const SparseMatrix<value_type>& a, const Matrix<value_type>& b)
{
    assert(a.num_cols() == b.num_rows());
    SparseMatrix<value_type> result(a.num_rows(), b.num_cols());
    SparseRow<value_type> aux;
    for (size_t r = 0; r < a.num_rows(); ++r) {
        const SparseRow<value_type>& row = a.row(r);
        if (row.empty()) {
            continue;
        }
        auto& result_row = result.row(r);
        SparseRow<value_type>* dest[] = {&result_row, &aux};
        for (size_t c = 0; c < b.num_cols(); ++c) {
            value_type val = 0;
            for (const auto& [r2, coef] : row) {
                val += coef * b.at(r2, c);
            }
            dest[val == 0.]->emplace_back(c, val);
        }
        aux.clear();
    }
    return result;
}

template <typename value_type>
[[nodiscard]]
SparseMatrix<value_type>
operator*(value_type mult, const SparseMatrix<value_type>& b)
{
    return b * mult;
}

template <typename value_type>
[[nodiscard]]
SparseMatrix<value_type>
operator*(const Matrix<value_type>& a, const SparseMatrix<value_type>& b)
{
    assert(a.num_cols() == b.num_rows());
    SparseMatrix<value_type> result(a.num_rows(), b.num_cols());
    const auto t = b.transpose();
    SparseRow<value_type> aux;
    for (size_t r = 0; r < a.num_rows(); ++r) {
        auto& row = result.row(r);
        SparseRow<value_type>* dest[] = {&row, &aux};
        for (size_t c = 0; c < t.num_rows(); ++c) {
            value_type val = 0;
            for (const auto& [x, coef] : t.row(c)) {
                val += coef * a.at(r, x);
            }
            dest[val == 0.]->emplace_back(c, val);
        }
        aux.clear();
    }
    return result;
}

template <typename value_type>
[[nodiscard]]
SparseMatrix<value_type> hadamard_product(
    const SparseMatrix<value_type>& a,
    const SparseMatrix<value_type>& b)
{
    SparseMatrix<value_type> result(a);
    result.inplace_hadamard(b);
    return result;
}

template <typename value_type>
[[nodiscard]]
SparseMatrix<value_type> hadamard_division(
    const SparseMatrix<value_type>& a,
    const SparseMatrix<value_type>& b)
{
    SparseMatrix<value_type> result(a);
    result.inplace_hadamard_division(b);
    return result;
}

template <typename value_type>
[[nodiscard]]
LinVector<value_type>&
operator+=(LinVector<value_type>& value, const SparseMatrix<value_type>& matrix)
{
    assert(matrix.num_cols() == 1u);
    assert(matrix.num_rows() == value.size());
    for (size_t n = 0; n < matrix.num_rows(); ++n) {
        const auto& row = matrix.row(n);
        for (const auto& [_, coef] : row) {
            value[n] += coef;
        }
    }
    return value;
}

template <typename value_type>
[[nodiscard]]
LinVector<value_type> operator+(
    const LinVector<value_type>& value,
    const SparseMatrix<value_type>& matrix)
{
    LinVector<value_type> result(value);
    result += matrix;
    return result;
}

template <typename value_type>
[[nodiscard]]
LinVector<value_type>
diagonal_product(const Matrix<value_type>& a, const SparseMatrix<value_type>& b)
{
    assert(a.num_rows() == b.num_cols());
    assert(a.num_cols() == b.num_rows());
    LinVector<value_type> result(a.num_rows(), 0);
    for (size_t b_row = 0; b_row < b.num_rows(); ++b_row) {
        const auto& row = b.row(b_row);
        for (const auto& [b_col, coef] : row) {
            result[b_col] += a.at(b_col, b_row) * coef;
        }
    }
    return result;
}

template <typename Value>
std::ostream& operator<<(std::ostream& out, const SparseMatrix<Value>& matrix)
{
    for (size_t r = 0; r < matrix.num_rows(); ++r) {
        const auto& row = matrix.row(r);
        for (size_t i = 0, j = 0; j < matrix.num_cols(); ++j) {
            if (i < row.size() && row[i].first == j) {
                out << (j > 0 ? " " : "") << row[i].second;
                ++i;
            } else {
                out << (j > 0 ? " " : "") << 0;
            }
        }
        out << (r + 1 < matrix.num_rows() ? "\n" : "");
    }
    return out;
}

} // namespace police
