import os
import torch
import numpy as np
from .utils import rbf_kernel
from .base import Base
from fast_pytorch_kmeans import KMeans
from time import time

class RKME(Base):
    def __init__(self, cfg, X, **kwargs):
        super().__init__(cfg, X, **kwargs)

    def generate_helper(self, *args, **kwargs):
        reduced_size = self.cfg['RKME_reduced_size']
        self.__init_Z_by_kMeans(reduced_size)
        self.__update_beta()
        for _ in range(self.cfg['steps']):
            self.__update_Z()
            self.__update_beta()
        self.norm = (self.beta @ self.KZ @ self.beta).item()

    def __init_Z_by_kMeans(self, reduced_size):
        kmeans = KMeans(n_clusters=reduced_size, mode='euclidean', max_iter=100, verbose=0)
        kmeans.fit(self.X)
        self.Z = kmeans.centroids.double()
        self.KZ = self.kernel_x(self.Z, self.Z)

    def __update_beta(self):
        ones = torch.ones(self.n, device=self.device).double()
        I = torch.eye(self.Z.shape[0], device=self.device)
        KZX = self.kernel_x(self.Z, self.X)
        q = KZX @ ones / self.n
        self.beta = torch.linalg.inv(self.KZ + I * 1e-6) @ q

    @torch.no_grad()
    def __update_Z(self):
        beta = self.beta
        Z = self.Z         # (m, d)
        grad_z = torch.zeros_like(Z)

        for i in range(self.Z.shape[0]):
            z = Z[i]
            KzZ = self.kernel_x(z, self.Z)    # (1, m)
            KzX = self.kernel_x(z, self.X)    # (1, n)
            grad_z[i, :] = -2 * beta[i] * ((beta * KzZ) @ (z - Z) - 2 * (KzX / self.n) @ (z - self.X))

        Z = Z - self.cfg['step_size'] * grad_z
        self.Z = Z
        self.KZ = self.kernel_x(Z, Z)

    def compare(self, rkme):
        Z1 = self.get('Z')
        Z2 = self.get('Z', rkme)
        beta1 = self.get('beta').double()
        beta2 = self.get('beta', rkme).double()
        KZ12 = self.kernel_x(Z1, Z2)
        cross_norm = (beta1 @ KZ12 @ beta2).item()
        return self.norm + rkme.norm - 2 * cross_norm

    def save(self):
        print('save to', self.path)
        np.savez(
            self.path,
            Z=self.Z.detach().cpu().numpy(),
            beta=self.beta.detach().cpu().numpy(),
            norm=self.norm
        )

    def load_helper(self):
        data = np.load(self.path, allow_pickle=True)
        self.Z = torch.from_numpy(data['Z']).to(self.device)
        self.beta = torch.from_numpy(data['beta']).to(self.device)
        self.norm = data['norm']
