import numpy as np
from itertools import product
from numpy.polynomial import chebyshev as ch
from numpy.polynomial import legendre as lg
from sklearn.svm import LinearSVC
import matplotlib.pyplot as plt
from scipy.linalg import sqrtm
from dcolor import DColor

d = 30
n = 10

def It(i, j):
  if not i == j:
    return 0
  if i == 0:
    return np.pi
  return np.pi / 2

def Iu(i, j):
  if i % 2 == 0 and j % 2 == 0:
    return np.pi + 2 * np.pi * np.floor(min(i, j) / 2)
  if i % 2 == 1 and j % 2 == 1:
    return 2 * np.pi * np.ceil(min(i, j) / 2)
  return 0

Sigma = np.zeros(((d+1)**2, (d+1)**2))
for j, k in product(range(d+1), range(d+1)):
  for l, m in product(range(d+1), range(d+1)):
    Sigma[(d+1) * j + k, (d+1) * l + m] = j * l * Iu(j-1, l-1) * It(k, m) + k * m * It(j, l) * Iu(k-1, m-1)

Sigma = Sigma[1:,1:]

Xs, Ys = np.meshgrid(np.linspace(-1, 1, n), np.linspace(-1, 1, n))
Xt, Yt = np.meshgrid(np.linspace(-1, 1, 200), np.linspace(-1, 1, 200))

y = np.sign(np.max(np.abs(np.vstack([Xs.flatten(), Ys.flatten()])), axis=0) - 0.5)

X_train_orth = ch.chebvander2d(Xs.flatten(), Ys.flatten(), [d, d])[:, 1:]
X_test_orth = ch.chebvander2d(Xt.flatten(), Yt.flatten(), [d, d])[:, 1:]

X_train_sync = X_train_orth @ np.real(sqrtm(np.linalg.inv(Sigma)).T)
X_test_sync = X_test_orth @ np.real(sqrtm(np.linalg.inv(Sigma)).T)

fig, ax = plt.subplots(1, 2, sharex=True, sharey=True)

cls = LinearSVC()
cls.fit(X_train_orth, y)

w_orth = np.hstack([cls.coef_[0]])
b_orth = cls.intercept_[0]

yhat = cls._predict_proba_lr(X_test_orth)[:, 1]

pc0 = ax[0].contourf(Xt, Yt, yhat.reshape(Xt.shape))
ax[0].set_title('orthonormal')
ax[0].set_xlabel('$x_1$')
ax[0].set_ylabel('$x_2$')
# ax[0].scatter(Xs, Ys, c='r')

cls = LinearSVC()
cls.fit(X_train_sync, y)

w_sync = sqrtm(Sigma) @ cls.coef_[0]
b_sync = cls.intercept_[0]

yhat = cls._predict_proba_lr(X_test_sync)[:, 1]

pc1 = ax[1].contourf(Xt, Yt, yhat.reshape(Xt.shape))
ax[1].set_title('synchronized')
ax[1].set_xlabel('$x_1$')

# pc1 = ax[1].scatter(Xs, Ys, c='r')

plt.colorbar(pc0, ax=ax[0])
plt.colorbar(pc1, ax=ax[1])

fig.savefig('cheb2d_real.eps', bbox_inches='tight', dpi=150)
plt.show()