import numpy as np
import matplotlib.pyplot as plt
import torch

from sklearn.svm import LinearSVC

from utils import *
from quad_jax import QuadraticClassifier, batch_classifier

from sklearn.preprocessing import StandardScaler

d = 100
eta = 0.5
noise = 0.001
nuc_norm = 1
r = 1
n_train = 20 * d * int(np.log(d))
n_test = 20 * d * int(np.log(d))
epochs = 2000
batch_size = n_train
rep = 5

#Init
s_list = np.linspace(0, 0.8, 10)
lam = {"lmbda": [0.01, 0.1, 0.5, 1, 5]}
train_error_nuc = np.zeros((rep, len(s_list)))
test_error_nuc = np.zeros((rep, len(s_list)))
train_error_fro = np.zeros((rep, len(s_list)))
test_error_fro = np.zeros((rep, len(s_list)))


def data_generation(n_train, n_test, dim, nuc_norm, noise, rank):
    A_true = np.random.randn(dim, dim)
    A_true = rank_constraint(A_true, rank)
    A_true *= nuc_norm / np.linalg.norm(A_true, ord='nuc')

    u, v = np.random.randn(2, dim)
    A_true = np.outer(u, v)
    A_true *= nuc_norm / np.linalg.norm(A_true, ord='nuc')

    X_train = sample_spherical(n_train, dim)
    y_train = np.sign(batch_classifier(A_true, np.zeros(dim), 0, X_train))
    y_train = inject_noise(y_train, noise)

    X_test = sample_spherical(n_test, dim)
    y_test = np.sign(batch_classifier(A_true, np.zeros(dim), 0, X_test))
    y_test = inject_noise(y_test, noise)

    return X_train, y_train, X_test, y_test, A_true

def generate(n, d, eta, r):
    d_zero = int(np.floor(d**eta))
    O = np.linalg.svd(np.random.randn(d, d))[0]

    U = O[:, :d_zero]
    U_bot = O[:, d_zero:]

    r1 = r*np.sqrt(d_zero)
    Unif_1 = [r1 * x/np.linalg.norm(x) for x in np.random.randn(n, d_zero)]

    r2 = np.sqrt(d-d_zero)
    Unif_2 = [r2 * x/np.linalg.norm(x) for x in np.random.randn(n, d - d_zero)]

    X = [np.dot(U, z_1) + np.dot(U_bot, z_2) for (z_1, z_2) in zip(Unif_1, Unif_2)]

    return np.array(X), U

for r_ in range(rep):
    print(r_)
    u, v = np.random.randn(2, int(np.floor(d**eta)))
    A_true = np.outer(u, v)
    A_true *= nuc_norm / np.linalg.norm(A_true, ord='nuc')
    for i_s, s in enumerate(s_list):
        print(s)
        X, U = generate(n_train + n_test, d, eta, d**(s))
        X_train_s = X[:n_train, :]
        X_test_s = X[n_train:, :]
        print(U.shape, A_true.shape)
        y_train_s = inject_noise(np.sign(batch_classifier(U @ A_true @ U.T, np.zeros(d), 0, X_train_s)), noise)
        y_test_s = inject_noise(np.sign(batch_classifier(U @ A_true @ U.T, np.zeros(d), 0, X_test_s)), noise)



        nuc = GridSearchCV(QuadraticClassifier(dim=d, norm='nuc'), lam, cv=[(np.arange(n_train), n_train + np.arange(n_train))], refit=True, n_jobs=4)
        fro = GridSearchCV(QuadraticClassifier(dim=d, norm='fro'), lam, cv=[(np.arange(n_train), n_train + np.arange(n_train))], refit=True, n_jobs=4)

        nuc.fit(np.concatenate([X_train_s, X_train_s]) , np.concatenate([y_train_s, y_train_s]), n_epoch=epochs, batch_size=batch_size)
        fro.fit(np.concatenate([X_train_s, X_train_s]) , np.concatenate([y_train_s, y_train_s]), n_epoch=epochs, batch_size=batch_size)

        train_error_nuc[r_, i_s] = nuc.score(X_train_s, y_train_s)
        test_error_nuc[r_, i_s] = nuc.score(X_test_s, y_test_s)
        train_error_fro[r_, i_s] = fro.score(X_train_s, y_train_s)
        test_error_fro[r_, i_s] = fro.score(X_test_s, y_test_s)

        #print("DONEZO")


plt.errorbar(s_list, 1 - np.mean(train_error_nuc, axis=0), yerr=np.std(train_error_nuc, axis=0), label="Nuc train", c='r')
plt.errorbar(s_list, 1 - np.mean(test_error_nuc, axis=0), yerr=np.std(test_error_nuc, axis=0), label="Nuc test", c='r', ls='--')
plt.errorbar(s_list, 1 - np.mean(train_error_fro, axis=0), yerr=np.std(train_error_fro, axis=0),  label="Fro train", c='b')
plt.errorbar(s_list, 1 - np.mean(test_error_fro, axis=0), yerr=np.std(test_error_fro, axis=0),  label="Fro test", c='b', ls='--')
plt.legend()
plt.savefig("../log/fig-{}.png".format(d), format='png')
print("HEREZOZOZO")
#plt.show()

'''
np.save("train_nuc_ani_d={}.npy".format(d), train_error_nuc)
np.save("test_nuc_ani_d={}.npy".format(d), test_error_nuc)
np.save("train_fro_ani_d={}.npy".format(d), train_error_fro)
np.save("test_fro_ani_d={}.npy".format(d), test_error_fro)
'''
