import numpy as np
import seaborn as sn
import pandas as pd
from sklearn.base import BaseEstimator
from sklearn.gaussian_process.kernels import RBF
from sklearn.kernel_approximation import Nystroem

import pandas as pd
from matplotlib import cm
import matplotlib as mpl
from matplotlib import pyplot as plt

def dist(a: np.array, b: np.array):
  assert(a.shape == b.shape)
  assert(len(a.shape) == 1)

  sum = 0
  for i in range(a.shape[0]):
    sum += (a[i] - b[i]) ** 2
  return np.sqrt(sum)

def square_root(X, rank):
  u, sigma, vt = np.linalg.svd(X, hermitian=True)
  Y = u[:,0:rank] * np.sqrt(sigma[:rank])
  return Y

class RFF(BaseEstimator):
    def __init__(self, gamma, D, metric="rbf"):
        self.gamma = gamma
        self.metric = metric
        # Target Dimension D
        self.D = D
        self.fitted = False

    def fit_transform(self, X):
        """ Generates random Fourier samples """
        d = X.shape[1]

        # Generate D iid samples from p(w) with (rbf) kernel function K(x) = exp(-gamma * |x|_{l_2}^2)
        self.w = np.sqrt(2 * self.gamma) * np.random.normal(size = (self.D, d))

        # Compute feature map Z(x):
        Z_1 = np.sqrt(1 / self.D) * np.cos(X.dot(self.w.T))
        Z_2 = np.sqrt(1 / self.D) * np.sin(X.dot(self.w.T))
        Z = np.append(Z_1.T, Z_2.T, axis = 0)
        return Z.T

# parameters
N = 100
d = 60
T = 20

RANGE = [16, 24, 30, 36, 48, 54, 60, 68, 76, 84, 90]
DATA = []

for D in RANGE:
  #R1, R2, R3 = [], [], []
  for round in range(T):

    # input data
    np.random.seed(round)
    X = np.random.normal(size = (N, d))

    # (accurate) kernel embedding
    kernel = RBF(1.0)
    K = kernel(X)
    X0 = square_root(K, N)
    assert(np.max(np.abs(X0 @ X0.T - K)) <= 1e-8)

    # approx embedding
    rff = RFF(gamma = 0.5, D = D // 2)
    # nystroem = Nystroem(kernel='rbf', gamma = 0.5, n_components = D)

    X1 = rff.fit_transform(X)
    X2 = square_root(K, D) 
    X3 = X0 @ (np.sqrt(1 / D) * np.random.normal(size = (N, D)))

    assert(X1.shape == X2.shape) # under same target dimension

    r1_max, r2_max, r3_max = 0, 0, 0
    for i in range(N):
      for j in range(i+1, N):
        l = np.sqrt(K[i][i] + K[j][j] - 2 * K[i][j])
        if (l <= 1e-8):
          continue

        r1_max = max(r1_max, np.abs(dist(X1[i], X1[j]) - l) / l)
        r2_max = max(r2_max, np.abs(dist(X2[i], X2[j]) - l) / l)
        r3_max = max(r3_max, np.abs(dist(X3[i], X3[j]) - l) / l)
    
    DATA.append({
      "Method" : "RFF",
      "Iteration" : round,
      "Target Dimension" : D,
      "Relative Error": r1_max,
    })
    DATA.append({
      "Method" : "SVD",
      "Iteration" : round,
      "Target Dimension" : D,
      "Relative Error": r2_max,
    })
    DATA.append({
      "Method" : "JL",
      "Iteration" : round,
      "Target Dimension" : D,
      "Relative Error": r3_max,
    })

DATA = pd.DataFrame(DATA)
sn.lineplot(
  data=DATA, #errorbar="sd", 
  x="Target Dimension", y="Relative Error", 
  hue="Method", style="Method", dashes=True, markers=True, 
)
plt.show()

'''
mpl.rcParams['figure.subplot.left'] = 0.125
mpl.rcParams['figure.subplot.bottom'] = 0.125
mpl.rcParams['figure.subplot.right'] = 0.95
mpl.rcParams['figure.subplot.top'] = 0.975
mpl.rcParams['figure.subplot.hspace'] = 0.1
mpl.rcParams['figure.subplot.wspace'] = 0.1

plt.plot(RANGE, A, color='r', label='RFF')
plt.plot(RANGE, B, color='b', label='SVD')
plt.plot(RANGE, C, color='g', label='JL')

plt.legend(fontsize=15, loc='best')
plt.xticks(fontsize=15)
plt.yticks(fontsize=15)

plt.xlabel('Target Dimension', fontsize=17)
plt.ylabel('Relative Error', fontsize=17)
plt.grid(linestyle='dashed')
plt.show()
'''
