import os
import torch
from .utils import rbf_kernel
from .label_rkme import LabelRKME

class HeterogeneousFeature(LabelRKME):
    def __init__(self, cfg, X, Y, **kwargs):
        """
            Heterogeneous Feature Specification.

            HF is designed to solve the heterogeneous feature problem, but we focus on the homogeneous feature setting in this paper.

            Therefore, some implementations are omitted, such as subspace learning, etc.
        """
        super().__init__(cfg, X, Y, **kwargs)

    def compare(self, other):
        dis = self.__class_distance(other)
        y1 =  self.get('y').tolist()
        y2 = other.get('y').tolist()
        common_classes = set(y1).intersection(set(y2))
        for cls in common_classes:
            dis += 10 * self.__class_distance(other, cls=cls) / len(common_classes)
        return dis

    def __class_distance(self, other, cls=None):
        Z1 = self.get('Z')
        Z2 = self.get('Z', other)
        y1 = self.get('y')
        y2 = self.get('y', other)
        beta1 = self.get('beta')
        beta2 = self.get('beta', other)
        norm1 = self.get('norm')
        norm2 = other.get('norm')

        if cls is not None:
            Z1 = Z1[y1 == cls]
            Z2 = Z2[y2 == cls]
            beta1 = beta1[y1 == cls]
            beta2 = beta2[y2 == cls]
            norm1 = (beta1 @ self.kernel_x(Z1, Z1) @ beta1).item()
            norm2 = (beta2 @ other.kernel_x(Z2, Z2) @ beta2).item()
        KZ12 = self.kernel_x(Z1, Z2)
        cross_norm = (beta1 @ KZ12 @ beta2).item()
        return norm1 + norm2 - 2 * cross_norm