from fast_pytorch_kmeans import KMeans
import torch
import numpy as np
from .base import CMEBase
from time import time

class LabelRKME(CMEBase):
    def __init__(self, cfg, X, Y, **kwargs):
        super().__init__(cfg, X, Y, **kwargs)
        self.reduced_size = cfg['reduced_size']
        sorted_indices = torch.argsort(Y)
        self.X = X[sorted_indices].to(self.device)
        self.Y = Y[sorted_indices].to(self.device)
        self.class_ratios = None

    def generate_helper(self, *args, **kwargs):
        start = time()
        self.__init_Z_by_kMeans()
        self.__update_beta()
        for step in range(self.cfg['steps']):
            self.__update_Z()
            self.__update_beta()
        print('Reduction time:', time() - start)

    def __init_Z_by_kMeans(self):
        """ Initialize Z using KMeans clustering """
        self.classes, counts = torch.unique(self.Y, return_counts=True)
        class_ratios = counts / counts.sum()
        sample_per_class = (class_ratios * self.reduced_size).int()
        discrepancy = self.reduced_size - sample_per_class.sum()
        if discrepancy != 0:
            max_count_idx = torch.argmax(counts)
            sample_per_class[max_count_idx] += discrepancy
            assert sample_per_class[max_count_idx] > 0
        assert self.reduced_size == sample_per_class.sum().item()
        self.class_ratios = sample_per_class / self.reduced_size
        centroids_per_class = []
        labels_per_class = []

        for cls, num_samples in zip(self.classes, sample_per_class):
            class_samples = self.X[self.Y == cls]
            kmeans = KMeans(n_clusters=num_samples.item(), mode="euclidean", max_iter=100, verbose=0)
            kmeans.fit(class_samples)
            centroids_per_class.append(kmeans.centroids)
            labels_per_class.extend([cls] * num_samples.item())

        self.Z = torch.cat(centroids_per_class, dim=0).double()
        self.KZ = self.kernel_x(self.Z, self.Z)
        self.y = torch.tensor(labels_per_class, device=self.device)

    def __update_beta(self):
        Z = self.Z
        KZ = self.KZ
        Y = self.y   # not self.Y
        Z_len_list = self.__get_class_num_list(Y)
        K_block = self.__generate_block_diag_matrix(K=KZ, col_len_list=Z_len_list, row_len_list=Z_len_list)
        KZ += K_block

        KZX = self.kernel_x(Z, self.X)
        X_len_list = self.__get_class_num_list(self.Y)
        C_block = self.__generate_block_diag_matrix(K=KZX, col_len_list=X_len_list, row_len_list=Z_len_list)
        KZX += C_block
        C = torch.sum(KZX, dim=1) / self.n

        beta = (torch.linalg.inv(KZ + torch.eye(KZ.shape[0]).to(self.device) * 1e-5) @ C)
        self.beta = beta

    def __update_Z(self):
        Z = self.Z
        beta = self.beta
        Y = self.y   # not self.Y

        grad_Z = torch.zeros_like(Z)
        for i in range(self.reduced_size):
            z_i = Z[i, :].reshape(1, -1)
            z_i_label = Y[i]

            match_Z = (Y == z_i_label).double() + 1.0
            match_X = (self.Y == z_i_label).double() + 1.0

            term_1 = 2 * (beta * match_Z * self.kernel_x(z_i, Z)) @ (z_i - Z)
            term_2 = (-2 * (match_X * self.kernel_x(z_i, self.X) / self.n) @ (z_i - self.X))
            grad_Z[i, :] = -2 * beta[i] * (term_1 + term_2)

        Z = Z - self.cfg['step_size'] * grad_Z
        self.Z = Z
        self.KZ = self.kernel_x(Z, Z)
        self.norm = (self.beta @ self.KZ @ self.beta).item()

    def __get_class_num_list(self, Y):
        _, counts = torch.unique(Y, return_counts=True)
        class_num_list = counts.tolist()
        return class_num_list

    def __generate_block_diag_matrix(self, K, col_len_list, row_len_list):
        # Initialize a list to store block matrices
        blocks = []

        # Define starting indices for rows and columns
        row_start = 0
        col_start = 0

        # Extract block matrices along the diagonal
        for col_len, row_len in zip(col_len_list, row_len_list):
            # Extract the block
            block = K[row_start : row_start + row_len, col_start : col_start + col_len]
            blocks.append(block)
            # Update the starting indices
            row_start += row_len
            col_start += col_len

        # Create a block diagonal matrix from the list of blocks
        return torch.block_diag(*blocks)

    def save(self):
        print('save to', self.path)
        np.savez(
            self.path,
            Z=self.Z.detach().cpu().numpy(),
            KZ=self.KZ.detach().cpu().numpy(),
            beta=self.beta.detach().cpu().numpy(),
            y=self.y.detach().cpu().numpy(),
            classes=self.classes.tolist(),
            class_ratios=self.class_ratios.tolist(),
            norm=self.norm
        )

    def load_helper(self):
        # Load relevant data from self.path
        data = np.load(self.path)
        self.Z = torch.tensor(data['Z'], device=self.device)
        self.KZ = torch.tensor(data['KZ'], device=self.device)
        self.beta = torch.tensor(data['beta'], device=self.device)
        self.y = torch.tensor(data['y'])
        self.classes = data['classes']
        self.class_ratios = data['class_ratios']
        self.norm = data['norm']
