from cvxpy.lin_ops.lin_utils import multiply
import matplotlib.pyplot as plt
import numpy as np
import cvxpy as cp
from scipy.linalg import sqrtm
from dcolor import DColor

plt.rcParams.update({'mathtext.fontset': 'stixsans'})

n = 10
d = 20

X_train = np.linspace(-1, 1, n) + 0.25j * np.sin(np.pi * np.linspace(-1, 1, n))
y_train = np.sign(np.real(X_train))
# xx = np.linspace(-1, 1, n)
# X_train = (1j - xx) / (1j + xx)
# X_train = np.exp(.5j * np.pi * xx)
# y_train = np.sign(xx)

gamma = lambda _n: np.pi / (_n+1)

def vander(Z):
  return np.array([Z ** i / np.sqrt(gamma(i)) for i in range(d)]).T

V = vander(X_train)

Sigma = np.diag([i ** 2 * gamma(i-1) / gamma(i) for i in range(1, d)])
S = np.eye(d)
S[1:, 1:] = sqrtm(np.linalg.inv(Sigma))

In, Id = np.eye(n), np.eye(d)

wr = cp.Variable(d)
wi = cp.Variable(d)
ksi = cp.Variable(n)
nu = cp.Variable(n)

Vr = cp.Parameter((n, d))
Vi = cp.Parameter((n, d))

constraints = [
  ksi >= 0,
  nu >= 0,
  y_train * (Vr @ wr - Vi @ wi) >= 1 - ksi,
  Vr @ wi + Vi @ wr <= nu,
  -nu <= Vr @ wi + Vi @ wr
]

obj = cp.Minimize(1/n * (cp.quad_form(ksi, In) + cp.quad_form(nu, In)) + cp.quad_form(wr, Id) + cp.quad_form(wi, Id))

problem = cp.Problem(obj, constraints)

Vr.value = np.real(V) 
Vi.value = np.imag(V)

problem.solve()

w_orth = wr.value + 1j * wi.value

V = V @ S

Vr.value = np.real(V) 
Vi.value = np.imag(V)

problem.solve()

w_sync = wr.value + 1j * wi.value

def clp_orth(Z):
    z = Z.flatten()
    phi = vander(z)
    out = phi @ w_orth

    return out.reshape(Z.shape)

def clp_sync(Z):
    z = Z.flatten()
    phi = vander(z) @ S
    out = phi @ w_sync

    return out.reshape(Z.shape)

dc = DColor(xmin=-1.2, xmax=1.2, ymin=-1.2, ymax=1.2, samples=200)

fig, ax = plt.subplots(1, 2, sharex='all', sharey='all')
dc.plot(clp_orth, ax=ax[0])
ax[0].scatter(np.real(X_train), np.imag(X_train), c='black')
ax[0].set_title('Orthonormal')
ax[0].set_xlabel(r'$\Re[z]$')
ax[0].set_ylabel(r'$\Im[z]$')

dc.plot(clp_sync, ax=ax[1])
ax[1].scatter(np.real(X_train), np.imag(X_train), c='black')
ax[1].set_title('Synchronized')
ax[1].set_xlabel(r'$\Re[z]$')

fig.savefig('disk.eps', bbox_inches='tight', dpi=100)