import numpy as np
import sklearn
from sklearn.linear_model import Ridge
from sklearn.kernel_ridge import KernelRidge

# Return pairwise kernel
def return_pairwise_kernel(kernel_method):
    if kernel_method == 'linear':
        return sklearn.metrics.pairwise.linear_kernel
    elif kernel_method == 'rbf':
        return sklearn.metrics.pairwise.rbf_kernel
    elif kernel_method == 'polynomial':
        return sklearn.metrics.pairwise.polynomial_kernel
    else:
        raise NotImplementedError


# CLIP-UCB bonus
def clip_ucb_bonus(prompt2gen_feat, drep, kernel_method, kernel, alpha, gamma, N_feat=512, apply_rff=False):
    G = len(drep)
    bonus = np.zeros((G,))

    if apply_rff:
        for g in range(G):
            reg_target_g = np.squeeze(drep[g] @ prompt2gen_feat.T, axis=-1)
            reg_bon_g = Ridge(alpha=alpha).fit(drep[g], reg_target_g)
            bonus[g] = reg_bon_g.predict(prompt2gen_feat)[0]
        return np.sqrt(np.maximum(np.zeros((G,)), np.ones((G,)) - bonus))

    for g in range(G):
        reg_target_g = np.squeeze(drep[g] @ prompt2gen_feat.T, axis=-1)
        reg_bon_g = KernelRidge(kernel=kernel_method, alpha=alpha, gamma=gamma).fit(drep[g], reg_target_g)
        bonus[g] = reg_bon_g.predict(prompt2gen_feat)[0]
    c = kernel(X=prompt2gen_feat.reshape(1, N_feat), Y=prompt2gen_feat.reshape(1, N_feat)).flat[0]
    return np.sqrt(np.maximum(np.zeros((G,)), c * np.ones((G,)) - bonus))


class RFFKernel:
    def __init__(self, input_dim, num_features, sigma):
        self.input_dim = input_dim
        self.num_features = num_features
        self.sigma = sigma
        self.W = np.random.normal(scale=1.0 / sigma, size=(num_features, input_dim))
        self.b = np.random.uniform(0, 2 * np.pi, num_features)

    def transform(self, X):
        projection = np.dot(X, self.W.T) + self.b
        Z = np.sqrt(2.0 / self.num_features) * np.cos(projection)
        return Z
