"""Construct Bilateral Coresets."""

#################################### Import Modules ####################################

from time import time
import jax
import jax.random as jr
import jax.numpy as jnp
import equinox as eqx
import optax

from sklearn.preprocessing import StandardScaler

from coreax.solvers import (
    SupervisedLinearBilateralDistributionCompression,
    SupervisedNonLinearBilateralDistributionCompression,
    JointKIP,
)
from coreax.data import SupervisedData
from coreax.kernels import PCIMQKernel, median_heuristic
from coreax.solvers.autoencoders import Autoencoder, Encoder, Decoder
################################### Set JAX Settings ###################################

jax.config.update("jax_platform_name", "gpu")
jax.config.update("jax_enable_x64", True)
print(f"Jax backend: {jax.default_backend()}")

################################## Set File Settings ###################################

SEED = 10001

DATA_SIZE = 20000
TEST_SIZE = 1000
CORESET_SIZE = 200
AMBIENT_DIMENSION = 200
INTRINSIC_DIMENSION = 3
NUM_PROJECTION_EPOCHS = 3
NUM_AUTOENCODER_EPOCHS = 3

MAX_CORESET_ITERATIONS = 5000
NUM_CORESET_SEEDS = 1
NUM_PROJECTION_SEEDS = 1
ORTHONORMAL = True
PROJECTION_BATCH_SIZE = 64
NUM_CONSTRUCTIONS = 25
CONVERGENCE_PARAMETER = 1e-8

PROJECTION_RATE = 1e-2
AUTOENCODER_RATE = 1e-2
CORESET_FEATURE_RATE = 1e-2
CORESET_RESPONSE_RATE = 1e-2

projection_optimiser = optax.adam(
    learning_rate=optax.constant_schedule(PROJECTION_RATE)
)
autoencoder_optimiser = optax.adam(
    learning_rate=optax.constant_schedule(AUTOENCODER_RATE)
)
coreset_feature_optimiser = optax.adam(
    learning_rate=optax.constant_schedule(CORESET_FEATURE_RATE)
)
coreset_response_optimiser = optax.adam(
    learning_rate=optax.constant_schedule(CORESET_RESPONSE_RATE)
)


##################################### Generate Data ####################################
print("\n------------------------ Loading Data ------------------------\n")
key_u, key_v, key_noise, key_eps, key_proj = jr.split(jr.PRNGKey(SEED), 5)

DATA_NAME = "swiss_roll_linear_imq"
GEN_SIZE = DATA_SIZE + TEST_SIZE

U = jr.uniform(key_u, (GEN_SIZE,), minval=1.5 * jnp.pi, maxval=4.5 * jnp.pi)
V = jr.uniform(key_v, (GEN_SIZE,), minval=0.0, maxval=20.0)
X_INTRINSIC = jnp.stack([U * jnp.cos(U), V, U * jnp.sin(U)], axis=1) + 0.5 * jr.normal(
    key_noise, shape=(GEN_SIZE, 3)
)
V_TRUE = jr.normal(key_proj, (AMBIENT_DIMENSION, INTRINSIC_DIMENSION))
features = X_INTRINSIC @ V_TRUE.T

u = U / (3.0 * jnp.pi)  # ∈ [0.5, 1.5]
v = V / 20.0  # ∈ [0, 1]
f0 = 4.0 * (u - 0.5) ** 2 + jnp.pi * v
responses = (f0 + 0.1 * jr.normal(key_eps, (GEN_SIZE,))).reshape(-1, 1)

X = features[:DATA_SIZE]
Y = responses[:DATA_SIZE]
X_TEST = features[DATA_SIZE:]
Y_TEST = responses[DATA_SIZE:]
del features, responses

feature_scaler = StandardScaler().fit(X)
X = jnp.asarray(feature_scaler.transform(X))
X_TEST = jnp.asarray(feature_scaler.transform(X_TEST))

response_scaler = StandardScaler().fit(Y)
Y = jnp.asarray(response_scaler.transform(Y))
Y_TEST = jnp.asarray(response_scaler.transform(Y_TEST))


###################################### Save Setup ######################################

setup_dict = {}

setup_dict["SETUP"] = {}
setup_dict["SETUP"]["SEED"] = SEED
setup_dict["SETUP"]["DATA_NAME"] = DATA_NAME
setup_dict["SETUP"]["CORESET_SIZE"] = CORESET_SIZE
setup_dict["SETUP"]["INTRINSIC_DIMENSION"] = INTRINSIC_DIMENSION
setup_dict["SETUP"]["NUM_PROJECTION_EPOCHS"] = NUM_PROJECTION_EPOCHS
setup_dict["SETUP"]["NUM_AUTOENCODER_EPOCHS"] = NUM_AUTOENCODER_EPOCHS
setup_dict["SETUP"]["MAX_CORESET_ITERATIONS"] = MAX_CORESET_ITERATIONS
setup_dict["SETUP"]["NUM_CORESET_SEEDS"] = NUM_CORESET_SEEDS
setup_dict["SETUP"]["NUM_PROJECTION_SEEDS"] = NUM_PROJECTION_SEEDS
setup_dict["SETUP"]["ORTHONORMAL"] = ORTHONORMAL
setup_dict["SETUP"]["PROJECTION_BATCH_SIZE"] = PROJECTION_BATCH_SIZE
setup_dict["SETUP"]["AUTOENCODER_BATCH_SIZE"] = PROJECTION_BATCH_SIZE
setup_dict["SETUP"]["CONVERGENCE_PARAMETER"] = CONVERGENCE_PARAMETER
setup_dict["SETUP"]["PROJECTION_RATE"] = PROJECTION_RATE
setup_dict["SETUP"]["AUTOENCODER_RATE"] = AUTOENCODER_RATE
setup_dict["SETUP"]["CORESET_FEATURE_RATE"] = CORESET_FEATURE_RATE
setup_dict["SETUP"]["CORESET_RESPONSE_RATE"] = CORESET_RESPONSE_RATE
setup_dict["SETUP"]["NUM_CONSTRUCTIONS"] = NUM_CONSTRUCTIONS

setup_dict["DATA"] = {}
setup_dict["DATA"]["X"] = X
setup_dict["DATA"]["Y"] = Y
setup_dict["DATA"]["X_TEST"] = X_TEST
setup_dict["DATA"]["Y_TEST"] = Y_TEST

jnp.save(f"./synthetic/{DATA_NAME}/setup_{SEED}_{DATA_NAME}", setup_dict)

################################## Construct Coresets ##################################
reconstruction_kernel = PCIMQKernel(median_heuristic(X[:1000]))
response_kernel = PCIMQKernel(median_heuristic(Y[:1000]))

construct_keys = jr.split(jr.key(SEED), num=(NUM_CONSTRUCTIONS,))

for i in range(NUM_CONSTRUCTIONS):
    print(f"\n------- {i + 1}/{NUM_CONSTRUCTIONS} -------\n")
    print("BDC-NL...")
    nonlinear_solver = SupervisedNonLinearBilateralDistributionCompression(
        coreset_size=CORESET_SIZE,
        autoencoder=Autoencoder(
            encoder=Encoder(
                random_key=construct_keys[i],
                ambient_dimension=X.shape[1],
                intrinsic_dimension=INTRINSIC_DIMENSION,
                num_hidden_layers=2,
                hidden_layer_sizes=[128, 64],
            ),
            decoder=Decoder(
                random_key=construct_keys[i],
                ambient_dimension=X.shape[1],
                intrinsic_dimension=INTRINSIC_DIMENSION,
                num_hidden_layers=2,
                hidden_layer_sizes=[64, 128],
            ),
        ),
        random_key=construct_keys[i],
        reconstruction_kernel=reconstruction_kernel,
        response_kernel=response_kernel,
        compression_kernel="median_heuristic",
        num_autoencoder_epochs=NUM_AUTOENCODER_EPOCHS,
        autoencoder_optimiser=projection_optimiser,
        num_coreset_seeds=NUM_CORESET_SEEDS,
        max_coreset_iterations=MAX_CORESET_ITERATIONS,
        convergence_parameter=CONVERGENCE_PARAMETER,
        coreset_feature_optimiser=coreset_feature_optimiser,
        coreset_response_optimiser=coreset_response_optimiser,
        track_info=False,
        autoencoder_batch_size=PROJECTION_BATCH_SIZE,
        validation_features=X_TEST,
        validation_responses=Y_TEST,
    )

    t0 = time()
    nonlinear_coreset, nonlinear_state = nonlinear_solver.reduce(SupervisedData(X, Y))
    nonlinear_timer = time() - t0

    # Save autoencoder ready for deserialisation
    eqx.tree_serialise_leaves(
        f"./synthetic/{DATA_NAME}/{DATA_NAME}_model_{i}.eqx",
        nonlinear_state.autoencoder,
    )

    print("BDC-L...")
    linear_solver = SupervisedLinearBilateralDistributionCompression(
        coreset_size=CORESET_SIZE,
        intrinsic_dimension=INTRINSIC_DIMENSION,
        random_key=construct_keys[i],
        reconstruction_kernel=reconstruction_kernel,
        response_kernel=response_kernel,
        compression_kernel="median_heuristic",
        orthonormal=ORTHONORMAL,
        num_projection_seeds=NUM_PROJECTION_SEEDS,
        num_projection_epochs=NUM_PROJECTION_EPOCHS,
        projection_optimiser=projection_optimiser,
        num_coreset_seeds=NUM_CORESET_SEEDS,
        max_coreset_iterations=MAX_CORESET_ITERATIONS,
        projection_convergence_parameter=CONVERGENCE_PARAMETER,
        coreset_convergence_parameter=CONVERGENCE_PARAMETER,
        coreset_feature_optimiser=coreset_feature_optimiser,
        coreset_response_optimiser=coreset_response_optimiser,
        track_info=False,
        projection_batch_size=PROJECTION_BATCH_SIZE,
        validation_features=X_TEST,
        validation_responses=Y_TEST,
    )

    t0 = time()
    linear_coreset, linear_state = linear_solver.reduce(SupervisedData(X, Y))
    linear_timer = time() - t0

    print("ADC...")
    kip_solver = JointKIP(
        coreset_size=CORESET_SIZE,
        random_key=construct_keys[i],
        feature_kernel=reconstruction_kernel,
        response_kernel=response_kernel,
        target_sample_size=None,
        coreset_sample_size=None,
        num_seeds=NUM_CORESET_SEEDS,
        max_iterations=MAX_CORESET_ITERATIONS,
        convergence_parameter=CONVERGENCE_PARAMETER,
        feature_optimiser=coreset_feature_optimiser,
        response_optimiser=coreset_response_optimiser,
        track_info=False,
    )

    t0 = time()
    kip_coreset, _ = kip_solver.reduce(SupervisedData(X, Y))
    kip_timer = time() - t0

    ################################# Save the result #################################

    result_dict = {}

    result_dict["RESULTS"] = {}

    result_dict["RESULTS"]["LINEAR_RESULTS"] = {}
    result_dict["RESULTS"]["LINEAR_RESULTS"]["LINEAR_AUTOENCODER"] = (
        linear_state.autoencoder
    )
    result_dict["RESULTS"]["LINEAR_RESULTS"]["LINEAR_CORESET"] = linear_coreset
    result_dict["RESULTS"]["LINEAR_RESULTS"]["LINEAR_TIMES"] = linear_timer

    result_dict["RESULTS"]["NONLINEAR_RESULTS"] = {}
    result_dict["RESULTS"]["NONLINEAR_RESULTS"]["NONLINEAR_CORESET"] = nonlinear_coreset
    result_dict["RESULTS"]["NONLINEAR_RESULTS"]["NONLINEAR_TIME"] = nonlinear_timer

    result_dict["RESULTS"]["KIP_RESULTS"] = {}
    result_dict["RESULTS"]["KIP_RESULTS"]["KIP_CORESET"] = kip_coreset
    result_dict["RESULTS"]["KIP_RESULTS"]["KIP_TIME"] = kip_timer

    jnp.save(f"./synthetic/{DATA_NAME}/results_{i}_{SEED}_{DATA_NAME}", result_dict)
