"""Construct Bilateral Coresets."""

#################################### Import Modules ####################################

from time import time
import jax
import jax.random as jr
import jax.numpy as jnp
import equinox as eqx
import pandas as pd
import optax
from sklearn.preprocessing import StandardScaler

from coreax.solvers import (
    SupervisedLinearBilateralDistributionCompression,
    SupervisedNonLinearBilateralDistributionCompression,
)
from coreax.data import SupervisedData
from coreax.kernels import SquaredExponentialKernel, median_heuristic
from coreax.solvers.autoencoders import Autoencoder, Encoder, Decoder
################################### Set JAX Settings ###################################

jax.config.update("jax_platform_name", "gpu")
jax.config.update("jax_enable_x64", True)
print(f"Jax backend: {jax.default_backend()}")

################################## Set File Settings ###################################

SEED = 10001

INTRINSIC_DIMENSIONS = [5, 10, 15, 20, 25, 30, 35, 40, 45, 50]
CORESET_SIZE = 200
NUM_PROJECTION_EPOCHS = 10
NUM_AUTOENCODER_EPOCHS = 50

MAX_CORESET_ITERATIONS = 5000
NUM_CORESET_SEEDS = 10
NUM_PROJECTION_SEEDS = None
ORTHONORMAL = True
PROJECTION_BATCH_SIZE = 1024
NUM_CONSTRUCTIONS = 25
CONVERGENCE_PARAMETER = 1e-8

PROJECTION_RATE = 1e-3
AUTOENCODER_RATE = 1e-3
CORESET_FEATURE_RATE = 1e-2
CORESET_RESPONSE_RATE = 1e-2

projection_optimiser = optax.adam(
    learning_rate=optax.constant_schedule(PROJECTION_RATE)
)
autoencoder_optimiser = optax.adam(
    learning_rate=optax.constant_schedule(AUTOENCODER_RATE)
)
coreset_feature_optimiser = optax.adam(
    learning_rate=optax.constant_schedule(CORESET_FEATURE_RATE)
)
coreset_response_optimiser = optax.adam(
    learning_rate=optax.constant_schedule(CORESET_RESPONSE_RATE)
)


##################################### Generate Data ####################################
print("\n------------------------ Loading Data ------------------------\n")
key, subkey = jr.split(jr.PRNGKey(SEED), 2)

DATA_NAME = "ct_slice_latent"
data = pd.read_csv(f"./real/{DATA_NAME}/ct_slice.csv")
D = jnp.array(data.to_numpy())

# Drop first column as it is an index
D = D[:, 1:]

# Extract features and responses
X = D[:, :-1]
Y = D[:, [-1]]

# Shuffle the data
shuffle_idcs = jr.choice(
    jr.key(SEED), jnp.arange(X.shape[0]), shape=(X.shape[0],), replace=False
)
X = X[shuffle_idcs]
Y = Y[shuffle_idcs]

# Split test and train
X_TEST = X[50000:]
Y_TEST = Y[50000:]
X = X[:50000]
Y = Y[:50000]

# Scale the features and responses
feature_scaler = StandardScaler().fit(X)
X = jnp.asarray(feature_scaler.transform(X))
X_TEST = jnp.asarray(feature_scaler.transform(X_TEST))

response_scalaer = StandardScaler().fit(Y)
Y = jnp.asarray(response_scalaer.transform(Y))
Y_TEST = jnp.asarray(response_scalaer.transform(Y_TEST))

###################################### Save Setup ######################################

setup_dict = {}

setup_dict["SETUP"] = {}
setup_dict["SETUP"]["SEED"] = SEED
setup_dict["SETUP"]["DATA_NAME"] = DATA_NAME
setup_dict["SETUP"]["CORESET_SIZE"] = CORESET_SIZE
setup_dict["SETUP"]["INTRINSIC_DIMENSIONS"] = INTRINSIC_DIMENSIONS
setup_dict["SETUP"]["NUM_PROJECTION_EPOCHS"] = NUM_PROJECTION_EPOCHS
setup_dict["SETUP"]["NUM_AUTOENCODER_EPOCHS"] = NUM_AUTOENCODER_EPOCHS
setup_dict["SETUP"]["MAX_CORESET_ITERATIONS"] = MAX_CORESET_ITERATIONS
setup_dict["SETUP"]["NUM_CORESET_SEEDS"] = NUM_CORESET_SEEDS
setup_dict["SETUP"]["NUM_PROJECTION_SEEDS"] = NUM_PROJECTION_SEEDS
setup_dict["SETUP"]["ORTHONORMAL"] = ORTHONORMAL
setup_dict["SETUP"]["PROJECTION_BATCH_SIZE"] = PROJECTION_BATCH_SIZE
setup_dict["SETUP"]["AUTOENCODER_BATCH_SIZE"] = PROJECTION_BATCH_SIZE
setup_dict["SETUP"]["CONVERGENCE_PARAMETER"] = CONVERGENCE_PARAMETER
setup_dict["SETUP"]["PROJECTION_RATE"] = PROJECTION_RATE
setup_dict["SETUP"]["AUTOENCODER_RATE"] = AUTOENCODER_RATE
setup_dict["SETUP"]["CORESET_FEATURE_RATE"] = CORESET_FEATURE_RATE
setup_dict["SETUP"]["CORESET_RESPONSE_RATE"] = CORESET_RESPONSE_RATE
setup_dict["SETUP"]["NUM_CONSTRUCTIONS"] = NUM_CONSTRUCTIONS

setup_dict["DATA"] = {}
setup_dict["DATA"]["X"] = X
setup_dict["DATA"]["Y"] = Y
setup_dict["DATA"]["X_TEST"] = X_TEST
setup_dict["DATA"]["Y_TEST"] = Y_TEST

jnp.save(f"./real/{DATA_NAME}/setup_{SEED}_{DATA_NAME}", setup_dict)

################################## Construct Coresets ##################################
reconstruction_kernel = SquaredExponentialKernel(median_heuristic(X[:1000]))
response_kernel = SquaredExponentialKernel(median_heuristic(Y[:1000]))

construct_keys = jr.split(jr.key(SEED), num=(NUM_CONSTRUCTIONS,))

for i in range(NUM_CONSTRUCTIONS):
    for p in INTRINSIC_DIMENSIONS:
        print(
            f"\n------- {i + 1}/{NUM_CONSTRUCTIONS}, {p} dimensions retained -------\n"
        )
        print("BDC-NL...")
        nonlinear_solver = SupervisedNonLinearBilateralDistributionCompression(
            coreset_size=CORESET_SIZE,
            autoencoder=Autoencoder(
                encoder=Encoder(
                    random_key=construct_keys[i],
                    ambient_dimension=X.shape[1],
                    intrinsic_dimension=p,
                    num_hidden_layers=1,
                    hidden_layer_sizes=[128],
                ),
                decoder=Decoder(
                    random_key=construct_keys[i],
                    ambient_dimension=X.shape[1],
                    intrinsic_dimension=p,
                    num_hidden_layers=1,
                    hidden_layer_sizes=[128],
                ),
            ),
            random_key=construct_keys[i],
            reconstruction_kernel=reconstruction_kernel,
            response_kernel=response_kernel,
            compression_kernel="median_heuristic",
            num_autoencoder_epochs=NUM_AUTOENCODER_EPOCHS,
            autoencoder_optimiser=projection_optimiser,
            num_coreset_seeds=NUM_CORESET_SEEDS,
            max_coreset_iterations=MAX_CORESET_ITERATIONS,
            convergence_parameter=CONVERGENCE_PARAMETER,
            coreset_feature_optimiser=coreset_feature_optimiser,
            coreset_response_optimiser=coreset_response_optimiser,
            track_info=False,
            autoencoder_batch_size=PROJECTION_BATCH_SIZE,
            validation_features=X_TEST,
            validation_responses=Y_TEST,
        )

        t0 = time()
        nonlinear_coreset, nonlinear_state = nonlinear_solver.reduce(
            SupervisedData(X, Y)
        )
        nonlinear_timer = time() - t0

        # Save autoencoder ready for deserialisation
        eqx.tree_serialise_leaves(
            f"./real/{DATA_NAME}/{DATA_NAME}_model_{i}_{p}.eqx",
            nonlinear_state.autoencoder,
        )

        print("BDC-L...")
        linear_solver = SupervisedLinearBilateralDistributionCompression(
            coreset_size=CORESET_SIZE,
            intrinsic_dimension=p,
            random_key=construct_keys[i],
            reconstruction_kernel=reconstruction_kernel,
            response_kernel=response_kernel,
            compression_kernel="median_heuristic",
            orthonormal=ORTHONORMAL,
            num_projection_seeds=NUM_PROJECTION_SEEDS,
            num_projection_epochs=NUM_PROJECTION_EPOCHS,
            projection_optimiser=projection_optimiser,
            num_coreset_seeds=NUM_CORESET_SEEDS,
            max_coreset_iterations=MAX_CORESET_ITERATIONS,
            projection_convergence_parameter=CONVERGENCE_PARAMETER,
            coreset_convergence_parameter=CONVERGENCE_PARAMETER,
            coreset_feature_optimiser=coreset_feature_optimiser,
            coreset_response_optimiser=coreset_response_optimiser,
            track_info=False,
            projection_batch_size=PROJECTION_BATCH_SIZE,
            validation_features=X_TEST,
            validation_responses=Y_TEST,
        )

        t0 = time()
        linear_coreset, linear_state = linear_solver.reduce(SupervisedData(X, Y))
        linear_timer = time() - t0

        ################################# Save the result #################################

        result_dict = {}

        result_dict["RESULTS"] = {}

        result_dict["RESULTS"]["LINEAR_RESULTS"] = {}
        result_dict["RESULTS"]["LINEAR_RESULTS"]["LINEAR_AUTOENCODER"] = (
            linear_state.autoencoder
        )
        result_dict["RESULTS"]["LINEAR_RESULTS"]["LINEAR_CORESET"] = linear_coreset
        result_dict["RESULTS"]["LINEAR_RESULTS"]["LINEAR_TIMES"] = linear_timer

        result_dict["RESULTS"]["NONLINEAR_RESULTS"] = {}
        result_dict["RESULTS"]["NONLINEAR_RESULTS"]["NONLINEAR_CORESET"] = (
            nonlinear_coreset
        )
        result_dict["RESULTS"]["NONLINEAR_RESULTS"]["NONLINEAR_TIME"] = nonlinear_timer

        jnp.save(f"./real/{DATA_NAME}/results_{i}_{p}_{SEED}_{DATA_NAME}", result_dict)
