from .label_rkme import LabelRKME
import numpy as np
import torch

class ClassWiseRKME(LabelRKME):
    def __init__(self, cfg, X, Y, **kwargs):
        super().__init__(cfg, X, Y, **kwargs)

    def save(self):
        print('save to', self.path)
        data = {
            'y': self.y.detach().cpu().numpy(),
            'classes': self.classes.tolist()
        }
        for c in self.classes:
            idx = (self.y == c)
            Z = self.Z[idx]
            beta = self.beta[idx]
            KZ = self.KZ[idx][:, idx]
            norm = (beta @ KZ @ beta).item()
            data[f'Z-{c}']    = Z.detach().cpu().numpy()
            data[f'beta-{c}'] = beta.detach().cpu().numpy()
            data[f'norm-{c}'] = norm
        np.savez(self.path, **data)

    def load_helper(self):
        data = np.load(self.path)
        self.classes = data['classes']
        self.y = torch.tensor(data['y']).to(self.device)
        self.beta = {}
        self.Z = {}
        self.norm = {}
        for c in self.classes:
            self.beta[c] = torch.from_numpy(data[f'beta-{c}']).to(self.device)
            self.Z[c] = torch.from_numpy(data[f'Z-{c}']).to(self.device)
            self.norm[c] = data[f'norm-{c}']

    def class_distance(self, other, c1, c2, lambd=1.0):
        assert c1 in self.classes and c2 in other.classes
        Z1 = self.Z[c1]
        Z2 = other.Z[c2].to(self.device)
        beta1 = self.beta[c1]
        beta2 = other.beta[c2].to(self.device)
        norm1 = self.norm[c1]
        norm2 = other.norm[c2]
        cross_norm = (beta1 @ self.kernel_x(Z1, Z2) @ beta2).item()
        return lambd - (norm1 + norm2 - 2 * cross_norm)