import numpy as np
import matplotlib.pyplot as plt
import torch

from sklearn.svm import LinearSVC

from utils import *
from quad_jax import QuadraticClassifier

def data_generation(n_train, n_test, dim, nuc_norm, noise, rank):
    A_true = np.random.randn(dim, dim)
    A_true = rank_constraint(A_true, rank)
    A_true *= nuc_norm / np.linalg.norm(A_true, ord='nuc')
    
    X_train = sample_spherical(n_train, dim)
    z = X_train @ A_true
    y_train = np.sign([np.dot(z_, x_) for z_, x_ in zip(z, X_train)])
    y_train = inject_noise(y_train, noise)

    X_test = sample_spherical(n_test, dim)
    z = X_test @ A_true
    y_test = np.sign([np.dot(z_, x_) for z_, x_ in zip(z, X_test)])
    y_test = inject_noise(y_test, noise)
    
    return X_train, y_train, X_test, y_test, A_true

from sklearn.preprocessing import StandardScaler

def train(X_train, y_train, X_test, y_test, lr, lam, epochs, batch_size, norm):
    classifier = QuadraticClassifier(dim=d, lmbda = lam, norm=norm)
    classifier.fit(X_train,y_train,n_epoch=epochs)
    return 1-classifier.score(X_train,y_train), 1-classifier.score(X_test,y_test), np.linalg.norm(classifier.A, ord='nuc')


d = 50
noise = 0.1
nuc_norm = 1
r = 5
n_train = 10 * d * int(np.log(d))
n_test = 10 * d * int(np.log(d))


#Init
s_list = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
lam = [0.1, 0.5, 1, 2.5, 5, 7.5, 10, 12.5, 15, 17.5, 20]
train_error_nuc_dict = dict()
test_error_nuc_dict = dict()
train_error_fro_dict = dict()
test_error_fro_dict = dict()
for i_s, s in enumerate(s_list):
    train_error_nuc_dict[str(s)] = dict()
    test_error_nuc_dict[str(s)] = dict()
    train_error_fro_dict[str(s)] = dict()
    test_error_fro_dict[str(s)] = dict()

epochs = 25000
batch_size = 20
lr=0.0001
rep = 5
for r_ in range(rep):
    print(r_)
    X_train_0, y_train_0, X_test_0, y_test_0, A_true = data_generation(n_train, n_test, d, nuc_norm, noise, r)
    for i_s, s in enumerate(s_list):
        X_train_s = anisotropize(X_train_0, s)
        X_test_s = anisotropize(X_test_0, s)
        y_train_s = y_train_0
        y_test_s = y_test_0
        for i_l, l in enumerate(lam):
            print("s = {}, lam = {}".format(s, l))
            if not str(l) in train_error_nuc_dict[str(s)].keys():
                train_error_nuc_dict[str(s)][str(l)] = []
                test_error_nuc_dict[str(s)][str(l)] = []
            train_err, test_err, nuc = train(X_train_s, y_train_s, X_test_s, y_test_s, lr, l, epochs, batch_size, norm='nuc')
            train_error_nuc_dict[str(s)][str(l)].append(train_err)
            test_error_nuc_dict[str(s)][str(l)].append(test_err)
            
            if not str(l) in train_error_fro_dict[str(s)].keys():
                train_error_fro_dict[str(s)][str(l)] = []
                test_error_fro_dict[str(s)][str(l)] = []
            train_err, test_err, nuc = train(X_train_s, y_train_s, X_test_s, y_test_s, lr, l, epochs, batch_size, norm='fro')
            train_error_fro_dict[str(s)][str(l)].append(train_err)
            test_error_fro_dict[str(s)][str(l)].append(test_err)

for i_l, l in enumerate(lam):

    train_error_nuc = [train_error_nuc_dict[str(s)][str(l)] for s in s_list]
    test_error_nuc = [test_error_nuc_dict[str(s)][str(l)] for s in s_list]
    train_error_fro = [train_error_fro_dict[str(s)][str(l)] for s in s_list]
    test_error_fro = [test_error_fro_dict[str(s)][str(l)] for s in s_list]
    plt.errorbar(s_list, np.mean(train_error_nuc, axis=1), yerr=np.std(train_error_nuc, axis=1), label="Nuc train")
    plt.errorbar(s_list, np.mean(test_error_nuc, axis=1), yerr=np.std(test_error_nuc, axis=1), label="Nuc test", ls='--')
    plt.errorbar(s_list, np.mean(train_error_fro, axis=1), yerr=np.std(train_error_fro, axis=1),  label="Fro train")
    plt.errorbar(s_list, np.mean(test_error_fro, axis=1), yerr=np.std(test_error_fro, axis=1),  label="Fro test", ls='--')

plt.legend()


plt.savefig("old_figs/fig-{}.png".format(d), format='png')
np.save("train_nuc_ani_d={}.npy".format(d), train_error_nuc_dict)
np.save("test_nuc_ani_d={}.npy".format(d), test_error_nuc_dict)
np.save("train_fro_ani_d={}.npy".format(d), train_error_fro_dict)
np.save("test_fro_ani_d={}.npy".format(d), test_error_fro_dict)
