"""Compress the conditional distribution of a synthetic regression dataset."""

# Import required modules
import optax
import jax
import jax.numpy as jnp
import jax.random as jr
from jax import Array

from tqdm import tqdm
from sklearn.preprocessing import StandardScaler
from sklearn.datasets import fetch_openml

from coreax.kernels import SquaredExponentialKernel, median_heuristic, IndicatorKernel
from coreax.data import SupervisedData
from coreax.solvers import (
    AverageConditionalKIP,
    AverageConditionalKernelHerding,
    PseudoJointKernelHerding,
    JointKIP,
    HerdingExhaustiveSearch,
    KIPExhaustiveSearch,
)


# Check device
print(f"Jax backend: {jax.default_backend()}")

# Enable float64
jax.config.update("jax_enable_x64", True)

# Set top-seed
SEED = 10001

#################################### GENERATE DATA ####################################
DATA_NAME = "mnist"
DATA_SIZE = 8000
VALIDATION_SIZE = 1000
TEST_SIZE = 1000

# Load the data
print("Loading and formatting data...")
mnist = fetch_openml("mnist_784")
X, y = (
    mnist.data.to_numpy().astype(jnp.float64),
    mnist.target.to_numpy().astype(jnp.float64).reshape(-1, 1),
)

# Extract the data
features = X[:DATA_SIZE]
responses = y[:DATA_SIZE]

test_features = X[DATA_SIZE : DATA_SIZE + TEST_SIZE]
test_responses = y[DATA_SIZE : DATA_SIZE + TEST_SIZE]

validation_features = X[DATA_SIZE + TEST_SIZE : DATA_SIZE + TEST_SIZE + VALIDATION_SIZE]
validation_responses = y[
    DATA_SIZE + TEST_SIZE : DATA_SIZE + TEST_SIZE + VALIDATION_SIZE
]

# Scale the features
feature_scaler = StandardScaler().fit(features)
features = jnp.asarray(feature_scaler.transform(features))
test_features = jnp.asarray(feature_scaler.transform(test_features))
validation_features = jnp.asarray(feature_scaler.transform(validation_features))

# Put the data in a container
data = SupervisedData(features, responses)

################################ SETUP KERNEL FUNCTIONS ################################
# Use median heuristic with a subset of the data to choose lengthscales
feature_kernel = SquaredExponentialKernel(median_heuristic(features[:1000, :]))
response_kernel = IndicatorKernel()

####################### CROSS-VALIDATE REGULARISATION PARAMETER #######################
# Define a coarse grid for initial check
COARSE_GRID = jnp.array(
    [
        1e1,
        0.5e1,
        1e0,
        0.5e0,
        1e-1,
        0.5e-1,
        1e-2,
        0.5e-2,
        1e-3,
        0.5e-3,
        1e-4,
        0.5e-4,
        1e-5,
        0.5e-5,
        1e-6,
        0.5e-6,
        1e-7,
        0.5e-7,
        1e-8,
        0.5e-8,
        1e-9,
    ]
)


# Define the validation loss function
def validation_loss(
    regularisation_parameter: Array,
    feature_gramian: Array,
    cross_feature_gramian: Array,
    response_gramian: Array,
    cross_response_gramian: Array,
):
    """Compute the loss on a validation set."""
    train_size = feature_gramian.shape[0]

    # Regularise the feature gramian
    regularised_feature_gramian = feature_gramian + regularisation_parameter * jnp.eye(
        train_size
    )

    # Compute the coefficients
    coefficients = jnp.linalg.solve(
        a=regularised_feature_gramian, b=cross_feature_gramian
    )

    # Compute the terms of the validation loss
    term_2 = (coefficients * cross_response_gramian).sum()
    term_3 = (coefficients.T.dot(response_gramian) * coefficients.T).sum()

    return (1 / train_size) * (-2 * term_2 + term_3)


# Vmap the validation loss ready for evaluating across an array of regularisation
# parameters.
vmapped_validation_loss = jax.vmap(validation_loss, in_axes=(0, None, None, None, None))

# Define the number of subsets we will take, and the size of those subsets
NUM_SETS = 10
BATCH_SIZE = 1000
coarse_losses = jnp.zeros((NUM_SETS, len(COARSE_GRID)))
rough_keys = jr.split(jr.key(SEED), (NUM_SETS,))
print("\nFinding average rough minima...")
for i in tqdm(range(NUM_SETS)):
    # Sample the batch
    batch_indices = jr.choice(
        rough_keys[i], features.shape[0], shape=(BATCH_SIZE,), replace=False
    )

    # Compute the validation loss over the coarse grid
    coarse_losses = coarse_losses.at[i, :].set(
        vmapped_validation_loss(
            COARSE_GRID,
            feature_kernel.compute(features[batch_indices], features[batch_indices]),
            feature_kernel.compute(features[batch_indices], validation_features),
            response_kernel.compute(responses[batch_indices], responses[batch_indices]),
            response_kernel.compute(responses[batch_indices], validation_responses),
        )
    )
# Average the losses and choose the region to zoom into that contains the minima
averaged_coarse_losses = coarse_losses.mean(axis=0)
min_idx = jnp.nanargmin(averaged_coarse_losses) - 1
max_idx = jnp.nanargmin(averaged_coarse_losses) + 1

FINE_SIZE = 25
fine_losses = jnp.zeros((NUM_SETS, FINE_SIZE))
fine_grid = jnp.linspace(COARSE_GRID[min_idx], COARSE_GRID[max_idx], FINE_SIZE)
fine_keys = jr.split(rough_keys[-1], (NUM_SETS,))
print("Finding average fine minima...")
for i in tqdm(range(NUM_SETS)):
    # Get indices
    batch_indices = jr.choice(
        fine_keys[i], features.shape[0], shape=(BATCH_SIZE,), replace=False
    )
    # Compute the validation loss over the coarse grid
    fine_losses = fine_losses.at[i, :].set(
        vmapped_validation_loss(
            fine_grid,
            feature_kernel.compute(features[batch_indices], features[batch_indices]),
            feature_kernel.compute(features[batch_indices], validation_features),
            response_kernel.compute(responses[batch_indices], responses[batch_indices]),
            response_kernel.compute(responses[batch_indices], validation_responses),
        )
    )

# Set the optimal regularisation parameter
optimal_regularisation_parameter = fine_grid[
    jnp.nanargmin(fine_losses.mean(axis=0))
].item()
print(f"Optimal Regularisation Parameter: {optimal_regularisation_parameter}")

################################ SET SOLVER HYPERPARAMS ################################

SCHEDULE = 1e-1
constant_schedule = optax.constant_schedule(SCHEDULE)
feature_optimiser = optax.adam(learning_rate=constant_schedule)
response_optimiser = HerdingExhaustiveSearch(jnp.unique(responses).reshape(-1, 1))

CORESET_SIZE = 250  # number of pairs to optimise
MAX_STEPS = 250  # maximum number of gradient steps
NUM_SEEDS = 10  # number of initial seeds for optimisation to check
CONVERGENCE_PARAMETER = 1e-7

NUM_RUNS = 20  # Number of runs of each coreset solver to use
run_keys = jr.split(jr.key(SEED), (NUM_RUNS,))

################################ COMPUTE HERD CORESETS ################################

print("\nConstruct the Joint Kernel Herding coresets...")
jkh_coresets, jkh_states = [], []
for i in range(NUM_RUNS):
    print(f"--------------------------- {i + 1}/{NUM_RUNS} ---------------------------")
    jkh_solver = PseudoJointKernelHerding(
        coreset_size=CORESET_SIZE,
        random_key=run_keys[i],
        feature_kernel=feature_kernel,
        response_kernel=response_kernel,
        feature_optimiser=feature_optimiser,
        response_optimiser=response_optimiser,
        batch_size=None,
        max_steps=MAX_STEPS,
        convergence_parameter=CONVERGENCE_PARAMETER,
        num_seeds=NUM_SEEDS,
        track_info=False,
    )

    jkh_coreset, jkh_state = jkh_solver.reduce(data)
    jkh_states.append(jkh_state)
    jkh_coresets.append((jkh_coreset.coreset.data, jkh_coreset.coreset.supervision))

print("\nConstruct the Average Conditional Kernel Herding coresets...")
ackh_coresets, ackh_states = [], []
for i in range(NUM_RUNS):
    print(f"--------------------------- {i + 1}/{NUM_RUNS} ---------------------------")
    ackh_solver = AverageConditionalKernelHerding(
        coreset_size=CORESET_SIZE,
        random_key=run_keys[i],
        feature_kernel=feature_kernel,
        response_kernel=response_kernel,
        feature_optimiser=feature_optimiser,
        response_optimiser=response_optimiser,
        batch_size=None,
        max_steps=MAX_STEPS,
        convergence_parameter=CONVERGENCE_PARAMETER,
        num_seeds=NUM_SEEDS,
        track_info=False,
        regularisation_parameter=optimal_regularisation_parameter,
    )

    ackh_coreset, ackh_state = ackh_solver.reduce(data)
    ackh_states.append(ackh_state)
    ackh_coresets.append((ackh_coreset.coreset.data, ackh_coreset.coreset.supervision))


############################### CONSTRUCT KIP CORESETS ###############################
# Set KIP solver specific hyperparams
MAX_ITERATIONS = 500
response_optimiser = KIPExhaustiveSearch(jnp.unique(responses).reshape(-1, 1))

# Set up an array of coreset sizes to construct using KIP
NUM_CONSTRUCTIONS = 20
CORESET_SIZES = jnp.linspace(1, CORESET_SIZE, NUM_CONSTRUCTIONS).astype(jnp.int64)

print("\nConstruct the Average Conditional Kernel Inducing Point coresets...")
ackip_coresets, ackip_states = (
    [[] for _ in range(NUM_RUNS)],
    [[] for _ in range(NUM_RUNS)],
)
for i in range(NUM_RUNS):
    print(f"--------------------------- {i + 1}/{NUM_RUNS} ---------------------------")
    for j in range(NUM_CONSTRUCTIONS):
        ackip_solver = AverageConditionalKIP(
            coreset_size=CORESET_SIZES[j],
            random_key=run_keys[i],
            feature_kernel=feature_kernel,
            response_kernel=response_kernel,
            regularisation_parameter=optimal_regularisation_parameter,
            max_iterations=MAX_ITERATIONS,
            target_sample_size=None,
            coreset_sample_size=None,
            convergence_parameter=CONVERGENCE_PARAMETER,
            feature_optimiser=feature_optimiser,
            response_optimiser=response_optimiser,
            track_info=False,
            num_seeds=NUM_SEEDS,
        )

        ackip_coreset, ackip_state = ackip_solver.reduce(data)
        ackip_states[i].append(ackip_state)
        ackip_coresets[i].append(
            (ackip_coreset.coreset.data, ackip_coreset.coreset.supervision)
        )

print("\nConstruct the Joint Kernel Inducing Point coresets...")
jkip_coresets, jkip_states = (
    [[] for _ in range(NUM_RUNS)],
    [[] for _ in range(NUM_RUNS)],
)
for i in range(NUM_RUNS):
    print(f"--------------------------- {i + 1}/{NUM_RUNS} ---------------------------")
    for j in range(NUM_CONSTRUCTIONS):
        jkip_solver = JointKIP(
            coreset_size=CORESET_SIZES[j],
            random_key=run_keys[i],
            feature_kernel=feature_kernel,
            response_kernel=response_kernel,
            max_iterations=MAX_ITERATIONS,
            target_sample_size=None,
            coreset_sample_size=None,
            convergence_parameter=CONVERGENCE_PARAMETER,
            feature_optimiser=feature_optimiser,
            response_optimiser=response_optimiser,
            track_info=False,
            num_seeds=NUM_SEEDS,
        )

        jkip_coreset, jkip_state = jkip_solver.reduce(data)
        jkip_states[i].append(jkip_state)
        jkip_coresets[i].append(
            (jkip_coreset.coreset.data, jkip_coreset.coreset.supervision)
        )

# Save all the results in a dict
result_dict = {}
result_dict["SETUP"] = {}
result_dict["SETUP"]["SEED"] = SEED

result_dict["SETUP"]["DATA"] = {}
result_dict["SETUP"]["DATA"]["TRAINING"] = (features, responses)
result_dict["SETUP"]["DATA"]["TEST"] = (test_features, test_responses)
result_dict["SETUP"]["DATA"]["VALIDATION"] = (validation_features, validation_responses)
result_dict["SETUP"]["DATA"]["DATA_NAME"] = DATA_NAME
result_dict["SETUP"]["DATA"]["DATA_SIZE"] = DATA_SIZE
result_dict["SETUP"]["DATA"]["VALIDATION_SIZE"] = VALIDATION_SIZE
result_dict["SETUP"]["DATA"]["TEST_SIZE"] = TEST_SIZE
result_dict["SETUP"]["DATA"]["FEATURE_SCALER"] = {}
result_dict["SETUP"]["DATA"]["FEATURE_SCALER"]["SCALER"] = feature_scaler
result_dict["SETUP"]["DATA"]["FEATURE_SCALER"]["MEAN"] = feature_scaler.mean_
result_dict["SETUP"]["DATA"]["FEATURE_SCALER"]["SCALE"] = feature_scaler.scale_

result_dict["SETUP"]["CROSS_VALIDATION"] = {}
result_dict["SETUP"]["CROSS_VALIDATION"]["COARSE_GRID"] = COARSE_GRID
result_dict["SETUP"]["CROSS_VALIDATION"]["NUM_SETS"] = NUM_SETS
result_dict["SETUP"]["CROSS_VALIDATION"]["BATCH_SIZE"] = BATCH_SIZE
result_dict["SETUP"]["CROSS_VALIDATION"]["REGULARISATION_PARAMETER"] = (
    optimal_regularisation_parameter
)

result_dict["SETUP"]["CORESET_CONSTRUCTION"] = {}
result_dict["SETUP"]["CORESET_CONSTRUCTION"]["FEATURE_KERNEL"] = feature_kernel
result_dict["SETUP"]["CORESET_CONSTRUCTION"]["RESPONSE_KERNEL"] = response_kernel
result_dict["SETUP"]["CORESET_CONSTRUCTION"]["SCHEDULE"] = SCHEDULE
# result_dict["SETUP"]["CORESET_CONSTRUCTION"]["FEATURE_OPTIMISER"] = feature_optimiser
# result_dict["SETUP"]["CORESET_CONSTRUCTION"]["RESPONSE_OPTIMISER"] = response_optimiser
result_dict["SETUP"]["CORESET_CONSTRUCTION"]["CORESET_SIZE"] = CORESET_SIZE
result_dict["SETUP"]["CORESET_CONSTRUCTION"]["MAX_STEPS"] = MAX_STEPS
result_dict["SETUP"]["CORESET_CONSTRUCTION"]["NUM_SEEDS"] = NUM_SEEDS
result_dict["SETUP"]["CORESET_CONSTRUCTION"]["CONVERGENCE_PARAMETER"] = (
    CONVERGENCE_PARAMETER
)
result_dict["SETUP"]["CORESET_CONSTRUCTION"]["NUM_RUNS"] = NUM_RUNS
result_dict["SETUP"]["CORESET_CONSTRUCTION"]["MAX_ITERATIONS"] = MAX_ITERATIONS
result_dict["SETUP"]["CORESET_CONSTRUCTION"]["NUM_CONSTRUCTIONS"] = NUM_CONSTRUCTIONS
result_dict["SETUP"]["CORESET_CONSTRUCTION"]["CORESET_SIZES"] = CORESET_SIZES

result_dict["JKH_RESULTS"] = {}
result_dict["JKH_RESULTS"]["CORESETS"] = jkh_coresets
result_dict["JKH_RESULTS"]["STATES"] = jkh_states

result_dict["ACKH_RESULTS"] = {}
result_dict["ACKH_RESULTS"]["CORESETS"] = ackh_coresets
result_dict["ACKH_RESULTS"]["STATES"] = ackh_states

result_dict["JKIP_RESULTS"] = {}
result_dict["JKIP_RESULTS"]["CORESETS"] = jkip_coresets
result_dict["JKIP_RESULTS"]["STATES"] = jkip_states

result_dict["ACKIP_RESULTS"] = {}
result_dict["ACKIP_RESULTS"]["CORESETS"] = ackip_coresets
result_dict["ACKIP_RESULTS"]["STATES"] = ackip_states

jnp.save(f"results_{SEED}_{DATA_NAME}", result_dict)
