import time

import numpy as np
import numpy.linalg as LA

from .softset import SoftSet, _bmm_gram_impl, _matmul_impl, softset_membership_batch

K = 3
D = 64


class TestSoftSet:
    def test_construct(self):
        np.random.seed(0)
        A = np.random.rand(K, D).astype(np.float32)
        sset = SoftSet.construct(A)

        # Check normalized.
        assert np.allclose(LA.norm(sset.bases, axis=-1), np.ones(K))
        # Check orthogonalized.
        orthogonal_mm = np.einsum("...d,...d->...", sset.bases[:, 0], sset.bases[:, -1])
        assert np.allclose(orthogonal_mm, np.zeros(1), rtol=1e-3, atol=1e-3)

    def test_len(self):
        np.random.seed(0)
        A = np.random.rand(K, D).astype(np.float32)
        sset = SoftSet.construct(A)
        assert len(sset) == K

    def test_membership(self):
        A = np.array([[1.0, 0.0, 0.0], [0.0, 0.0, 1.0]]).astype(np.float32)
        sset = SoftSet.construct(A)
        q_in = np.array([1.0, 0.0, 0.0]).astype(np.float32)
        q_ex = np.array([0.0, 1.0, 0.0]).astype(np.float32)
        scores = sset.membership(q_in)
        assert np.allclose(scores, np.array(1.0).astype(np.float32))
        scores = sset.membership(q_ex)
        assert np.allclose(scores, np.array(0.0).astype(np.float32))

    def test_membership_batch(self):
        A = np.array([[1.0, 0.0, 0.0], [0.0, 0.0, 1.0]]).astype(np.float32)
        B = np.array([[0.0, 1.0, 0.0]]).astype(np.float32)
        ssetA = SoftSet.construct(A)
        ssetB = SoftSet.construct(B)
        q_in = np.array([1.0, 0.0, 0.0]).astype(np.float32)
        q_ex = np.array([0.0, 1.0, 0.0]).astype(np.float32)
        q = np.vstack([q_in, q_ex])
        scores = softset_membership_batch([ssetA, ssetB], q)
        assert np.allclose(scores, np.array([[1.0, 0.0], [0.0, 1.0]], dtype=np.float32))

    def test_matmul_impl(self):
        N, Q, D = 3, 5, 8
        A = np.random.rand(N, D).astype(np.float32)
        q = np.random.rand(Q, D).astype(np.float32)
        res = _matmul_impl(A, q)
        expected = np.einsum("nd,qd->nq", A, q)
        assert np.allclose(res, expected)

    def test_matmul_impl_speed(self):
        N, Q, D = 3, 100_000, 256
        A = np.random.rand(N, D).astype(np.float32)
        q = np.random.rand(Q, D).astype(np.float32)

        _matmul_impl(A, q)
        s = time.perf_counter()
        for _ in range(10):
            _matmul_impl(A, q)
        e = time.perf_counter()
        time_fast = e - s

        np.einsum("nd,qd->nq", A, q)
        s = time.perf_counter()
        for _ in range(10):
            np.einsum("nd,qd->nq", A, q)
        e = time.perf_counter()
        time_naive = e - s
        assert time_fast <= time_naive

    def test_bmm_gram(self):
        N, Q, K, D = 3, 5, 7, 8
        A = np.random.rand(N, K, D).astype(np.float32)
        q = np.random.rand(Q, D).astype(np.float32)
        res = _bmm_gram_impl(A, q)
        bmm = np.einsum("nkd,qd->nqk", A, q)
        expected = np.einsum("nqk,nql->nqkl", bmm, bmm)
        assert np.allclose(res, expected)

    def test_bmm_gram_speed(self):
        N, Q, K, D = 3, 100_000, 2, 256
        A = np.random.rand(N, K, D).astype(np.float32)
        q = np.random.rand(Q, D).astype(np.float32)

        _bmm_gram_impl(A, q)
        s = time.perf_counter()
        for _ in range(10):
            _bmm_gram_impl(A, q)
        e = time.perf_counter()
        time_fast = e - s

        bmm = np.einsum("nkd,qd->nqk", A, q)
        np.einsum("nqk,nql->nqkl", bmm, bmm)
        s = time.perf_counter()
        for _ in range(10):
            bmm = np.einsum("nkd,qd->nqk", A, q)
            np.einsum("nqk,nql->nqkl", bmm, bmm)
        e = time.perf_counter()
        time_naive = e - s
        assert time_fast <= time_naive
