import numpy as np
from .mmd import kernelwidthPair, grbf, MMD_unbiased


class MMD_distances:
    modelname = "MMD"

    def fit(self, X, Y, groups, **kargs):
        distances = []
        for grp in groups:
            Xgrp = np.array(X[:, grp])
            Ygrp = np.array(Y[:, grp])

            sigma = kernelwidthPair(Xgrp, Ygrp)
            Kxx = grbf(Xgrp, Xgrp, sigma)
            Kyy = grbf(Ygrp, Ygrp, sigma)
            Kxy = grbf(Xgrp, Ygrp, sigma)
            mmd = MMD_unbiased(Kxx, Kyy, Kxy)
            distances.append(mmd)
        distances = np.stack(distances)
        
        self.distances_ = np.array(distances)
        self.groups_ = np.array(groups)
        self.sorted_group_importance = self.groups_[np.argsort(self.distances_)[::-1]]
        
