import torch
import numpy as np
from .base import CMEBase

class CME(CMEBase):
    def __init__(self, cfg, X, Y, **kwargs):
        N = 30000
        super().__init__(cfg, X[:N], Y[:N], **kwargs)

    def generate_helper(self, *args, **kwargs):
        KX = self.kernel_x(self.X, self.X)
        KY = self.kernel_y(self.Y, self.Y)
        I = torch.eye(self.n, device=self.device)
        self.W = torch.linalg.inv(KX + self.n * self.lambd * I)
        self.norm = torch.trace(self.W @ KY @ self.W @ KX).item()  # W.T @ KY @ W @ KX

    def compare(self, cme):
        """
              ||Psi1 W1 Phi1.T - Psi2 W2 Phi2.T||^2
            = tr(W1 KY1 W1 KX1) + tr(W2 KY2 W2 KX2) - 2 tr(W1 KY12 W2 KX21)
            = norm1 + norm2 - 2 * norm12
        """
        X1 = self.get('X')              # (N1, dx)
        Y1 = self.get('Y')              # (N1, dy)
        W1 = self.get('W')              # (N1, N1)
        X2 = self.get('X', cme)         # (N2, dx)
        Y2 = self.get('Y', cme)         # (N2, dy)
        W2 = self.get('W', cme)         # (N2, N2)
        KX21 = self.kernel_x(X2, X1)    # (N2, N1)
        KY12 = self.kernel_y(Y1, Y2)    # (N1, N2)
        cross_norm = torch.trace(W1 @ KY12 @ W2 @ KX21).item()
        return self.norm + cme.norm - 2 * cross_norm

    def save(self):
        print('save to', self.path)
        np.savez(
            self.path,
            X=self.X.detach().cpu().numpy(),
            Y=self.Y.detach().cpu().numpy(),
            W=self.W.detach().cpu().numpy(),
            norm=self.norm
        )

    def load_helper(self):
        data = np.load(self.path, allow_pickle=True)
        self.X = torch.from_numpy(data['X']).to(self.device)
        self.Y = torch.from_numpy(data['Y']).to(self.device)
        self.W = torch.from_numpy(data['W']).to(self.device)
        self.norm = data['norm']