"""Construct Bilateral Coresets."""

#################################### Import Modules ####################################

from time import time
import jax
import jax.random as jr
import jax.numpy as jnp
import equinox as eqx
import optax
from sklearn.preprocessing import StandardScaler
from sklearn.datasets import fetch_openml

from coreax.solvers import (
    SupervisedLinearBilateralDistributionCompression,
    SupervisedNonLinearBilateralDistributionCompression,
    JointKIP,
    M3D,
)
from coreax.data import SupervisedData
from coreax.kernels import SquaredExponentialKernel, median_heuristic
from coreax.solvers.autoencoders import Autoencoder, Encoder, Decoder
################################### Set JAX Settings ###################################

jax.config.update("jax_platform_name", "gpu")
jax.config.update("jax_enable_x64", True)
print(f"Jax backend: {jax.default_backend()}")

################################## Set File Settings ###################################

SEED = 10001

CORESET_SIZE = 200
INTRINSIC_DIMENSION = 28
NUM_PROJECTION_EPOCHS = 10
NUM_AUTOENCODER_EPOCHS = 50

MAX_CORESET_ITERATIONS = 2500
NUM_CORESET_SEEDS = 10
NUM_PROJECTION_SEEDS = None
ORTHONORMAL = True
PROJECTION_BATCH_SIZE = 1024
NUM_CONSTRUCTIONS = 10
CONVERGENCE_PARAMETER = 1e-8
ITERATIONS_PER_MODEL = 5
M3D_BATCH_SIZE = 128

PROJECTION_RATE = 1e-3
AUTOENCODER_RATE = 1e-3
CORESET_FEATURE_RATE = 1e-2
CORESET_RESPONSE_RATE = 1e-2
M3D_RATE = 1

projection_optimiser = optax.adam(
    learning_rate=optax.constant_schedule(PROJECTION_RATE)
)
autoencoder_optimiser = optax.adam(
    learning_rate=optax.constant_schedule(AUTOENCODER_RATE)
)
coreset_feature_optimiser = optax.adam(
    learning_rate=optax.constant_schedule(CORESET_FEATURE_RATE)
)
coreset_response_optimiser = optax.adam(
    learning_rate=optax.constant_schedule(CORESET_RESPONSE_RATE)
)


##################################### Generate Data ####################################
print("\n------------------------ Loading Data ------------------------\n")
key, subkey = jr.split(jr.PRNGKey(SEED), 2)

DATA_NAME = "mnist"
mnist = fetch_openml("mnist_784")
X = mnist.data.to_numpy()
Y = mnist.target.to_numpy()

Y_CLASSES = jnp.array([int(i) for i in Y])
Y = jax.nn.one_hot(Y_CLASSES, jnp.unique(Y_CLASSES).shape[0])

X = (X - X.min()) / (X.max() - X.min())

X_TEST = X[60000:]
Y_TEST = Y[60000:]
X = X[:60000]
Y = Y[:60000]
Y_CLASSES_TEST = Y_CLASSES[60000:]
Y_CLASSES = Y_CLASSES[:60000]

response_scaler = StandardScaler().fit(Y)
Y = jnp.asarray(response_scaler.transform(Y))
Y_TEST = jnp.asarray(response_scaler.transform(Y_TEST))

###################################### Save Setup ######################################

setup_dict = {}

setup_dict["SETUP"] = {}
setup_dict["SETUP"]["SEED"] = SEED
setup_dict["SETUP"]["DATA_NAME"] = DATA_NAME
setup_dict["SETUP"]["CORESET_SIZE"] = CORESET_SIZE
setup_dict["SETUP"]["INTRINSIC_DIMENSION"] = INTRINSIC_DIMENSION
setup_dict["SETUP"]["NUM_PROJECTION_EPOCHS"] = NUM_PROJECTION_EPOCHS
setup_dict["SETUP"]["NUM_AUTOENCODER_EPOCHS"] = NUM_AUTOENCODER_EPOCHS
setup_dict["SETUP"]["MAX_CORESET_ITERATIONS"] = MAX_CORESET_ITERATIONS
setup_dict["SETUP"]["NUM_CORESET_SEEDS"] = NUM_CORESET_SEEDS
setup_dict["SETUP"]["NUM_PROJECTION_SEEDS"] = NUM_PROJECTION_SEEDS
setup_dict["SETUP"]["ORTHONORMAL"] = ORTHONORMAL
setup_dict["SETUP"]["PROJECTION_BATCH_SIZE"] = PROJECTION_BATCH_SIZE
setup_dict["SETUP"]["AUTOENCODER_BATCH_SIZE"] = PROJECTION_BATCH_SIZE
setup_dict["SETUP"]["CONVERGENCE_PARAMETER"] = CONVERGENCE_PARAMETER
setup_dict["SETUP"]["PROJECTION_RATE"] = PROJECTION_RATE
setup_dict["SETUP"]["AUTOENCODER_RATE"] = AUTOENCODER_RATE
setup_dict["SETUP"]["CORESET_FEATURE_RATE"] = CORESET_FEATURE_RATE
setup_dict["SETUP"]["CORESET_RESPONSE_RATE"] = CORESET_RESPONSE_RATE
setup_dict["SETUP"]["NUM_CONSTRUCTIONS"] = NUM_CONSTRUCTIONS

setup_dict["DATA"] = {}
setup_dict["DATA"]["X"] = X
setup_dict["DATA"]["Y"] = Y
setup_dict["DATA"]["Y_CLASSES"] = Y_CLASSES
setup_dict["DATA"]["X_TEST"] = X_TEST
setup_dict["DATA"]["Y_TEST"] = Y_TEST
setup_dict["DATA"]["Y_CLASSES_TEST"] = Y_CLASSES_TEST

jnp.save(f"./real/{DATA_NAME}/setup_{SEED}_{DATA_NAME}", setup_dict)

################################## Construct Coresets ##################################
reconstruction_kernel = SquaredExponentialKernel(median_heuristic(X[:1000]))
response_kernel = SquaredExponentialKernel(median_heuristic(Y[:1000]))

construct_keys = jr.split(jr.key(SEED), num=(NUM_CONSTRUCTIONS,))


def encoder_generator(random_key):
    """Function which returns a randomly initialised encoder."""
    return Encoder(
        random_key=random_key,
        ambient_dimension=X.shape[1],
        intrinsic_dimension=INTRINSIC_DIMENSION,
        num_hidden_layers=1,
        hidden_layer_sizes=[256],
    )


for i in range(NUM_CONSTRUCTIONS):
    print(f"\n------- {i + 1}/{NUM_CONSTRUCTIONS} -------\n")
    print("BDC-NL...")
    nonlinear_solver = SupervisedNonLinearBilateralDistributionCompression(
        coreset_size=CORESET_SIZE,
        autoencoder=Autoencoder(
            encoder=Encoder(
                random_key=construct_keys[i],
                ambient_dimension=X.shape[1],
                intrinsic_dimension=INTRINSIC_DIMENSION,
                num_hidden_layers=1,
                hidden_layer_sizes=[256],
            ),
            decoder=Decoder(
                random_key=construct_keys[i],
                ambient_dimension=X.shape[1],
                intrinsic_dimension=INTRINSIC_DIMENSION,
                num_hidden_layers=1,
                hidden_layer_sizes=[256],
                output_transformation=jax.nn.sigmoid,
            ),
        ),
        random_key=construct_keys[i],
        reconstruction_kernel=reconstruction_kernel,
        response_kernel=response_kernel,
        compression_kernel="median_heuristic",
        num_autoencoder_epochs=NUM_AUTOENCODER_EPOCHS,
        autoencoder_optimiser=projection_optimiser,
        num_coreset_seeds=NUM_CORESET_SEEDS,
        max_coreset_iterations=MAX_CORESET_ITERATIONS,
        convergence_parameter=CONVERGENCE_PARAMETER,
        coreset_feature_optimiser=coreset_feature_optimiser,
        coreset_response_optimiser=coreset_response_optimiser,
        track_info=False,
        autoencoder_batch_size=PROJECTION_BATCH_SIZE,
        validation_features=X_TEST[:1000],
        validation_responses=Y_TEST[:1000],
    )

    t0 = time()
    nonlinear_coreset, nonlinear_state = nonlinear_solver.reduce(SupervisedData(X, Y))
    nonlinear_timer = time() - t0

    # Save autoencoder ready for deserialisation
    eqx.tree_serialise_leaves(
        f"./real/{DATA_NAME}/{DATA_NAME}_model_{i}.eqx", nonlinear_state.autoencoder
    )

    print("BDC-L...")
    linear_solver = SupervisedLinearBilateralDistributionCompression(
        coreset_size=CORESET_SIZE,
        intrinsic_dimension=INTRINSIC_DIMENSION,
        random_key=construct_keys[i],
        reconstruction_kernel=reconstruction_kernel,
        response_kernel=response_kernel,
        compression_kernel="median_heuristic",
        orthonormal=ORTHONORMAL,
        num_projection_seeds=NUM_PROJECTION_SEEDS,
        num_projection_epochs=NUM_PROJECTION_EPOCHS,
        projection_optimiser=projection_optimiser,
        num_coreset_seeds=NUM_CORESET_SEEDS,
        max_coreset_iterations=MAX_CORESET_ITERATIONS,
        projection_convergence_parameter=CONVERGENCE_PARAMETER,
        coreset_convergence_parameter=CONVERGENCE_PARAMETER,
        coreset_feature_optimiser=coreset_feature_optimiser,
        coreset_response_optimiser=coreset_response_optimiser,
        track_info=False,
        projection_batch_size=PROJECTION_BATCH_SIZE,
        validation_features=X_TEST[:1000],
        validation_responses=Y_TEST[:1000],
    )

    t0 = time()
    linear_coreset, linear_state = linear_solver.reduce(SupervisedData(X, Y))
    linear_timer = time() - t0

    print("ADC...")
    kip_solver = JointKIP(
        coreset_size=CORESET_SIZE,
        random_key=construct_keys[i],
        feature_kernel=reconstruction_kernel,
        response_kernel=response_kernel,
        target_sample_size=None,
        coreset_sample_size=None,
        num_seeds=NUM_CORESET_SEEDS,
        max_iterations=MAX_CORESET_ITERATIONS,
        convergence_parameter=CONVERGENCE_PARAMETER,
        feature_optimiser=coreset_feature_optimiser,
        response_optimiser=coreset_response_optimiser,
        track_info=False,
    )

    t0 = time()
    kip_coreset, _ = kip_solver.reduce(SupervisedData(X, Y))
    kip_timer = time() - t0

    print("M3D...")
    m3d_solver = M3D(
        coreset_size=CORESET_SIZE,
        random_key=construct_keys[i],
        encoder_generator=encoder_generator,
        compression_kernel="median_heuristic",
        max_coreset_iterations=MAX_CORESET_ITERATIONS // ITERATIONS_PER_MODEL,
        coreset_optimiser=optax.sgd(optax.constant_schedule(M3D_RATE)),
        iterations_per_model=ITERATIONS_PER_MODEL,
        batch_size=M3D_BATCH_SIZE,
        track_info=False,
    )

    t0 = time()
    m3d_coreset, _ = m3d_solver.reduce(SupervisedData(X, Y_CLASSES))
    m3d_timer = time() - t0

    ################################# Save the result #################################

    result_dict = {}

    result_dict["RESULTS"] = {}

    result_dict["RESULTS"]["LINEAR_RESULTS"] = {}
    result_dict["RESULTS"]["LINEAR_RESULTS"]["LINEAR_AUTOENCODER"] = (
        linear_state.autoencoder
    )
    result_dict["RESULTS"]["LINEAR_RESULTS"]["LINEAR_CORESET"] = linear_coreset
    result_dict["RESULTS"]["LINEAR_RESULTS"]["LINEAR_TIMES"] = linear_timer

    result_dict["RESULTS"]["NONLINEAR_RESULTS"] = {}
    result_dict["RESULTS"]["NONLINEAR_RESULTS"]["NONLINEAR_CORESET"] = nonlinear_coreset
    result_dict["RESULTS"]["NONLINEAR_RESULTS"]["NONLINEAR_TIME"] = nonlinear_timer

    result_dict["RESULTS"]["KIP_RESULTS"] = {}
    result_dict["RESULTS"]["KIP_RESULTS"]["KIP_CORESET"] = kip_coreset
    result_dict["RESULTS"]["KIP_RESULTS"]["KIP_TIME"] = kip_timer

    result_dict["RESULTS"]["M3D_RESULTS"] = {}
    result_dict["RESULTS"]["M3D_RESULTS"]["M3D_CORESET"] = m3d_coreset
    result_dict["RESULTS"]["M3D_RESULTS"]["M3D_TIME"] = m3d_timer

    jnp.save(f"./real/{DATA_NAME}/results_{i}_{SEED}_{DATA_NAME}", result_dict)
