"""Construct Bilateral Coresets."""

#################################### Import Modules ####################################

from time import time
import jax
import jax.random as jr
import jax.numpy as jnp
import equinox as eqx
import optax
from sklearn.preprocessing import StandardScaler
from sklearn.datasets import fetch_openml

from coreax.solvers import (
    SupervisedLinearBilateralDistributionCompression,
    SupervisedNonLinearBilateralDistributionCompression,
)
from coreax.data import SupervisedData
from coreax.kernels import SquaredExponentialKernel, median_heuristic
from coreax.solvers.autoencoders import Autoencoder, Encoder, Decoder
################################### Set JAX Settings ###################################

jax.config.update("jax_platform_name", "gpu")
jax.config.update("jax_enable_x64", True)
print(f"Jax backend: {jax.default_backend()}")

################################## Set File Settings ###################################

SEED = 10001

INTRINSIC_DIMENSIONS = [2, 5, 10, 15, 20, 25, 30, 40, 50]
CORESET_SIZE = 200
NUM_PROJECTION_EPOCHS = 10
NUM_AUTOENCODER_EPOCHS = 50

MAX_CORESET_ITERATIONS = 5000
NUM_CORESET_SEEDS = 10
NUM_PROJECTION_SEEDS = None
ORTHONORMAL = True
PROJECTION_BATCH_SIZE = 1024
NUM_CONSTRUCTIONS = 10
CONVERGENCE_PARAMETER = 1e-8

PROJECTION_RATE = 1e-3
AUTOENCODER_RATE = 1e-3
CORESET_FEATURE_RATE = 1e-2
CORESET_RESPONSE_RATE = 1e-2

projection_optimiser = optax.adam(
    learning_rate=optax.constant_schedule(PROJECTION_RATE)
)
autoencoder_optimiser = optax.adam(
    learning_rate=optax.constant_schedule(AUTOENCODER_RATE)
)
coreset_feature_optimiser = optax.adam(
    learning_rate=optax.constant_schedule(CORESET_FEATURE_RATE)
)
coreset_response_optimiser = optax.adam(
    learning_rate=optax.constant_schedule(CORESET_RESPONSE_RATE)
)


##################################### Generate Data ####################################
print("\n------------------------ Loading Data ------------------------\n")
key, subkey = jr.split(jr.PRNGKey(SEED), 2)

DATA_NAME = "mnist_latent"
mnist = fetch_openml("mnist_784")
X = mnist.data.to_numpy()
Y = mnist.target.to_numpy()

Y_CLASSES = jnp.array([int(i) for i in Y])
Y = jax.nn.one_hot(Y_CLASSES, jnp.unique(Y_CLASSES).shape[0])

X = (X - X.min()) / (X.max() - X.min())

X_TEST = X[60000:]
Y_TEST = Y[60000:]
X = X[:60000]
Y = Y[:60000]
Y_CLASSES_TEST = Y_CLASSES[60000:]
Y_CLASSES = Y_CLASSES[:60000]

response_scaler = StandardScaler().fit(Y)
Y = jnp.asarray(response_scaler.transform(Y))
Y_TEST = jnp.asarray(response_scaler.transform(Y_TEST))

###################################### Save Setup ######################################

setup_dict = {}

setup_dict["SETUP"] = {}
setup_dict["SETUP"]["SEED"] = SEED
setup_dict["SETUP"]["DATA_NAME"] = DATA_NAME
setup_dict["SETUP"]["CORESET_SIZE"] = CORESET_SIZE
setup_dict["SETUP"]["INTRINSIC_DIMENSIONS"] = INTRINSIC_DIMENSIONS
setup_dict["SETUP"]["NUM_PROJECTION_EPOCHS"] = NUM_PROJECTION_EPOCHS
setup_dict["SETUP"]["NUM_AUTOENCODER_EPOCHS"] = NUM_AUTOENCODER_EPOCHS
setup_dict["SETUP"]["MAX_CORESET_ITERATIONS"] = MAX_CORESET_ITERATIONS
setup_dict["SETUP"]["NUM_CORESET_SEEDS"] = NUM_CORESET_SEEDS
setup_dict["SETUP"]["NUM_PROJECTION_SEEDS"] = NUM_PROJECTION_SEEDS
setup_dict["SETUP"]["ORTHONORMAL"] = ORTHONORMAL
setup_dict["SETUP"]["PROJECTION_BATCH_SIZE"] = PROJECTION_BATCH_SIZE
setup_dict["SETUP"]["AUTOENCODER_BATCH_SIZE"] = PROJECTION_BATCH_SIZE
setup_dict["SETUP"]["CONVERGENCE_PARAMETER"] = CONVERGENCE_PARAMETER
setup_dict["SETUP"]["PROJECTION_RATE"] = PROJECTION_RATE
setup_dict["SETUP"]["AUTOENCODER_RATE"] = AUTOENCODER_RATE
setup_dict["SETUP"]["CORESET_FEATURE_RATE"] = CORESET_FEATURE_RATE
setup_dict["SETUP"]["CORESET_RESPONSE_RATE"] = CORESET_RESPONSE_RATE
setup_dict["SETUP"]["NUM_CONSTRUCTIONS"] = NUM_CONSTRUCTIONS

setup_dict["DATA"] = {}
setup_dict["DATA"]["X"] = X
setup_dict["DATA"]["Y"] = Y
setup_dict["DATA"]["Y_CLASSES"] = Y_CLASSES
setup_dict["DATA"]["X_TEST"] = X_TEST
setup_dict["DATA"]["Y_TEST"] = Y_TEST
setup_dict["DATA"]["Y_CLASSES_TEST"] = Y_CLASSES_TEST

jnp.save(f"./real/{DATA_NAME}/setup_{SEED}_{DATA_NAME}", setup_dict)

################################## Construct Coresets ##################################
reconstruction_kernel = SquaredExponentialKernel(median_heuristic(X[:1000]))
response_kernel = SquaredExponentialKernel(median_heuristic(Y[:1000]))

construct_keys = jr.split(jr.key(SEED), num=(NUM_CONSTRUCTIONS,))

for i in range(NUM_CONSTRUCTIONS):
    for p in INTRINSIC_DIMENSIONS:
        print(
            f"\n------- {i + 1}/{NUM_CONSTRUCTIONS}, {p} dimensions retained -------\n"
        )
        print("BDC-NL...")
        nonlinear_solver = SupervisedNonLinearBilateralDistributionCompression(
            coreset_size=CORESET_SIZE,
            autoencoder=Autoencoder(
                encoder=Encoder(
                    random_key=construct_keys[i],
                    ambient_dimension=X.shape[1],
                    intrinsic_dimension=p,
                    num_hidden_layers=1,
                    hidden_layer_sizes=[256],
                ),
                decoder=Decoder(
                    random_key=construct_keys[i],
                    ambient_dimension=X.shape[1],
                    intrinsic_dimension=p,
                    num_hidden_layers=1,
                    hidden_layer_sizes=[256],
                    output_transformation=jax.nn.sigmoid,
                ),
            ),
            random_key=construct_keys[i],
            reconstruction_kernel=reconstruction_kernel,
            response_kernel=response_kernel,
            compression_kernel="median_heuristic",
            num_autoencoder_epochs=NUM_AUTOENCODER_EPOCHS,
            autoencoder_optimiser=projection_optimiser,
            num_coreset_seeds=NUM_CORESET_SEEDS,
            max_coreset_iterations=MAX_CORESET_ITERATIONS,
            convergence_parameter=CONVERGENCE_PARAMETER,
            coreset_feature_optimiser=coreset_feature_optimiser,
            coreset_response_optimiser=coreset_response_optimiser,
            track_info=False,
            autoencoder_batch_size=PROJECTION_BATCH_SIZE,
            validation_features=X_TEST,
            validation_responses=Y_TEST,
        )

        t0 = time()
        nonlinear_coreset, nonlinear_state = nonlinear_solver.reduce(
            SupervisedData(X, Y)
        )
        nonlinear_timer = time() - t0

        # Save autoencoder ready for deserialisation
        eqx.tree_serialise_leaves(
            f"./real/{DATA_NAME}/{DATA_NAME}_model_{i}_{p}.eqx",
            nonlinear_state.autoencoder,
        )

        print("BDC-L...")
        linear_solver = SupervisedLinearBilateralDistributionCompression(
            coreset_size=CORESET_SIZE,
            intrinsic_dimension=p,
            random_key=construct_keys[i],
            reconstruction_kernel=reconstruction_kernel,
            response_kernel=response_kernel,
            compression_kernel="median_heuristic",
            orthonormal=ORTHONORMAL,
            num_projection_seeds=NUM_PROJECTION_SEEDS,
            num_projection_epochs=NUM_PROJECTION_EPOCHS,
            projection_optimiser=projection_optimiser,
            num_coreset_seeds=NUM_CORESET_SEEDS,
            max_coreset_iterations=MAX_CORESET_ITERATIONS,
            projection_convergence_parameter=CONVERGENCE_PARAMETER,
            coreset_convergence_parameter=CONVERGENCE_PARAMETER,
            coreset_feature_optimiser=coreset_feature_optimiser,
            coreset_response_optimiser=coreset_response_optimiser,
            track_info=False,
            projection_batch_size=PROJECTION_BATCH_SIZE,
            validation_features=X_TEST,
            validation_responses=Y_TEST,
        )

        t0 = time()
        linear_coreset, linear_state = linear_solver.reduce(SupervisedData(X, Y))
        linear_timer = time() - t0

        ################################# Save the result #################################

        result_dict = {}

        result_dict["RESULTS"] = {}

        result_dict["RESULTS"]["LINEAR_RESULTS"] = {}
        result_dict["RESULTS"]["LINEAR_RESULTS"]["LINEAR_AUTOENCODER"] = (
            linear_state.autoencoder
        )
        result_dict["RESULTS"]["LINEAR_RESULTS"]["LINEAR_CORESET"] = linear_coreset
        result_dict["RESULTS"]["LINEAR_RESULTS"]["LINEAR_TIMES"] = linear_timer

        result_dict["RESULTS"]["NONLINEAR_RESULTS"] = {}
        result_dict["RESULTS"]["NONLINEAR_RESULTS"]["NONLINEAR_CORESET"] = (
            nonlinear_coreset
        )
        result_dict["RESULTS"]["NONLINEAR_RESULTS"]["NONLINEAR_TIME"] = nonlinear_timer

        jnp.save(f"./real/{DATA_NAME}/results_{i}_{p}_{SEED}_{DATA_NAME}", result_dict)
