"""Construct Bilateral Coresets."""

#################################### Import Modules ####################################

from time import time
import jax
import jax.random as jr
import jax.numpy as jnp
import equinox as eqx
import optax
from jax import vmap
from sklearn.preprocessing import StandardScaler

from coreax.solvers import (
    SupervisedLinearBilateralDistributionCompression,
    SupervisedNonLinearBilateralDistributionCompression,
    JointKIP,
)
from coreax.data import SupervisedData
from coreax.kernels import SquaredExponentialKernel, median_heuristic
from coreax.solvers.autoencoders import Autoencoder, Encoder, Decoder
################################### Set JAX Settings ###################################

jax.config.update("jax_platform_name", "gpu")
jax.config.update("jax_enable_x64", True)
print(f"Jax backend: {jax.default_backend()}")

################################## Set File Settings ###################################

SEED = 10001

DATA_SIZE = 20000
TEST_SIZE = 1000
CORESET_SIZE = 200
AMBIENT_DIMENSION = 200
INTRINSIC_DIMENSION = 3
NUM_PROJECTION_EPOCHS = 5
NUM_AUTOENCODER_EPOCHS = 5

MAX_CORESET_ITERATIONS = 5000
NUM_CORESET_SEEDS = 1
NUM_PROJECTION_SEEDS = 1
ORTHONORMAL = True
PROJECTION_BATCH_SIZE = 64
NUM_CONSTRUCTIONS = 25
CONVERGENCE_PARAMETER = 1e-8

PROJECTION_RATE = 1e-2
AUTOENCODER_RATE = 1e-2
CORESET_FEATURE_RATE = 1e-2
CORESET_RESPONSE_RATE = 1e-2

projection_optimiser = optax.adam(
    learning_rate=optax.constant_schedule(PROJECTION_RATE)
)
autoencoder_optimiser = optax.adam(
    learning_rate=optax.constant_schedule(AUTOENCODER_RATE)
)
coreset_feature_optimiser = optax.adam(
    learning_rate=optax.constant_schedule(CORESET_FEATURE_RATE)
)
coreset_response_optimiser = optax.adam(
    learning_rate=optax.constant_schedule(CORESET_RESPONSE_RATE)
)


##################################### Generate Data ####################################
print("\n------------------------ Loading Data ------------------------\n")
key_u, key_v, key_noise, key_eps, key_proj = jr.split(jr.PRNGKey(SEED), 5)

DATA_NAME = "swiss_roll_nonlinear"
GEN_SIZE = DATA_SIZE + TEST_SIZE

U = jr.uniform(key_u, (GEN_SIZE,), minval=1.5 * jnp.pi, maxval=4.5 * jnp.pi)
V = jr.uniform(key_v, (GEN_SIZE,), minval=0.0, maxval=20.0)
X_INTRINSIC = jnp.stack([U * jnp.cos(U), V, U * jnp.sin(U)], axis=1) + 0.5 * jr.normal(
    key_noise, shape=(GEN_SIZE, 3)
)
true_decoder = Decoder(
    random_key=jr.key(2 * SEED),
    ambient_dimension=AMBIENT_DIMENSION,
    intrinsic_dimension=INTRINSIC_DIMENSION,
    num_hidden_layers=2,
    hidden_layer_sizes=[64, 128],
)
eqx.tree_serialise_leaves(
    f"./synthetic/{DATA_NAME}/{DATA_NAME}_true_decoder.eqx", true_decoder
)

features = vmap(true_decoder)(X_INTRINSIC)

u = U / (3.0 * jnp.pi)  # ∈ [0.5, 1.5]
v = V / 20.0  # ∈ [0, 1]
f0 = 4.0 * (u - 0.5) ** 2 + jnp.pi * v
responses = (f0 + 0.1 * jr.normal(key_eps, (GEN_SIZE,))).reshape(-1, 1)

X = features[:DATA_SIZE]
Y = responses[:DATA_SIZE]
X_TEST = features[DATA_SIZE:]
Y_TEST = responses[DATA_SIZE:]
del features, responses

feature_scaler = StandardScaler().fit(X)
X = jnp.asarray(feature_scaler.transform(X))
X_TEST = jnp.asarray(feature_scaler.transform(X_TEST))

response_scaler = StandardScaler().fit(Y)
Y = jnp.asarray(response_scaler.transform(Y))
Y_TEST = jnp.asarray(response_scaler.transform(Y_TEST))

###################################### Save Setup ######################################

setup_dict = {}

setup_dict["SETUP"] = {}
setup_dict["SETUP"]["SEED"] = SEED
setup_dict["SETUP"]["DATA_NAME"] = DATA_NAME
setup_dict["SETUP"]["CORESET_SIZE"] = CORESET_SIZE
setup_dict["SETUP"]["INTRINSIC_DIMENSION"] = INTRINSIC_DIMENSION
setup_dict["SETUP"]["NUM_PROJECTION_EPOCHS"] = NUM_PROJECTION_EPOCHS
setup_dict["SETUP"]["NUM_AUTOENCODER_EPOCHS"] = NUM_AUTOENCODER_EPOCHS
setup_dict["SETUP"]["MAX_CORESET_ITERATIONS"] = MAX_CORESET_ITERATIONS
setup_dict["SETUP"]["NUM_CORESET_SEEDS"] = NUM_CORESET_SEEDS
setup_dict["SETUP"]["NUM_PROJECTION_SEEDS"] = NUM_PROJECTION_SEEDS
setup_dict["SETUP"]["ORTHONORMAL"] = ORTHONORMAL
setup_dict["SETUP"]["PROJECTION_BATCH_SIZE"] = PROJECTION_BATCH_SIZE
setup_dict["SETUP"]["AUTOENCODER_BATCH_SIZE"] = PROJECTION_BATCH_SIZE
setup_dict["SETUP"]["CONVERGENCE_PARAMETER"] = CONVERGENCE_PARAMETER
setup_dict["SETUP"]["PROJECTION_RATE"] = PROJECTION_RATE
setup_dict["SETUP"]["AUTOENCODER_RATE"] = AUTOENCODER_RATE
setup_dict["SETUP"]["CORESET_FEATURE_RATE"] = CORESET_FEATURE_RATE
setup_dict["SETUP"]["CORESET_RESPONSE_RATE"] = CORESET_RESPONSE_RATE
setup_dict["SETUP"]["NUM_CONSTRUCTIONS"] = NUM_CONSTRUCTIONS

setup_dict["DATA"] = {}
setup_dict["DATA"]["X"] = X
setup_dict["DATA"]["Y"] = Y
setup_dict["DATA"]["X_TEST"] = X_TEST
setup_dict["DATA"]["Y_TEST"] = Y_TEST

jnp.save(f"./synthetic/{DATA_NAME}/setup_{SEED}_{DATA_NAME}", setup_dict)

################################## Construct Coresets ##################################
reconstruction_kernel = SquaredExponentialKernel(median_heuristic(X[:1000]))
response_kernel = SquaredExponentialKernel(median_heuristic(Y[:1000]))

construct_keys = jr.split(jr.key(SEED), num=(NUM_CONSTRUCTIONS,))

for i in range(NUM_CONSTRUCTIONS):
    print(f"\n------- {i + 1}/{NUM_CONSTRUCTIONS} -------\n")
    print("BDC-NL...")
    nonlinear_solver = SupervisedNonLinearBilateralDistributionCompression(
        coreset_size=CORESET_SIZE,
        autoencoder=Autoencoder(
            encoder=Encoder(
                random_key=construct_keys[i],
                ambient_dimension=X.shape[1],
                intrinsic_dimension=INTRINSIC_DIMENSION,
                num_hidden_layers=2,
                hidden_layer_sizes=[128, 64],
            ),
            decoder=Decoder(
                random_key=construct_keys[i],
                ambient_dimension=X.shape[1],
                intrinsic_dimension=INTRINSIC_DIMENSION,
                num_hidden_layers=2,
                hidden_layer_sizes=[64, 128],
            ),
        ),
        random_key=construct_keys[i],
        reconstruction_kernel=reconstruction_kernel,
        response_kernel=response_kernel,
        compression_kernel="median_heuristic",
        num_autoencoder_epochs=NUM_AUTOENCODER_EPOCHS,
        autoencoder_optimiser=projection_optimiser,
        num_coreset_seeds=NUM_CORESET_SEEDS,
        max_coreset_iterations=MAX_CORESET_ITERATIONS,
        convergence_parameter=CONVERGENCE_PARAMETER,
        coreset_feature_optimiser=coreset_feature_optimiser,
        coreset_response_optimiser=coreset_response_optimiser,
        track_info=False,
        autoencoder_batch_size=PROJECTION_BATCH_SIZE,
        validation_features=X_TEST,
        validation_responses=Y_TEST,
    )

    t0 = time()
    nonlinear_coreset, nonlinear_state = nonlinear_solver.reduce(SupervisedData(X, Y))
    nonlinear_timer = time() - t0

    # Save autoencoder ready for deserialisation
    eqx.tree_serialise_leaves(
        f"./synthetic/{DATA_NAME}/{DATA_NAME}_model_{i}.eqx",
        nonlinear_state.autoencoder,
    )

    print("BDC-L...")
    linear_solver = SupervisedLinearBilateralDistributionCompression(
        coreset_size=CORESET_SIZE,
        intrinsic_dimension=INTRINSIC_DIMENSION,
        random_key=construct_keys[i],
        reconstruction_kernel=reconstruction_kernel,
        response_kernel=response_kernel,
        compression_kernel="median_heuristic",
        orthonormal=ORTHONORMAL,
        num_projection_seeds=NUM_PROJECTION_SEEDS,
        num_projection_epochs=NUM_PROJECTION_EPOCHS,
        projection_optimiser=projection_optimiser,
        num_coreset_seeds=NUM_CORESET_SEEDS,
        max_coreset_iterations=MAX_CORESET_ITERATIONS,
        projection_convergence_parameter=CONVERGENCE_PARAMETER,
        coreset_convergence_parameter=CONVERGENCE_PARAMETER,
        coreset_feature_optimiser=coreset_feature_optimiser,
        coreset_response_optimiser=coreset_response_optimiser,
        track_info=False,
        projection_batch_size=PROJECTION_BATCH_SIZE,
        validation_features=X_TEST,
        validation_responses=Y_TEST,
    )

    t0 = time()
    linear_coreset, linear_state = linear_solver.reduce(SupervisedData(X, Y))
    linear_timer = time() - t0

    print("ADC...")
    kip_solver = JointKIP(
        coreset_size=CORESET_SIZE,
        random_key=construct_keys[i],
        feature_kernel=reconstruction_kernel,
        response_kernel=response_kernel,
        target_sample_size=None,
        coreset_sample_size=None,
        num_seeds=NUM_CORESET_SEEDS,
        max_iterations=MAX_CORESET_ITERATIONS,
        convergence_parameter=CONVERGENCE_PARAMETER,
        feature_optimiser=coreset_feature_optimiser,
        response_optimiser=coreset_response_optimiser,
        track_info=False,
    )

    t0 = time()
    kip_coreset, _ = kip_solver.reduce(SupervisedData(X, Y))
    kip_timer = time() - t0

    ################################# Save the result #################################

    result_dict = {}

    result_dict["RESULTS"] = {}

    result_dict["RESULTS"]["LINEAR_RESULTS"] = {}
    result_dict["RESULTS"]["LINEAR_RESULTS"]["LINEAR_AUTOENCODER"] = (
        linear_state.autoencoder
    )
    result_dict["RESULTS"]["LINEAR_RESULTS"]["LINEAR_CORESET"] = linear_coreset
    result_dict["RESULTS"]["LINEAR_RESULTS"]["LINEAR_TIMES"] = linear_timer

    result_dict["RESULTS"]["NONLINEAR_RESULTS"] = {}
    result_dict["RESULTS"]["NONLINEAR_RESULTS"]["NONLINEAR_CORESET"] = nonlinear_coreset
    result_dict["RESULTS"]["NONLINEAR_RESULTS"]["NONLINEAR_TIME"] = nonlinear_timer

    result_dict["RESULTS"]["KIP_RESULTS"] = {}
    result_dict["RESULTS"]["KIP_RESULTS"]["KIP_CORESET"] = kip_coreset
    result_dict["RESULTS"]["KIP_RESULTS"]["KIP_TIME"] = kip_timer

    jnp.save(f"./synthetic/{DATA_NAME}/results_{i}_{SEED}_{DATA_NAME}", result_dict)
