import torch
import numpy as np
from .sinkhorn_stab import sinkhorn_stabilized

DEFAULT_PLATFORM = {"dtype": torch.float64, "device": "cpu"}


class WassersteinByGroup:
    modelname = "Wasserstein"
    
    def __init__(self, eps=0.01):
        self.eps = eps

    def fit(self, X, Y, groups, a=None, b=None, platform=DEFAULT_PLATFORM, **kargs):
        n1, n2 = X.shape[0], Y.shape[0]
        
        # Optimal tranport constraints
        if a is None:
            a = torch.ones((n1,), **platform)/n1
        if b is None:
            b = torch.ones((n2,), **platform)/n2
        
        distances = []
        for grp in groups:
            Xgrp = X[:, grp]
            Ygrp = Y[:, grp]
            C = torch.cdist(Xgrp, Ygrp, p=2)
            
            PI = sinkhorn_stabilized(a, b, C, self.eps, platform)
            distances.append(torch.einsum("ij,ij", PI, C))
        distances = torch.stack(distances)

        self.distances_ = np.array(distances.cpu())
        self.groups_ = np.array(groups)
        self.sorted_group_importance = self.groups_[np.argsort(self.distances_)[::-1]]
        
        self.a_ = a
        self.b_ = b
