import torch
import numpy as np
from .flat_model_para import get_id_pos

def centerize(vecs,dim=0,keepdims=True):
    if isinstance(vecs, np.ndarray):
        mu= vecs.mean(axis=dim,keepdims=keepdims)          # NumPy
    elif isinstance(vecs, torch.Tensor):
        mu= vecs.mean(dim=dim,keepdims=keepdims)           # PyTorch
    else:
        raise TypeError("仅支持 np.ndarray 或 torch.Tensor")
    vecs_centered = vecs - mu
    return vecs_centered


def get_kernel(vecs_centered):
    kernel = (vecs_centered @ vecs_centered.T) / vecs_centered.shape[0]
    return kernel


class Kernel():
    def __init__(self,kernel,centered_kernel=None,ids=None):
        self.kernel=kernel
        self.centered_kernel=centered_kernel
        self.ids=ids
        if ids is None:
            self.ids=list(range(len(kernel)))

    @staticmethod
    def center_kernel(A):
        """
        输入: A  (n*n) 协方差矩阵 A = G Gᵀ
        输出: B  (n*n) 中心化后的核矩阵 H Hᵀ
        """
        n = A.size(0)
        device = A.device
        dtype = A.dtype
        ones = torch.ones(n, 1, device=device, dtype=dtype)
        
        # 中心矩阵 I - 1 1ᵀ / n
        C = torch.eye(n, device=device, dtype=dtype) - ones @ ones.T / n
        
        # 公式: B = C A Cᵀ
        B = C @ A @ C.T
        return B
    

    def get_centered_kernel(self):
        if self.centered_kernel is None:
            self.centered_kernel=self.center_kernel(self.kernel)
        return self.centered_kernel

    def get_sub(self,ids):
        p_ids=get_id_pos(ids,self.ids)
        new_kernel=self.kernel[p_ids][:,p_ids]
        return Kernel(new_kernel,ids=ids)

    @staticmethod
    def get_top_eigen(kernel,k=2,cpu=False, eps=1e-6):
        if cpu:
            device=kernel.device
            kernel=kernel.cpu()
            eigenvalues, eigenvectors = np.linalg.eigh(kernel)
            eigenvalues=torch.tensor(eigenvalues,device=device)
            eigenvectors=torch.tensor(eigenvectors,device=device)
        else:
            kernel = kernel + torch.eye(kernel.size(0), device=kernel.device) * eps
            eigenvalues, eigenvectors = torch.linalg.eigh(kernel)

        e_values, indices = torch.topk(eigenvalues, k=k) 

        e_vectors = eigenvectors[:, indices]

        return e_values,e_vectors

    @staticmethod
    def get_projection(kernel,e_vectors,normalize_e=False,normalize_i=False): 
        values= kernel @ e_vectors
        if  normalize_e:
            norm_e=torch.sqrt(e_vectors.T @ kernel @ e_vectors)
            values=values/norm_e
        if  normalize_i:
            norm_i=torch.sqrt(torch.diag(kernel, 0))
            values=values/norm_i
        
        return values

