import numpy as np
from numpy.polynomial import chebyshev as ch
from sklearn.svm import LinearSVC
import matplotlib.pyplot as plt
from scipy.linalg import sqrtm
from dcolor import DColor
from itertools import product

def Iu(i, j):
    if i % 2 == 0 and j % 2 == 0:
        return np.pi + 2 * np.pi * np.floor(min(i, j) / 2)
    if i % 2 == 1 and j % 2 == 1:
        return 2 * np.pi * np.ceil(min(i, j) / 2)
    return 0

n = 6
d = 30

Xs = np.linspace(-1, 1, n)
Xt = np.linspace(-1, 1, 1000)
y = np.sign(Xs)

X_train = ch.chebvander(Xs, d)[:, 1:]
X_test = ch.chebvander(Xt, d)[:, 1:]

fig, ax = plt.subplots(1, 2, sharex=True, sharey=True)

cls = LinearSVC()
cls.fit(X_train, y)

w_orth = np.hstack([cls.coef_[0]])
b_orth = cls.intercept_

yhat = cls._predict_proba_lr(X_test)[:, 1]

ax[0].plot(Xt, yhat)
ax[0].set_title('orthogonal')
ax[0].set_xlabel('$x$')
ax[0].set_ylabel('$h(x)$')

Sigma = np.zeros((d+1, d+1))
for j, k in product(range(d+1), range(d+1)):
  Sigma[j, k] = j * k * Iu(j-1, k-1)

Sigma = Sigma[1:,1:]

X_train = X_train @ sqrtm(np.linalg.inv(Sigma)).T
X_test = X_test @ sqrtm(np.linalg.inv(Sigma)).T

cls = LinearSVC()
cls.fit(X_train, y)

w_sync = sqrtm(Sigma) @ cls.coef_[0]
b_sync = cls.intercept_

yhat = cls._predict_proba_lr(X_test)[:, 1]

ax[1].plot(Xt, yhat)
ax[1].set_title('synchronized')
ax[1].set_xlabel('$x$')

fig.savefig('cheb1d_real.eps', bbox_inches='tight', dpi=100)

fig, ax = plt.subplots(1, 2, sharex=True, sharey=True)

def clp_orth(Z):
    z = Z.flatten()
    phi = ch.chebvander(z, d)[:, 1:]
    out = phi @ w_orth + b_orth

    return out.reshape(Z.shape)

def clp_sync(Z):
    z = Z.flatten()
    phi = ch.chebvander(z, d)[:, 1:] @ sqrtm(np.linalg.inv(Sigma)).T
    out = phi @ w_sync + b_sync

    return out.reshape(Z.shape)


dc = DColor(xmin=-1.2, xmax=1.2, ymin=-1.2, ymax=1.2, samples=200)
dc.plot(clp_orth, ax=ax[0])
ax[0].set_title('orthogonal')
ax[0].set_xlabel('$\Re[z]$')
ax[0].set_ylabel('$\Im[z]$')

dc.plot(clp_sync, ax=ax[1])
ax[1].set_title('synchronized')
ax[1].set_xlabel('$\Re[z]$')

fig.savefig('cheb1d_imag.eps', bbox_inches='tight', dpi=100)

plt.show()