import torch

class HSCIC():

    def __init__(self,
                 regularization=0.1):

        # kernel model
        self.regularization = regularization

    def __call__(self,
                 Y: torch.Tensor,
                 A: torch.Tensor,
                 X: torch.Tensor):

        # reshape arrays
        A = torch.reshape(A, [torch.Tensor.size(A)[0], -1])
        X = torch.reshape(X, [torch.Tensor.size(X)[0], -1])
        Y = torch.reshape(Y, [torch.Tensor.size(Y)[0], -1])
        A = torch.FloatTensor(A)
        X = torch.FloatTensor(X)
        Y = torch.FloatTensor(Y)

        # get Kernel matrices
        gram_A = self.gaussian_kernel(A, A)
        gram_X = self.gaussian_kernel(X, X)
        gram_Y = self.gaussian_kernel(Y, Y)
        gram_A = torch.FloatTensor(gram_A)
        gram_X = torch.FloatTensor(gram_X)
        gram_Y = torch.FloatTensor(gram_Y)

        # get HSCIC loss
        res_total = 0
        i = 0
        for i, row in enumerate(gram_X):
            res_i = self.inner_loss(torch.t(row), gram_A, gram_X, gram_Y)
            res_total += res_i

        res = res_total/(i+1)
        return res

    # get loss given a single instance x
    def inner_loss(self, X, gram_A, gram_X, gram_Y):
        # get number of samples and make matrix W
        n_samples = torch.Tensor.size(gram_Y)[0]
        identity = torch.eye(n_samples)
        W = gram_X + n_samples * self.regularization * identity

        # solve linear system
        f = torch.linalg.solve(torch.t(W), X)
        f = f.reshape(-1, 1)
        fT = torch.t(f)

        # get distributions
        res = torch.einsum('ij,jk,kl', fT, gram_A * gram_Y, f)
        M = torch.einsum('ij,jk', gram_A, f)
        N = torch.einsum('ij,jk', gram_Y, f)
        res = res - 2 * torch.einsum('ij,jk', fT, M * N)
        P = torch.einsum('ij,jk,kl', fT, gram_A, f)
        Q = torch.einsum('ij,jk,kl', fT, gram_Y, f)
        res = res + P * Q

        return res.flatten()

    def gaussian_kernel(self, a, b):
        dim1_1, dim1_2 = a.shape[0], b.shape[0]
        depth = a.shape[1]
        a = a.view(dim1_1, 1, depth)
        b = b.view(1, dim1_2, depth)
        a_core = a.expand(dim1_1, dim1_2, depth)
        b_core = b.expand(dim1_1, dim1_2, depth)
        numerator = (a_core - b_core).pow(2).mean(2) / (2 * 0.1 ** 2)
        return torch.exp(-numerator)