from utils import DataLoader, iso_scale, normalize, compute_quadratic_features
import numpy as np
import jax.numpy as jnp
from sklearn.preprocessing import scale
from sklearn.model_selection import ShuffleSplit, GridSearchCV
from quad_jax import QuadraticClassifier, batch_loss, batch_classifier
from sklearn.svm import LinearSVC
import pandas as pd
import matplotlib.pyplot as plt
from utils import *
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--dim", type=int, default=10, help="Dimension")
args = parser.parse_args()

lmbda = 1
epochs = 50000
noise = 0.9
isotropy = 2
n_runs = 1

data = {
    'anisotropic_nuc_train': [0 for _ in range(n_runs)],
    'anisotropic_fro_train': [0 for _ in range(n_runs)],
    'isotropic_nuc_train': [0 for _ in range(n_runs)],
    'isotropic_fro_train': [0 for _ in range(n_runs)],
    'anisotropic_nuc_test': [0 for _ in range(n_runs)],
    'anisotropic_fro_test': [0 for _ in range(n_runs)],
    'isotropic_nuc_test': [0 for _ in range(n_runs)],
    'isotropic_fro_test': [0 for _ in range(n_runs)]
}


def generate(n, d, eta, r):
    d_zero = int(np.floor(d**eta))
    O = np.linalg.svd(np.random.randn(d, d))[0]

    U = O[:, :d_zero]
    U_bot = O[:, d_zero:]

    r1 = r * np.sqrt(d_zero)
    Unif_1 = [r1 * x / np.linalg.norm(x) for x in np.random.randn(n, d_zero)]

    r2 = np.sqrt(d - d_zero)
    Unif_2 = [
        r2 * x / np.linalg.norm(x) for x in np.random.randn(n, d - d_zero)
    ]

    X = [
        np.dot(U, z_1) + np.dot(U_bot, z_2)
        for (z_1, z_2) in zip(Unif_1, Unif_2)
    ]

    return np.array(X), U


def data_generation(d, isotropic=False):
    A_true = np.zeros((d, d))
    A_true[0, 0] = 1

    n_train = 20 * d * int(np.log(d))
    n_test = 20 * d * int(np.log(d))

    if not isotropic:
        X, _ = generate(n_train + n_test, d, 0, d**(isotropy))
    else:
        X = np.array([
            np.sqrt(d) * x / np.linalg.norm(x)
            for x in np.random.randn(n_train + n_test, d)
        ])
    X_train = X[:n_train, :]
    X_test = X[n_train:, :]

    training_mean = np.mean(X_train, axis=0)
    training_component_stds = np.std(X_train, axis=0)
    X_train_scaled = scale(X_train)
    X_test_scaled = (X_test - training_mean) / training_component_stds

    y_train = inject_noise(
        np.sign(batch_classifier(A_true, np.zeros(d), 0, X_train_scaled)),
        noise)
    y_test = inject_noise(
        np.sign(batch_classifier(A_true, np.zeros(d), 0, X_test_scaled)),
        noise)

    sigma_train = np.average([np.outer(x, x) for x in X_train_scaled], axis=0)
    sigma_test = np.average([np.outer(x, x) for x in X_test_scaled], axis=0)

    train_dim = np.trace(sigma_train) / np.linalg.norm(sigma_train, ord=2)
    test_dim = np.trace(sigma_test) / np.linalg.norm(sigma_test, ord=2)

    print("INTRINSICS ARE :", train_dim, test_dim, d)

    return X_train_scaled, y_train, X_test_scaled, y_test


d = args.dim
for r in range(n_runs):
    nuc = QuadraticClassifier(dim=d, lmbda=lmbda, norm='nuc')
    fro = QuadraticClassifier(dim=d, lmbda=lmbda, norm='fro')

    X_train, y_train, X_test, y_test = data_generation(d, isotropic=False)

    nuc.fit(X_train,
            y_train,
            n_epoch=epochs,
            batch_size=len(X_train),
            plot=(r == n_runs - 1),
            fname="SYNTH NUC LOW" + str(d))
    fro.fit(X_train,
            y_train,
            n_epoch=epochs,
            batch_size=len(X_train),
            plot=(r == n_runs - 1),
            fname="SYNTH FRO LOW" + str(d))

    nuc_train_loss = batch_loss(nuc.A, nuc.b, nuc.c, X_train, y_train)
    nuc_test_loss = batch_loss(nuc.A, nuc.b, nuc.c, X_test, y_test)
    fro_train_loss = batch_loss(fro.A, fro.b, fro.c, X_train, y_train)
    fro_test_loss = batch_loss(fro.A, fro.b, fro.c, X_test, y_test)

    data['anisotropic_nuc_train'][r] = nuc_train_loss
    data['anisotropic_fro_train'][r] = fro_train_loss
    data['anisotropic_nuc_test'][r] = nuc_test_loss
    data['anisotropic_fro_test'][r] = fro_test_loss

    X_train, y_train, X_test, y_test = data_generation(d, isotropic=True)

    nuc.fit(X_train,
            y_train,
            n_epoch=epochs,
            batch_size=len(X_train),
            plot=(r == n_runs - 1),
            fname="SYNTH NUC HIGH" + str(d))
    fro.fit(X_train,
            y_train,
            n_epoch=epochs,
            batch_size=len(X_train),
            plot=(r == n_runs - 1),
            fname="SYNTH FRO HIGH" + str(d))

    nuc_train_loss = batch_loss(nuc.A, nuc.b, nuc.c, X_train, y_train)
    nuc_test_loss = batch_loss(nuc.A, nuc.b, nuc.c, X_test, y_test)
    fro_train_loss = batch_loss(fro.A, fro.b, fro.c, X_train, y_train)
    fro_test_loss = batch_loss(fro.A, fro.b, fro.c, X_test, y_test)

    data['isotropic_nuc_train'][r] = nuc_train_loss
    data['isotropic_fro_train'][r] = fro_train_loss
    data['isotropic_nuc_test'][r] = nuc_test_loss
    data['isotropic_fro_test'][r] = fro_test_loss

df = pd.DataFrame(data)
df.to_csv("../log/" + "proper_synth_dim" + str(args.dim) + ".csv")

print("DONEZOO")
