import numpy as np
import torch
import os
from .base import CMEBase
from .neural_phi import NeuralPhi
from time import time

class NCME(CMEBase):
    def __init__(self, cfg, X, Y, **kwargs):
        super().__init__(cfg, X, Y, **kwargs)
        self.phi = NeuralPhi(cfg, **kwargs)

    def generate_helper(self, *args, **kwargs):
        # self.KY = self.kernel_y(self.Y, self.Y)
        start = time()
        KY = self.kernel_y(self.Y, self.Y)
        if self.cfg['fit']:
            self.phi.fit(self.X, KY)
        print('phi fitting time:', time() - start)

        # start = time()
        Phi = self.phi(self.X).T    # (d, n)
        self.BX = 1
        I = torch.eye(Phi.shape[0], device=self.device)
        invW = torch.linalg.inv(Phi @ Phi.T + self.n * self.lambd * I)
        self.Q = Phi.T @ invW
        # print('Generate time:', time() - start)
        self.norm = torch.trace(self.Q.T @ KY @ self.Q).item()  # Q.T @ KY @ Q
        return KY

    def compare(self, ncme):
        """
              ||Psi1 Q1 - Psi2 Q2||^2
            = tr(Q1.T @ KY1 @ Q1) + tr(Q2.T @ KY2 @ Q2) - 2 * tr(Q1.T @ KY12 @ Q2)
        """
        Q1 = self.get('Q')
        Y1 = self.get('Y')
        Q2 = self.get('Q', ncme)
        Y2 = self.get('Y', ncme)
        KY12 = self.kernel_y(Y1, Y2)
        cross_norm = torch.trace(Q1.T @ KY12 @ Q2).item()
        return self.norm + ncme.norm - 2 * cross_norm

    def save(self):
        print('save to', self.path)
        np.savez(
            self.path,
            Y=self.Y.detach().cpu().numpy(),
            Q=self.Q.detach().cpu().numpy(),
            norm=self.norm
        )

    def load_helper(self):
        data = np.load(self.path, allow_pickle=True)
        self.Y = torch.from_numpy(data['Y']).to(self.device)
        self.Q = torch.from_numpy(data['Q']).to(self.device)
        self.norm = data['norm']