from Functions_gpu import Kernel_matrix, LG_sym, calc_differential_vec
import cupy as cp
import numpy as np
from cuml.model_selection import GridSearchCV

def Differential_method(X1, X2, K=200, k=30):
    """GPU accelerated differential method"""
    # X1_gpu = cp.asarray(X1, dtype=cp.float32)
    # X2_gpu = cp.asarray(X2, dtype=cp.float32)
    
    K1 = Kernel_matrix(X1, K)
    K2 = Kernel_matrix(X2, K)
    
    L1, d1, v1 = LG_sym(K1)
    L2, d2, v2 = LG_sym(K2)
    
    s2, u2 = calc_differential_vec(L2, v1[:, 1:], k)
    s1, u1 = calc_differential_vec(L1, v2[:, 1:], k)
    
    return s1, u1, s2, u2

def Shared_space(X1, X2, K=200, k0=None):
    """GPU-accelerated shared space computing"""
    # X1_gpu = cp.asarray(X1, dtype=cp.float32)
    # X2_gpu = cp.asarray(X2, dtype=cp.float32)
    
    K1 = Kernel_matrix(X1, K)
    K2 = Kernel_matrix(X2, K)
    
    D1 = cp.diag(cp.sum(K1, axis=1) ** (-0.5))
    D2 = cp.diag(cp.sum(K2, axis=1) ** (-0.5))
    P1 = D1 @ K1 @ D1
    P2 = D2 @ K2 @ D2
    
    P_theta = P1 @ P2 + P2 @ P1
    
    d, v = cp.linalg.eigh(P_theta)
    idx_ = cp.argsort(d)[::-1]
    d = d[idx_]
    v = v[:, idx_]
    if k0 is not None:
        v = v[:, :k0]
        d = d[:k0]
    return P_theta, d, v

def Multiple_latent_variables(X1, X2, N=5, K=200, k=100, k0=100):
    """GPU accelerated multi-latent variable computation"""
    # X1_gpu = cp.asarray(X1, dtype=cp.float32)
    # X2_gpu = cp.asarray(X2, dtype=cp.float32)
    
    _, u1, _, _ = Differential_method(X1, X2, K=K, k=k)
    deltas = [u1[:, 0]]
    
    _, _, v_shared = Shared_space(X1, X2, K=K, k0=k0)
    V = v_shared
    
    for i in range(1, N):
        V = cp.hstack([V, deltas[i-1].reshape(-1, 1)])
        _, u1_new, _, _ = Differential_method(X1, V.T, K=K, k=k)
        deltas.append(u1_new[:, 0])
    
    return cp.array(deltas).T


def ELVES(X1, X2, N=2, K=200, k=500, k0=400, w1=0.5):
    """ iteratively calculate the differential vector: take the larger value at the corresponding position as the final score"""
    delta1 = Multiple_latent_variables(X1, X2, N=N, K=K, k=k, k0=k0)
    delta2 = Multiple_latent_variables(X2, X1, N=N, K=K, k=k, k0=k0)

    # ===== intra-category aggregation =====
    # calculate score1 and score2
    score1 = delta1[:, -1] ** 2
    score2 = delta2[:, -1] ** 2

    # ===== aggregation across categories =====
    # for score1 and score2, take the larger value in the corresponding position
    score = np.maximum(score1, score2)

    return score            # delta1, delta2, score1, score2, score

