import torch
import numpy as np
from sympy import * 

device = "cuda:0"
class Kernel_Lib():
    def __init__(self):
        super(Kernel_Lib, self).__init__()

    def NTK_goon(self, K, Sigma, depth, fix= False):
        N = K.shape[0]
        for i in range(depth - 1):
            Sigma_l = torch.sqrt(torch.diag(Sigma)).unsqueeze(1).repeat(1,N)
            Sigma_r = Sigma_l.T
            Sigma = Sigma/Sigma_l/Sigma_r
            Sigma_margin = Sigma
            Sigma_margin[torch.where(Sigma>1)] = 1
            Sigma_margin[torch.where(Sigma < -1)] = -1
            Sigma_dot = (torch.pi- torch.arccos(Sigma_margin)) /torch.pi
            Sigma = (Sigma_margin*(torch.pi- torch.arccos(Sigma_margin)) + torch.sqrt(1 - Sigma_margin*Sigma_margin))*Sigma_l* Sigma_r/torch.pi
            K = Sigma_dot*K + Sigma
        if fix:
            K -= Sigma
        return K


    def NTK(self, inner_matrix, depth = 5, fix = False):
        Sigma = inner_matrix + 0
        N = inner_matrix.shape[0]
        Sigma_dot = 0
        if not fix:
            K = Sigma
        else:
            K = torch.zeros((N,N)).to(device)
        for i in range(depth - 1):

    # S = S * iL * iR;
    # float BS = (S * (3.141592654f - acosf(max(min(S, 1.0f), -1.0f))) + sqrtf(1.0f - min(S * S, 1.0f))) * L * R / 9.424777960769379f;
    # S = (3.141592654f - acosf(max(min(S, 1.0f), -1.0f))) / 9.424777960769379f;

            Sigma_l = torch.sqrt(torch.diag(Sigma)).unsqueeze(1).repeat(1,N)
            Sigma_r = Sigma_l.T
            Sigma = Sigma/Sigma_l/Sigma_r
            Sigma_margin = Sigma
            Sigma_margin[torch.where(Sigma>1)] = 1
            Sigma_margin[torch.where(Sigma < -1)] = -1
            Sigma_dot = (torch.pi- torch.arccos(Sigma_margin)) /torch.pi
            Sigma = (Sigma*(torch.pi- torch.arccos(Sigma_margin)) + torch.sqrt(1 - Sigma_margin*Sigma_margin))*Sigma_l* Sigma_r/torch.pi
            K = Sigma_dot*K + Sigma
        
        if fix:
            K = K - Sigma
        return K

    def Gaussian_Kernel(self, sub_mode_matrix, gama = 1):
        K = torch.exp(-gama*sub_mode_matrix)
        return K

    def Poly_Kernel(self, inner_matrix, k = 2, c = 1):
        K = (inner_matrix + c)**k
        return K

    def un_NTK(self, inner_matrix, depth = 5):
        Sigma = inner_matrix + 0
        Sigma_dot = 0
        K = Sigma
        for i in range(depth - 1):
            Sigma[torch.where(Sigma>1)] = 1
            Sigma[torch.where(Sigma<-1)] = -1
            Sigma_dot = (torch.pi- torch.arccos(Sigma))/torch.pi/2
            Sigma = (Sigma*(torch.pi- torch.arccos(Sigma)) + torch.sqrt(1 - Sigma*Sigma))/torch.pi/2
            K = Sigma_dot*K + Sigma
        return K