import numpy as np
from kernel import * 

rng = np.random.default_rng(1234)

class TestStaticMMD:
    n = 1000
    d = 10
    num_rff = 10000

    def test_estimates_of_biased_and_RFF_MMD_are_close(self):
        for gamma in [0.25,0.5,1,2]: # check for different scales
            k = Gauss(gamma=gamma)
            X, Y = rng.normal(size=(self.n,self.d)), rng.normal(size=(self.n,self.d))
            assert np.abs(BiasedMMD(kernel=k).mmd(X,Y) - RFFMMD(kernel=k,num_omegas=self.num_rff).mmd(X,Y)) < 1e-3

    def test_RFF_MMD_rejects_null(self):
        d = 1
        gamma = 1
        k = Gauss(gamma=gamma)
        X, Y = rng.normal(size=(self.n,d)), rng.normal(loc=1, size=(self.n,d))
        assert RFFMMD(kernel=k,num_omegas=self.num_rff).mmd(X,Y) > BiasedMMD.threshold(n=self.n,m=self.n,alpha=.01)

    def test_biased_MMD_rejects_null(self):
        d = 1
        gamma = 1
        k = Gauss(gamma=gamma)
        X, Y = rng.normal(size=(self.n,d)), rng.normal(loc=1, size=(self.n,d))
        assert BiasedMMD(kernel=k).mmd(X,Y) > BiasedMMD.threshold(n=self.n,m=self.n,alpha=.01)

class TestStreamingMMD:
    n = 2**10
    m = 2**10-1
    d = 2
    num_rff = 100

    def test_streaming_and_static_computation_agree_across_all_splits(self):
        gamma=1
        k = Gauss(gamma=gamma)
        X, Y = rng.normal(size=(self.n,self.d)), rng.normal(loc=10, size=(self.m,self.d))

        cd = StreamingRFFMMD(kernel=k,d=self.d,num_omegas=self.num_rff)
        data = np.concatenate((X,Y))
        for e in data:
            cd.insert(e)

        for split, n in enumerate(np.cumsum((2**np.array([*range(1,11)])[::-1]))):
            assert np.abs(cd.mmd_values()[split] - RFFMMD(kernel=k).mmd(data[:n],data[n:],omegas=cd.omegas)) < 1e-10