
"""
This code builds upon the methods and reference implementation from:

    Shashank Shekhar, Aaditya Ramdas, and Larry Wasserman.
    "Nonparametric Testing by Betting." 
    Advances in Neural Information Processing Systems (NeurIPS), 2021.
    arXiv: https://arxiv.org/abs/2112.09162
    Code: https://github.com/sshekhar17/nonparametric-testing-by-betting

The original repository provides the foundational implementation of sequential
nonparametric hypothesis testing using e-processes. 
"""


from functools import partial
from math import sqrt 

import numpy as np 
from src.SeqTestsUtils import deLaPenaMartingale
from src.utils import RBFkernel, median_heuristic



def computeMMD(X, Y, kernel=None, perm=None, biased=True):
    """
    Compute the quadratic time MMD statistic based on gram-matrix K. 

    X       :ndarray    (nX, ndims) size observations
    Y       :ndarray    (nY, ndims) size observations
    kernel  :callable   kernel function 
    perm    :ndarray    the permutation array 
    biased  :bool       if True, compute the biased MMD statistic 

    returns 
    -------
    mmd     :float      the quadratic-time MMD statistic. 
    """
    Z = np.concatenate((X, Y), axis=0)
    nX, nZ = len(X), len(Z)

    if kernel is None:
        bw = median_heuristic(Z)
        kernel = partial(RBFkernel, bw=bw)

    if perm is None: 
        perm = np.arange(nZ)

    idxX, idxY = perm[:nX], perm[nX:]

    X_, Y_ = Z[idxX], Z[idxY]

    KXX = kernel(X_, X_)
    KYY = kernel(Y_, Y_)
    KXY = kernel(X_, Y_)

    nY = nZ - nX
    nY2, nX2, nXY = nY*nY, nX*nX, nX*nY
    assert nY>0 
    if biased:
        mmd = sqrt((1/nX2)*KXX.sum() + (1/nY2)*KYY.sum() - (2/nXY)*KXY.sum()) 
    else:#TODO: 
        raise NotImplementedError
    return mmd 


def kernelMMDprediction(X, Y, kernel=None, post_processing=None):
    nX, nY = len(X), len(Y) 
    assert nX==nY 
    assert nX>20

    if kernel is None:

        bw = median_heuristic(np.concatenate((X[:20], Y[:20]), axis=0))
        kernel = partial(RBFkernel, bw=bw)
    KXX = kernel(X, X)
    KYY = kernel(Y, Y)
    KXY = kernel(X, Y) 
    F = np.zeros((nX,)) 
    F_ = np.zeros((nX,))
    for i in range(1, nX):
        termX = np.mean((KXX[i, :i] - KXY[i, :i]))
        termY = np.mean((KXY[:i, i] - KYY[:i, i]))
        F_[i] = (termX - termY)
        F[i] = (termX - termY)

        if i>10:
            i0 = max(0, i-50)


            scale = np.max(np.abs(F_[:i]))   
            eps = np.finfo(float).eps       
            if scale <= eps:
                F[i] = F_[i]                
            else:
                F[i] = F_[i] / scale
                
    if post_processing=='sinh':
        F = np.sinh(F) 
    elif post_processing=='tanh':
        F = np.tanh(F)
    elif post_processing=='arctan':
        F = (2/np.pi)*np.arctan(F)  
    elif post_processing=='delapena':
        F = deLaPenaMartingale(F)
    return F 

