import numpy as np
import matplotlib.pyplot as plt
import torch

from sklearn.svm import LinearSVC

from utils import *
from quad_jax import QuadraticClassifier

from sklearn.preprocessing import StandardScaler

d = 5
noise = 0.1
nuc_norm = 1
r = 5
n_train = 10 * d * int(np.log(d))
n_test = 10 * d * int(np.log(d))
epochs = 10
batch_size = 20
rep = 1

#Init
s_list = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
lam = {"lmbda": [0.1, 0.5, 1, 2.5, 5, 7.5, 10, 12.5, 15, 17.5, 20]}
train_error_nuc = np.zeros((rep, len(s_list)))
test_error_nuc = np.zeros((rep, len(s_list)))
train_error_fro = np.zeros((rep, len(s_list)))
test_error_fro = np.zeros((rep, len(s_list)))


for r_ in range(rep):
    print(r_)
    X_train_0, y_train_0, X_test_0, y_test_0, A_true = data_generation(n_train, n_test, d, nuc_norm, noise, r)
    for i_s, s in enumerate(s_list):
        print(s)
        X_train_s = anisotropize(X_train_0, s)
        X_test_s = anisotropize(X_test_0, s)
        y_train_s = y_train_0
        y_test_s = y_test_0

        split = [(np.arange(n_train), n_train + np.arange(n_train))]
        nuc = GridSearchCV(QuadraticClassifier(dim=d, norm='nuc'), lam, cv=split, refit=True, n_jobs=-1)
        fro = GridSearchCV(QuadraticClassifier(dim=d, norm='fro'), lam, cv=split, refit=True, n_jobs=-1)

        nuc.fit(np.concatenate([X_train_s, X_train_s]), np.concatenate([y_train_s, y_train_s]))
        fro.fit(np.concatenate([X_train_s, X_train_s]), np.concatenate([y_train_s, y_train_s]))

        train_error_nuc[r_, i_s] = nuc.score(X_train_s, y_train_s)
        test_error_nuc[r_, i_s] = nuc.score(X_test_s, y_test_s)
        train_error_fro[r_, i_s] = fro.score(X_train_s, y_train_s)
        test_error_fro[r_, i_s] = fro.score(X_test_s, y_test_s)


plt.errorbar(s_list, 1 - np.mean(train_error_nuc, axis=0), yerr=np.std(train_error_nuc, axis=0), label="Nuc train", c='r')
plt.errorbar(s_list, 1 - np.mean(test_error_nuc, axis=0), yerr=np.std(test_error_nuc, axis=0), label="Nuc test", c='r', ls='--')
plt.errorbar(s_list, 1 - np.mean(train_error_fro, axis=0), yerr=np.std(train_error_fro, axis=0),  label="Fro train", c='b')
plt.errorbar(s_list, 1 - np.mean(test_error_fro, axis=0), yerr=np.std(test_error_fro, axis=0),  label="Fro test", c='b', ls='--')
plt.legend()
plt.savefig("../log/fig-{}.png".format(d), format='png')
plt.show()

np.save("train_nuc_ani_d={}.npy".format(d), train_error_nuc)
np.save("test_nuc_ani_d={}.npy".format(d), test_error_nuc)
np.save("train_fro_ani_d={}.npy".format(d), train_error_fro)
np.save("test_fro_ani_d={}.npy".format(d), test_error_fro)
