from quad_ascent import *
from utils import DataLoader, iso_scale, normalize, compute_quadratic_features
import numpy as np
import jax.numpy as jnp
from sklearn.preprocessing import scale
from sklearn.svm import LinearSVC
import pandas as pd
import matplotlib.pyplot as plt
from utils import *
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--dim", type=int, default=10, help="Dimension")
parser.add_argument("--id", type=int, default=0, help="Dimension")
args = parser.parse_args()

lmbda = 1
epochs = 5000
noise = 0.001
isotropy = 2

np.random.seed(23451*2**args.id)
A = np.random.randn(args.dim, args.dim)
A_true = A/np.linalg.norm(A, ord='nuc')


def generate(n, d, eta, r):
    d_zero = int(np.floor(d**eta))
    O = np.linalg.svd(np.random.randn(d, d))[0]

    U = O[:, :d_zero]
    U_bot = O[:, d_zero:]

    r1 = r * np.sqrt(d_zero)
    Unif_1 = [r1 * x / np.linalg.norm(x) for x in np.random.randn(n, d_zero)]

    r2 = np.sqrt(d - d_zero)
    Unif_2 = [
        r2 * x / np.linalg.norm(x) for x in np.random.randn(n, d - d_zero)
    ]

    X = [
        np.dot(U, z_1) + np.dot(U_bot, z_2)
        for (z_1, z_2) in zip(Unif_1, Unif_2)
    ]

    return np.array(X), U



def data_generation(d, isotropic=False):

    n_train = 5 * d * int(np.log(d))
    n_test = 5 * d * int(np.log(d))

    X = np.array([
            np.sqrt(d) * x / np.linalg.norm(x)
            for x in np.random.randn(n_train + n_test, d)
        ])

    if not isotropic:
        X, _ = generate(n_train + n_test, d, 0.5, d**2)
    else:
        X, _ = generate(n_train + n_test, d, 0.5, 1)

    X_train = X[:n_train, :]
    X_test = X[n_train:, :]

    training_mean = np.mean(X_train, axis=0)
    training_component_stds = np.std(X_train, axis=0)
    X_train_scaled = scale(X_train)
    X_test_scaled = (X_test - training_mean) / training_component_stds

    y_train = inject_noise(
        np.sign(batch_classifier(A_true, np.zeros(d), 0, X_train_scaled)),
        noise)
    y_test = inject_noise(
        np.sign(batch_classifier(A_true, np.zeros(d), 0, X_test_scaled)),
        noise)

    sigma_train = np.average([np.outer(x, x) for x in X_train_scaled], axis=0)
    sigma_test = np.average([np.outer(x, x) for x in X_test_scaled], axis=0)

    train_dim = np.trace(sigma_train) / np.linalg.norm(sigma_train, ord=2)
    test_dim = np.trace(sigma_test) / np.linalg.norm(sigma_test, ord=2)

    print("INTRINSICS ARE :", train_dim, test_dim, d)

    return X_train_scaled, y_train, X_test_scaled, y_test

d = args.dim
nuc = QuadraticClassifier(dim=d, lmbda=lmbda, norm='nuc')
fro = QuadraticClassifier(dim=d, lmbda=lmbda, norm='fro')

X_train, y_train, X_test, y_test = data_generation(d, isotropic=False)

nuc.fit(X_train,
        y_train, 
        X_test, 
        y_test,
        n_epoch=epochs,
        plot=(args.id == 0),
        fname="worst nuc aniso" + str(d))

NUC_ANISO_MAX = batch_loss(nuc.A, nuc.b, nuc.c, X_train, y_train, X_test, y_test)

fro.fit(X_train,
        y_train, 
        X_test, 
        y_test,
        n_epoch=epochs,
        plot=(args.id == 0),
        fname="worst fro aniso" + str(d))

FRO_ANISO_MAX = batch_loss(fro.A, fro.b, fro.c, X_train, y_train, X_test, y_test)

X_train, y_train, X_test, y_test = data_generation(d, isotropic=True)

nuc.fit(X_train,
        y_train, 
        X_test, 
        y_test,
        n_epoch=epochs,
        plot=(args.id == 0),
        fname="worst nuc iso" + str(d))

NUC_ISO_MAX = batch_loss(nuc.A, nuc.b, nuc.c, X_train, y_train, X_test, y_test)

fro.fit(X_train,
        y_train, 
        X_test, 
        y_test,
        n_epoch=epochs,
        plot=(args.id == 0),
        fname="worst fro iso" + str(d))

FRO_ISO_MAX = batch_loss(fro.A, fro.b, fro.c, X_train, y_train, X_test, y_test)

data = {'nuc_iso': NUC_ISO_MAX,
        'nuc_aniso': NUC_ANISO_MAX,
        'fro_iso': FRO_ISO_MAX,
        'fro_aniso': FRO_ANISO_MAX}

df = pd.Series(data)
df.to_csv("../log/" + "worst_dim_" + str(args.dim) + "_run_"+str(args.id)+".csv")

