from __future__ import annotations

from dataclasses import dataclass, field
from typing import Iterable

import h5py
import numba as nb
import numpy as np
import numpy.typing as npt


@nb.njit(nb.boolean[:](nb.int64[:], nb.int64[:]), parallel=True, cache=True)
def _isin_impl(
    a: npt.NDArray[np.int64], b: npt.NDArray[np.int64]
) -> npt.NDArray[np.bool_]:
    b = set(b)
    mask = np.empty(a.shape[0], dtype=nb.boolean)
    for i in nb.prange(a.shape[0]):
        mask[i] = a[i] in b
    return mask


def _isin_impl_nojit(
    a: npt.NDArray[np.int64], b: npt.NDArray[np.int64]
) -> npt.NDArray[np.int64]:
    b = set(b)
    mask = np.empty(b.shape[0], dtype=bool)
    for i in range(b.shape[0]):
        mask[i] = b[i] in a
    return b[mask]


@dataclass
class SparseVector:
    indices: npt.NDArray[np.int64] = field(
        default_factory=lambda: np.array([], dtype=np.int64)
    )

    def __len__(self) -> int:
        return len(self.indices)

    def __getitem__(self, arg) -> SparseVector:
        return type(self)(self.indices[arg])

    def __lshift__(self, shift: int) -> SparseVector:
        return type(self)(self.indices - shift)

    def __rshift__(self, shift: int) -> SparseVector:
        return type(self)(self.indices + shift)

    def __and__(self, other: SparseVector) -> SparseVector:
        return type(self)(other.indices[_isin_impl(other.indices, self.indices)])

    def __or__(self, other: SparseVector) -> SparseVector:
        return type(self)(np.union1d(self.indices, other.indices))

    @classmethod
    def union_from_iterable(cls, others: Iterable[SparseVector]) -> SparseVector:
        return (
            cls(np.concatenate([o.indices for o in others]))
            if len(tuple(others)) > 0
            else cls()
        )

    def union(self, *other: SparseVector) -> SparseVector:
        return type(self)(np.concatenate([self.indices, *[o.indices for o in other]]))

    def sort(self) -> SparseVector:
        self.indices.sort()
        return self


@dataclass
class SparseMatrixCSR:
    indptr: npt.NDArray[np.int64] | h5py.Dataset
    indices: npt.NDArray[np.int64] | h5py.Dataset

    def __len__(self) -> int:
        return len(self.indptr) - 1

    def __getitem__(self, i: int) -> SparseVector:
        return self.getrow(i)

    @classmethod
    def pack(cls, sparse_vectors: list[SparseVector]) -> SparseMatrixCSR:
        indptr = np.array([0] + [len(x) for x in sparse_vectors]).cumsum()
        indices = np.concatenate([x.indices for x in sparse_vectors])
        return cls(indptr, indices)

    def unpack(self) -> list[SparseVector]:
        return [self.getrow(i) for i in range(len(self.indptr) - 1)]

    def getrow(self, i: int) -> SparseVector:
        return SparseVector(self.indices[self.indptr[i] : self.indptr[i + 1]])
