from __future__ import annotations

from dataclasses import dataclass

import numba as nb
import numpy as np
import numpy.linalg as LA

from softmatcha.typing import Vector

NUMPY_AVX512F_SUPPORTED = "AVX512F" in np.show_config(mode="dicts").get(
    "SIMD Extensions", {}
).get("found", [])


@nb.njit(nb.float32(nb.float32[:, :], nb.float32[:]), fastmath=True, cache=True)
def _softset_membership(soft_set: Vector, q: Vector) -> float:
    """Compute membership degree between the soft set and a vector.

    Args:
        soft_set (Vector): A soft set that consists of the orthonormal bases of linear
          subspace of shape `(num_bases, embed_dim)`.
        q (Vector): Normalized query vectors of shape `(embed_dim,)`.

    Returns:
        float: Membership degree.

    Examples:
        >>> S = SoftSet.construct(np.random.rand(4, 300))
        >>> q = np.random.rand(300)
        >>> _softset_membership(S, q)
    """
    # Compute cos(theta) using SVD
    mm = np.ascontiguousarray(soft_set) @ np.ascontiguousarray(q)
    _, sigma, _ = LA.svd(mm[:, None])
    return sigma.max()


@nb.njit(
    nb.float32[:, :, :, :](nb.float32[:, :, :], nb.float32[:, :]),
    parallel=True,
    fastmath=True,
    cache=True,
)
def _bmm_gram_impl(soft_sets: Vector, q: Vector) -> Vector:
    """Compute batched matrix multiplication (bmm) and its gram matrix.

    This is a helper function for the singular value calculation of the bmm matrix.

    Args:
        soft_sets (Vector): Soft sets that consists of the orthonormal bases of linear
          subspace of shape `(num_sets, num_bases, embed_dim)`.
        q (Vector): Normalized query vectors of shape `(num_queries, embed_dim)`.

    Returns:
        Vector: Symmetric gram matrix of shape `(num_sets, num_queries, num_bases, num_bases)`.
    """
    num_sets, num_bases, embed_dim = soft_sets.shape
    num_queries = q.shape[0]

    q = np.ascontiguousarray(q)
    soft_sets = np.ascontiguousarray(soft_sets)
    bmm = np.zeros((num_sets, num_queries, num_bases), dtype=np.float32)
    bmm_gram = np.zeros(
        (num_sets, num_queries, num_bases, num_bases),
        dtype=np.float32,
    )
    for i in nb.prange(num_sets):
        for k in nb.prange(num_bases):
            for j in nb.prange(num_queries):
                for d in nb.prange(embed_dim):
                    bmm[i, j, k] += soft_sets[i, k, d] * q[j, d]

    for i in nb.prange(num_sets):
        for j in nb.prange(num_queries):
            for k1 in nb.prange(num_bases):
                for k2 in nb.prange(num_bases):
                    bmm_gram[i, j, k1, k2] += bmm[i, j, k1] * bmm[i, j, k2]
    return bmm_gram


@nb.njit(
    nb.float32[:, :](nb.float32[:, :], nb.float32[:, :], nb.float32),
    parallel=True,
    fastmath=True,
    cache=True,
)
def _matmul_numba_impl(a: Vector, b: Vector, minimum: float) -> Vector:
    """Compute matrix multiplication.

    Args:
        a (Vector): The input matrix of shape `(a_length, embed_dim)`.
        b (Vector): The other matrix of shape `(b_length, embed_dim)`.

    Returns:
        Vector: Multiplied matrix of shape `(a_length, b_length)`.
    """
    len_a, embed_dim = a.shape
    len_b = b.shape[0]

    a = np.ascontiguousarray(a)
    b = np.ascontiguousarray(b)
    mm = np.zeros((len_a, len_b), dtype=np.float32)
    for i in nb.prange(len_a):
        for j in nb.prange(len_b):
            for d in nb.prange(embed_dim):
                mm[i, j] += a[i, d] * b[j, d]
    return np.maximum(mm, minimum)


def _matmul_impl(a: Vector, b: Vector, minimum: float = 0.0) -> Vector:
    """Compute matrix multiplication.

    Args:
        a (Vector): The input matrix of shape `(a_length, embed_dim)`.
        b (Vector): The other matrix of shape `(b_length, embed_dim)`.

    Returns:
        Vector: Multiplied matrix of shape `(a_length, b_length)`.
    """
    if NUMPY_AVX512F_SUPPORTED:
        return np.maximum(a @ b.T, minimum)
    else:
        return _matmul_numba_impl(a, b, minimum)


def softset_membership_batch(soft_sets: list[SoftSet], q: Vector) -> Vector:
    """Compute membership degrees between the soft sets and vectors.

    Args:
        soft_sets (list[SoftSet]): Soft sets that consists of the orthonormal bases of linear
          subspace of shape `(num_bases, embed_dim)`.
        q (Vector): Normalized query vectors of shape `(num_queries, embed_dim,)`.

    Returns:
        Vector: Membership degrees of shape `(num_sets, num_queries)`.

    Examples:
        >>> S1 = SoftSet.construct(np.random.rand(4, 300))
        >>> S2 = SoftSet.construct(np.random.rand(3, 300))
        >>> q = np.random.rand(5, 300)
        >>> scores = softset_membership_batch([S1, S2], q)
        >>> scores.shape
        (2, 5)
    """
    # Construct batched soft sets
    num_bases = max(len(s) for s in soft_sets)
    batched_soft_sets = np.zeros(
        (len(soft_sets), num_bases, q.shape[-1]), dtype=np.float32
    )
    for i, s in enumerate(soft_sets):
        np.copyto(batched_soft_sets[i, : len(s)], s.bases)

    if num_bases == 1:
        return _matmul_impl(batched_soft_sets.squeeze(1), q, minimum=0.0)
    else:
        bmm_gram = _bmm_gram_impl(batched_soft_sets, q)
        sigma = LA.svd(bmm_gram, compute_uv=False, hermitian=True)[:, :, 0] ** 0.5
        return sigma


def softset_membership_forloop(soft_sets: list[SoftSet], q: Vector) -> Vector:
    """Compute membership degrees between the soft sets and vectors.

    Args:
        soft_sets (list[SoftSet]): Soft sets that consists of the orthonormal bases of linear
          subspace of shape `(num_bases, embed_dim)`.
        q (Vector): Normalized query vectors of shape `(num_queries, embed_dim,)`.

    Returns:
        Vector: Membership degrees of shape `(num_sets, num_queries)`.

    Examples:
        >>> S1 = SoftSet.construct(np.random.rand(4, 300))
        >>> S2 = SoftSet.construct(np.random.rand(3, 300))
        >>> q = np.random.rand(5, 300)
        >>> scores = softset_membership_batch([S1, S2], q)
        >>> scores.shape
        (2, 5)
    """
    # Construct batched soft sets
    scores = np.zeros((len(soft_sets), q.shape[0]), dtype=np.float32)
    for i, s in enumerate(soft_sets):
        # Compute cos(theta) using SVD
        for j in range(q.shape[0]):
            mm = np.einsum("kd,d->k", s.bases, q[j])
            squared_mm = mm[..., :, None] @ mm[..., None, :]
            sigma = LA.svd(squared_mm, compute_uv=False, hermitian=True) ** 0.5
            scores[i, j] = sigma[..., 0]
    return scores


@dataclass
class SoftSet:
    """Soft set class for various soft set operations.

    References:
        Ishibashi et al., 2022,
        "Subspace Representations for Soft Set Operations and Sentence Similarities".
        https://arxiv.org/abs/2210.13034

    - bases (Vector): Orthonormal bases of shpe `(num_bases, embed_dim)`.
    """

    bases: Vector

    def __len__(self) -> int:
        return len(self.bases)

    @classmethod
    def construct(cls, elements: Vector) -> SoftSet:
        """Construct the soft set from element vectors.

        A soft set is based on the linear subspace of element vectors and represented
        by the orthonormalized bases.
        This implement uses QR decomposition to compute the orthonormalized bases.

        Args:
            elements (Vector): Element vectors of shape `(num_vectors, embed_dim)`.

        Returns:
            SoftSet: Linear subspace based set of shape `(num_vectors, embed_dim)`.
        """
        subspace = LA.qr(elements.T).Q.T
        neg_mask = (subspace * elements).sum(axis=-1) < 0.0
        subspace[neg_mask] *= -1.0
        return cls(subspace)

    def membership(self, q: Vector) -> float:
        """Compute membership degree between the soft set and a vector.

        Args:
            q (Vector): Query vectors of shape `(embed_dim,)`.

        Returns:
            float: Membership degree.

        Examples:
            >>> S = SoftSet.construct(np.random.rand(4, 300))
            >>> q = np.random.rand(300)
            >>> S.membership(q)
        """
        return _softset_membership(self.bases, q)
