import torch
import numpy as np
import os
from .rncme import RNCME
from tqdm import trange
from time import time

class RNCMETime(RNCME):
    ''' Time/Performance w.r.t. reduced_size '''
    def __init__(self, cfg, X, Y, **kwargs):
        super().__init__(cfg, X, Y, **kwargs)
        self.time = []

    def compare(self, rncme, **kwargs):
        """
            ||PsiV1 R1 - PsiV2 R2||^2
          = tr(R1.T @ KV1 @ R1) + tr(R2.T @ KV2 @ R2) - 2 * tr(R1.T @ KV12 @ R2)
        """
        cross_norms_path = os.path.join('logs', 'cross_norms', self.cfg['task'], f'U{kwargs["user_id"]}_L{kwargs["learnware_id"]}.npy')
        max_m = kwargs.get('max_m', 30000)
        if os.path.exists(cross_norms_path):
            cross_norms = np.load(cross_norms_path)
        else:
            os.makedirs(os.path.dirname(cross_norms_path), exist_ok=True)
            R1 = self.get('R')
            V1 = self.get('V')
            R2 = self.get('R', rncme)
            V2 = self.get('V', rncme)
            KV12 = self.kernel_y(V1, V2)
            cross_norms = []
            for m in range(1, 1 + max_m):
                R1_m = R1[:m] / m
                R2_m = R2[:m] / m
                cross_norms.append(torch.trace(R1_m.T @ KV12[:m, :m] @ R2_m).item())
            np.save(cross_norms_path, np.array(cross_norms))

        return np.array([
            self.norms[m] + rncme.norms[m] - 2 * cross_norms[m]
            for m in range(max_m)
        ])

    def save(self):
        print('save to', self.path)
        np.savez(
            self.path,
            R=self.R.detach().cpu().numpy(),
            V=self.V.detach().cpu().numpy(),
            norms=self.norms,
            time=self.time
        )

    def load_helper(self):
        data = np.load(self.path, allow_pickle=True)
        self.R = torch.from_numpy(data['R']).to(self.device)
        self.V = torch.from_numpy(data['V']).to(self.device)
        self.norms = data['norms']
        self.time = data['time']

    def generate_helper(self, *args, **kwargs):
        """ V, R = argmin_{V, R} ||Psi Q - Psi_V R||^2 """
        start_time = time()
        super(RNCME, self).generate_helper(*args, **kwargs)
        reduced_size = min(self.cfg['reduced_size'], self.Y.shape[0])
        reduced_size = self.cfg['reduced_size']

        if self.task == 'regression':
            self._RNCME__init_regression(reduced_size)
        else:
            self._RNCME__init_classification()

        for t in trange(reduced_size):
            v, Wt_psiv, norm = self._RNCME__find_next_v(t)
            r = self._RNCME__find_next_r(Wt_psiv, norm)
            self.V = torch.cat([self.V, v.unsqueeze(0).detach()], dim=0)
            self.R = torch.cat([self.R, r.unsqueeze(0).detach()], dim=0)
            self.time.append(time() - start_time)

        # self.R /= reduced_size
        KV = self.kernel_y(self.V, self.V)
        self.norms = []
        for m in range(1, 1 + reduced_size):
            Rm = self.R[:m] / m
            norm_m = torch.trace(Rm.T @ KV[:m, :m] @ Rm).item()
            self.norms.append(norm_m)