import torch
import numpy as np
from .ncme import NCME
from fast_pytorch_kmeans import KMeans
from tqdm import trange
from time import time

class RNCME(NCME):
    def __init__(self, cfg, X, Y, **kwargs):
        super().__init__(cfg, X, Y, **kwargs)

    def compare(self, rncme, **kwargs):
        """
            ||PsiV1 R1 - PsiV2 R2||^2
          = tr(R1.T @ KV1 @ R1) + tr(R2.T @ KV2 @ R2) - 2 * tr(R1.T @ KV12 @ R2)
        """
        R1 = self.get('R')
        V1 = self.get('V')
        R2 = self.get('R', rncme)
        V2 = self.get('V', rncme)
        KV12 = self.kernel_y(V1, V2)
        cross_norm = torch.trace(R1.T @ KV12 @ R2).item()
        return self.norm + rncme.norm - 2 * cross_norm

    def save(self):
        print('save to', self.path)
        np.savez(
            self.path,
            R=self.R.detach().cpu().numpy(),
            V=self.V.detach().cpu().numpy(),
            norm=self.norm,
        )

    def load_helper(self):
        data = np.load(self.path, allow_pickle=True)
        self.R = torch.from_numpy(data['R']).to(self.device)
        self.V = torch.from_numpy(data['V']).to(self.device)
        self.norm = data['norm']

    def generate_helper(self, *args, **kwargs):
        """ V, R = argmin_{V, R} ||Psi V - Psi R||^2 """
        KY = super().generate_helper()
        reduced_size = min(self.cfg['reduced_size'], KY.shape[0])

        if self.task == 'regression':
            self.__init_regression(reduced_size)
        else:
            self.__init_classification()

        for t in trange(reduced_size):
            v, Wt_psiv, norm = self.__find_next_v(t)
            r = self.__find_next_r(Wt_psiv, norm)
            self.V = torch.cat([self.V, v.unsqueeze(0).detach()], dim=0)
            self.R = torch.cat([self.R, r.unsqueeze(0).detach()], dim=0)

        self.R /= reduced_size
        KV = self.kernel_y(self.V, self.V)
        self.norm = torch.trace(self.R.T @ KV @ self.R).item()

        ''' Eval '''
        ori_norm = torch.trace(self.Q.T @ KY @ self.Q).item()  # Q.T @ KY @ Q
        KVY = self.kernel_y(self.V, self.Y)
        cross_norm = torch.trace(self.R.T @ KVY @ self.Q).item()
        print(f'Reconstruct error: {self.norm + ori_norm - 2 * cross_norm:.4f}')

    def __init_classification(self):
        K_VY = self.kernel_y(self.classes_onehot, self.Y)  # (n_classes, n)
        self.term_1 = K_VY @ self.Q  # (nv, d)
        # note that len(self.classes) != self.n_classes
        self.K_VV = torch.empty(len(self.classes), 0, device=self.device)
        self.V    = torch.empty(0, self.n_classes, device=self.device)
        self.R    = torch.empty(0, self.cfg['phi_output_dim'], device=self.device)

    def __init_regression(self, n_clusters):
        ''' initialize V by k-means '''
        self.V = torch.empty(0, 1, device=self.device)
        self.R = torch.empty(0, self.cfg['phi_output_dim'], device=self.device)
        kmeans = KMeans(
            n_clusters=n_clusters,
            mode='euclidean',
            max_iter=100,
            verbose=0
        )
        kmeans.fit(self.Y)
        self.init_V = kmeans.centroids.double()

    def __find_next_v(self, t):
        if self.task == 'classification':
            V = self.classes_onehot              # (nv, d)
            Wt_psiV = self.__Wt_psiV(t, V)       # (nv, d)
            norms = torch.sum(Wt_psiV**2, dim=1) # length: nv
            v_idx = torch.argmax(norms).item()
            v = V[v_idx]
            Wt_psiv = Wt_psiV[v_idx]
            norm = norms[v_idx].item()
            return v, Wt_psiv, norm**0.5
        else:
            v = self.init_V[t].to(self.device).requires_grad_()
            for _ in range(self.cfg['steps']):
                loss = -self.__Wt_psiv_2(t, v)
                grad = torch.autograd.grad(loss, v, retain_graph=True)[0]
                max_norm = 0.1
                grad_norm = torch.norm(grad)
                if grad_norm > max_norm:
                    grad = grad * (max_norm / grad_norm.item())
                adaptive_lr = min(self.cfg['step_size'], 0.1 / (grad_norm.item() + 1e-8))
                v = v - adaptive_lr * grad

            return v, self.__Wt_psiv(t, v), (-loss.item())**0.5

    def __find_next_r(self, Wt_psiv, norm):
        """ r = (B_X / lambda) * (Wt_psiv / norm), B_X = 1 """
        return self.BX * Wt_psiv / (norm * self.lambd)

    @torch.no_grad()
    def __Wt_psiV(self, t, V):
        """
            For classification task which has determined label space
            _V: the candidate set of next v  (nv, d)
            W_t Psi(_V) in R^{nv * d}
            (W_t Psi(_V))_{kl} = sum_{i=1}^n q_{il} ky(y_i, _v_k) - 1/t sum_{j=1}^t r_{jl} ky(v_j, _v_k)
        """
        term_2 = 0
        if t > 0:
            K_Vv = self.kernel_y(V, self.V[-1])             # (nv, )
            self.K_VV = torch.cat([self.K_VV, K_Vv], dim=1) # (nv, t)
            term_2 = self.K_VV @ self.R / t                 # (nv, d)
        return self.term_1 - term_2                         # (nv, d)

    def __Wt_psiv(self, t, v):
        KYv = self.kernel_y(self.Y, v)
        term_1 = self.Q.T @ KYv             # (d, 1)
        term_2 = 0
        if t > 0:
            KVv = self.kernel_y(self.V, v)  # (t, 1)
            term_2 = (self.R.T @ KVv) / t   # (d, 1)
        return (term_1 - term_2).squeeze()  # (d, 1)

    def __Wt_psiv_2(self, t, v):
        return torch.sum(self.__Wt_psiv(t, v)**2)
