import numpy as np
import pytest

from .sparse_matrix import SparseVector

indices = np.array([0, 1, 3, 5, 9, 10])


class TestSparseVector:
    def test_len(self):
        a = SparseVector(indices)
        assert len(a) == len(a.indices)

    @pytest.mark.parametrize("i", [-2, -1, 0, 1, 2, slice(2, 4), slice(6, None)])
    def test_getitem(self, i: int | slice):
        sparse_vector = SparseVector(indices)
        assert np.array_equal(sparse_vector.indices[i], sparse_vector[i].indices)

    @pytest.mark.parametrize("shift", [-2, -1, 0, 1, 2])
    def test_lshift(self, shift: int):
        sparse_vector = SparseVector(indices)
        assert np.array_equal((sparse_vector << shift).indices, indices - shift)

    @pytest.mark.parametrize("shift", [-2, -1, 0, 1, 2])
    def test_rshift(self, shift: int):
        sparse_vector = SparseVector(indices)
        assert np.array_equal((sparse_vector >> shift).indices, indices + shift)

    def test_and(self):
        a = SparseVector(indices)
        assert np.array_equal((a & a).indices, a.indices)
        assert len(a & SparseVector()) == 0
        b = SparseVector(np.array([8, 2, 3, 9]))
        assert np.array_equal((a & b).indices, np.array([3, 9]))
        a &= b
        assert np.array_equal(a.indices, np.array([3, 9]))

    def test_shift_and(self):
        a = SparseVector(indices)
        assert np.array_equal(((a << 0) & a).indices, a.indices)
        assert len(((a << 0) & SparseVector())) == 0
        b = SparseVector(np.array([8, 2, 3, 9]))
        assert np.array_equal(((a << 0) & b).indices, np.array([3, 9]))
        assert np.array_equal(np.sort(((a << 1) & b).indices), np.array([2, 8, 9]))
        assert np.array_equal(((a << 10) & b).indices, np.array([], dtype=np.int64))
        assert np.array_equal(((b << 0) & a).indices, np.array([3, 9]))
        assert np.array_equal(((b << 1) & a).indices, np.array([1]))

    def test_or(self):
        a = SparseVector(indices)
        assert np.array_equal((a | a).indices, a.indices)
        assert len(a | SparseVector(np.array([]))) == len(a)
        b = SparseVector(np.array([8, 2, 3, 9]))
        assert np.array_equal((a | b).indices, np.array([0, 1, 2, 3, 5, 8, 9, 10]))

    def test_union(self):
        a = SparseVector(indices)
        assert len(a.union(SparseVector(np.array([])))) == len(a)
        b = SparseVector(np.array([8, 2, 4, 15]))
        assert np.array_equal(
            a.union(b).indices, np.concatenate([a.indices, b.indices])
        )

    def test_union_from_iterable(self):
        a = SparseVector(indices)
        assert len(SparseVector.union_from_iterable([a])) == len(a)
        assert len(
            SparseVector.union_from_iterable([a, SparseVector(np.array([]))])
        ) == len(a)
        b = SparseVector(np.array([8, 2, 4, 15]))
        assert np.array_equal(
            SparseVector.union_from_iterable([a, b]).indices,
            np.concatenate([a.indices, b.indices]),
        )
        assert np.array_equal(
            SparseVector.union_from_iterable([]).indices, SparseVector().indices
        )
