"""Construct Bilateral Coresets."""

#################################### Import Modules ####################################

import jax
import jax.random as jr
import jax.numpy as jnp
import optax


from coreax.solvers import LinearBilateralDistributionCompression, KernelInducingPoints
from coreax.data import Data
from coreax.kernels import SquaredExponentialKernel, median_heuristic

################################### Set JAX Settings ###################################

jax.config.update("jax_platform_name", "gpu")
jax.config.update("jax_enable_x64", True)
print(f"Jax backend: {jax.default_backend()}")

################################## Set File Settings ###################################

SEED = 10001

DATA_SIZE = 20000
AMBIENT_DIMENSION = 250

CORESET_SIZES = [
    2,
    5,
    10,
    25,
    50,
    75,
    100,
    150,
]
NUM_PROJECTION_EPOCHS = 3
MAX_CORESET_ITERATIONS = 3000
NUM_CORESET_SEEDS = 1
NUM_PROJECTION_SEEDS = 1
ORTHONORMAL = True
PROJECTION_CONVERGENCE_PARAMETER = 0
CORESET_CONVERGENCE_PARAMETER = 1e-7
PROJECTION_BATCH_SIZE = 64
NUM_CONSTRUCTIONS = 5

PROJECTION_RATE = 1e-3
CORESET_RATE = 1e-1
projection_optimiser = optax.adam(
    learning_rate=optax.constant_schedule(PROJECTION_RATE)
)
coreset_optimiser = optax.adam(learning_rate=optax.constant_schedule(CORESET_RATE))
kip_optimiser = optax.adam(learning_rate=optax.constant_schedule(CORESET_RATE))

##################################### Generate Data ####################################
print("\n------------------------ Generating Data ------------------------\n")


def gmm_sample(probs, mus, covs, num_components, seed):
    """Sample from the Gaussian mixture model."""
    x = jnp.zeros((1, mus.shape[1]))
    _, mixture_counts = jnp.unique(
        jr.choice(
            jr.key(seed),
            probs.shape[0],
            shape=(num_components,),
            replace=True,
            p=probs,
        ),
        return_counts=True,
    )
    for k in range(probs.shape[0]):
        x = jnp.vstack(
            (
                x,
                jr.multivariate_normal(
                    jr.key(seed),
                    mus[k],
                    jnp.atleast_2d(covs[k]),
                    shape=(mixture_counts[k],),
                ),
            )
        )

    # Remove beginning empty row and shuffle
    x = x[1:, :]
    return jr.permutation(jr.key(seed), x)


# Generate low-dimensional gaussian mixture
DATA_NAME = "gaussian_mixture"
MUS = (
    jnp.array(
        [
            [0, 0],
            [1, 1],
            [-1, -1],
            [-1, 1],
            [1, -1],
            [2, 0],
            [-2, 0],
            [0, 2],
            [0, -2],
        ]
    )
    / 1.5
)
MUS = MUS - MUS.mean(axis=0)
WEIGHTS = jnp.ones(MUS.shape[0])
WEIGHTS /= WEIGHTS.sum()
INTRINSIC_DIMENSION = MUS.shape[1]

COVS = jnp.zeros((WEIGHTS.shape[0], INTRINSIC_DIMENSION, INTRINSIC_DIMENSION))
for i in range(WEIGHTS.shape[0]):
    Sigma = jnp.eye(2) / 20
    COVS = COVS.at[i, :, :].set(Sigma)

X_INTRINSIC = gmm_sample(WEIGHTS, MUS, COVS, DATA_SIZE, SEED)

key, subkey = jr.split(jr.PRNGKey(SEED), 2)

# Random projection to higher dimension
V_TRUE = jr.normal(subkey, (INTRINSIC_DIMENSION, AMBIENT_DIMENSION))
X = X_INTRINSIC @ V_TRUE

###################################### Save Setup ######################################

setup_dict = {}

setup_dict["SETUP"] = {}
setup_dict["SETUP"]["SEED"] = SEED
setup_dict["SETUP"]["DATA_NAME"] = DATA_NAME
setup_dict["SETUP"]["DATA_SIZE"] = DATA_SIZE
setup_dict["SETUP"]["AMBIENT_DIMENSION"] = AMBIENT_DIMENSION
setup_dict["SETUP"]["CORESET_SIZES"] = CORESET_SIZES
setup_dict["SETUP"]["INTRINSIC_DIMENSION"] = INTRINSIC_DIMENSION
setup_dict["SETUP"]["NUM_PROJECTION_EPOCHS"] = NUM_PROJECTION_EPOCHS
setup_dict["SETUP"]["MAX_CORESET_ITERATIONS"] = MAX_CORESET_ITERATIONS
setup_dict["SETUP"]["NUM_CORESET_SEEDS"] = NUM_CORESET_SEEDS
setup_dict["SETUP"]["NUM_PROJECTION_SEEDS"] = NUM_PROJECTION_SEEDS
setup_dict["SETUP"]["ORTHONORMAL"] = ORTHONORMAL
setup_dict["SETUP"]["PROJECTION_BATCH_SIZE"] = PROJECTION_BATCH_SIZE
setup_dict["SETUP"]["CORESET_CONVERGENCE_PARAMETER"] = CORESET_CONVERGENCE_PARAMETER
setup_dict["SETUP"]["PROJECTION_CONVERGENCE_PARAMETER"] = (
    PROJECTION_CONVERGENCE_PARAMETER
)
setup_dict["SETUP"]["PROJECTION_RATE"] = PROJECTION_RATE
setup_dict["SETUP"]["CORESET_RATE"] = CORESET_RATE
setup_dict["SETUP"]["NUM_CONSTRUCTIONS"] = NUM_CONSTRUCTIONS
setup_dict["SETUP"]["COVS"] = COVS
setup_dict["SETUP"]["MUS"] = MUS
setup_dict["SETUP"]["WEIGHTS"] = WEIGHTS

setup_dict["DATA"] = {}
setup_dict["DATA"]["X_INTRINSIC"] = X_INTRINSIC
setup_dict["DATA"]["V_TRUE"] = V_TRUE
setup_dict["DATA"]["X"] = X

jnp.save(f"./synthetic/{DATA_NAME}/setup_{SEED}_{DATA_NAME}", setup_dict)


################################## Construct Coresets ##################################
reconstruction_kernel = SquaredExponentialKernel(median_heuristic(X[:1000]))

construct_keys = jr.split(jr.key(SEED), NUM_CONSTRUCTIONS)
for i in range(NUM_CONSTRUCTIONS):
    (
        bdc_pb_coresets_,
        bdc_pb_states_,
        kip_coresets_,
        kip_states_,
    ) = ([], [], [], [])

    for coreset_size in CORESET_SIZES:
        print(
            f"\n------- {i + 1}/{NUM_CONSTRUCTIONS}: Coreset Size: {coreset_size} -------\n"
        )
        print("Bilateral Distribution Compression (Pull Back Kernel)...")
        bdc_pb_solver = LinearBilateralDistributionCompression(
            coreset_size=coreset_size,
            intrinsic_dimension=INTRINSIC_DIMENSION,
            random_key=construct_keys[i],
            reconstruction_kernel=reconstruction_kernel,
            compression_kernel="pull_back",
            orthonormal=ORTHONORMAL,
            num_projection_seeds=NUM_PROJECTION_SEEDS,
            num_projection_epochs=NUM_PROJECTION_EPOCHS,
            projection_optimiser=projection_optimiser,
            num_coreset_seeds=NUM_CORESET_SEEDS,
            max_coreset_iterations=MAX_CORESET_ITERATIONS,
            projection_convergence_parameter=PROJECTION_CONVERGENCE_PARAMETER,
            coreset_convergence_parameter=CORESET_CONVERGENCE_PARAMETER,
            coreset_optimiser=coreset_optimiser,
            track_info=False,
            projection_batch_size=PROJECTION_BATCH_SIZE,
            validation_data=X[:1000],
        )

        bdc_pb_coreset, bdc_pb_state = bdc_pb_solver.reduce(Data(X))

        bdc_pb_coresets_.append(bdc_pb_coreset)
        bdc_pb_states_.append(bdc_pb_state.autoencoder)

        print("Kernel Inducing Points...")
        kip_solver = KernelInducingPoints(
            coreset_size=coreset_size,
            random_key=construct_keys[i],
            kernel=reconstruction_kernel,
            target_sample_size=None,
            coreset_sample_size=None,
            num_seeds=NUM_CORESET_SEEDS,
            max_iterations=MAX_CORESET_ITERATIONS,
            convergence_parameter=CORESET_CONVERGENCE_PARAMETER,
            optimiser=kip_optimiser,
            track_info=False,
        )

        kip_coreset, _ = kip_solver.reduce(Data(X))

        kip_states_.append(kip_state)

    result_dict = {}

    result_dict["BDC_PULLBACK_RESULTS"] = {}
    result_dict["BDC_PULLBACK_RESULTS"]["STATES"] = bdc_pb_states_
    result_dict["BDC_PULLBACK_RESULTS"]["CORESETS"] = bdc_pb_coresets_

    result_dict["KIP_RESULTS"] = {}
    result_dict["KIP_RESULTS"]["CORESETS"] = kip_coresets_

    jnp.save(f"./synthetic/{DATA_NAME}/results_{SEED}_{DATA_NAME}_{i}", result_dict)
