"""Compress the conditional distribution of a synthetic regression dataset."""

# Import required modules
import optax
import jax
import jax.numpy as jnp
import jax.random as jr
from jax import Array

from tqdm import tqdm

from coreax.kernels import SquaredExponentialKernel, median_heuristic
from coreax.data import SupervisedData
from coreax.solvers import (
    ExactAverageConditionalKIP,
    ExactAverageConditionalKernelHerding,
    ExactPseudoJointKernelHerding,
    ExactJointKIP,
)

# Check device
print(f"Jax backend: {jax.default_backend()}")

# Enable float64
jax.config.update("jax_enable_x64", True)

# Set top-seed
SEED = 10001

#################################### GENERATE DATA ####################################
DATA_NAME = "exact_linear"

DATA_SIZE = 8000
VALIDATION_SIZE = 1000
TEST_SIZE = 1000

# Generate features from a Gaussian distributions
X_STANDARD_DEVIATION = 1
X_MEAN = 1
features = X_MEAN + X_STANDARD_DEVIATION * jr.normal(
    jr.key(SEED),
    shape=(DATA_SIZE, 1),
)
test_features = X_MEAN + X_STANDARD_DEVIATION * jr.normal(
    jr.key(2 * SEED),
    shape=(TEST_SIZE, 1),
)
validation_features = X_MEAN + X_STANDARD_DEVIATION * jr.normal(
    jr.key(3 * SEED),
    shape=(VALIDATION_SIZE, 1),
)

# Set model parameters
BIAS = -0.5
SLOPE = 0.5


def generating_function(x):
    """Linear generating function."""
    return BIAS + SLOPE * x


Y_STANDARD_DEVIATION = 0.5


def generating_noise(x):
    """Homoscedastic noise function."""
    return jnp.repeat(Y_STANDARD_DEVIATION, x.shape[0]).reshape(x.shape)


# Generate responses
responses = generating_function(features) + generating_noise(features) * jr.normal(
    jr.key(4 * SEED), shape=(DATA_SIZE, 1)
)
test_responses = generating_function(test_features) + generating_noise(
    test_features
) * jr.normal(jr.key(5 * SEED), shape=(TEST_SIZE, 1))
validation_responses = generating_function(validation_features) + generating_noise(
    validation_features
) * jr.normal(jr.key(6 * SEED), shape=(VALIDATION_SIZE, 1))


# Put the data in a container
data = SupervisedData(features, responses)


################################ SETUP KERNEL FUNCTIONS ################################
# Use median heuristic with a subset of the data to choose lengthscales
feature_kernel = SquaredExponentialKernel(median_heuristic(features[:1000, :]))
response_kernel = SquaredExponentialKernel(median_heuristic(responses[:1000, :]))

# Compute the exact value of the first term of the AMCMD
FIRST_TERM = jnp.sqrt(
    (Y_STANDARD_DEVIATION**2 + response_kernel.length_scale**2)
    / (
        ((1 / response_kernel.length_scale**2) + (1 / Y_STANDARD_DEVIATION**2))
        * Y_STANDARD_DEVIATION**2
        * (2 * Y_STANDARD_DEVIATION**2 + response_kernel.length_scale**2)
    )
)

####################### CROSS-VALIDATE REGULARISATION PARAMETER #######################
# Define a coarse grid for initial check
COARSE_GRID = jnp.array(
    [1e1, 0.5e1, 1e0, 0.5e0, 1e-1, 0.5e-1, 1e-2, 0.5e-2, 1e-3, 0.5e-3, 1e-4]
)


# Define the validation loss function
def validation_loss(
    regularisation_parameter: Array,
    feature_gramian: Array,
    cross_feature_gramian: Array,
    response_gramian: Array,
    cross_response_gramian: Array,
):
    """Compute the loss on a validation set."""
    # Compute the coefficients
    coefficients = jnp.linalg.solve(
        a=feature_gramian
        + regularisation_parameter * jnp.eye(feature_gramian.shape[0]),
        b=cross_feature_gramian,
    )

    # Compute the terms of the validation loss
    term_2 = (coefficients * cross_response_gramian).sum()
    term_3 = (coefficients.T.dot(response_gramian) * coefficients.T).sum()

    return (1 / cross_feature_gramian.shape[1]) * (-2 * term_2 + term_3)


# Vmap the validation loss ready for evaluating across an array of regularisation
# parameters.
vmapped_validation_loss = jax.vmap(validation_loss, in_axes=(0, None, None, None, None))

# Define the number of subsets we will take, and the size of those subsets
NUM_SETS = 10
BATCH_SIZE = 1000
coarse_losses = jnp.zeros((NUM_SETS, len(COARSE_GRID)))
rough_keys = jr.split(jr.key(SEED), (NUM_SETS,))
print("\nFinding average rough minima...")
for i in tqdm(range(NUM_SETS)):
    # Sample the batch
    batch_indices = jr.choice(
        rough_keys[i], features.shape[0], shape=(BATCH_SIZE,), replace=False
    )

    # Compute the validation loss over the coarse grid
    coarse_losses = coarse_losses.at[i, :].set(
        vmapped_validation_loss(
            COARSE_GRID,
            feature_kernel.compute(features[batch_indices], features[batch_indices]),
            feature_kernel.compute(features[batch_indices], validation_features),
            response_kernel.compute(responses[batch_indices], responses[batch_indices]),
            response_kernel.compute(responses[batch_indices], validation_responses),
        )
    )

# Average the losses and choose the region to zoom into that contains the minima
averaged_coarse_losses = coarse_losses.mean(axis=0)
min_idx = jnp.nanargmin(averaged_coarse_losses) - 1
max_idx = jnp.nanargmin(averaged_coarse_losses) + 1

# Generate a grid over the zoomed in section
FINE_SIZE = 25
fine_losses = jnp.zeros((NUM_SETS, FINE_SIZE))
fine_grid = jnp.linspace(COARSE_GRID[min_idx], COARSE_GRID[max_idx], FINE_SIZE)
fine_keys = jr.split(rough_keys[-1], (NUM_SETS,))
print("\nFinding average fine minima...")
for i in tqdm(range(NUM_SETS)):
    # Sample the batch
    batch_indices = jr.choice(
        fine_keys[i], DATA_SIZE, shape=(BATCH_SIZE,), replace=False
    )

    # Compute the validation loss over the coarse grid
    fine_losses = fine_losses.at[i, :].set(
        vmapped_validation_loss(
            fine_grid,
            feature_kernel.compute(features[batch_indices], features[batch_indices]),
            feature_kernel.compute(features[batch_indices], validation_features),
            response_kernel.compute(responses[batch_indices], responses[batch_indices]),
            response_kernel.compute(responses[batch_indices], validation_responses),
        )
    )

# Set the optimal regularisation parameter
optimal_regularisation_parameter = fine_grid[
    jnp.nanargmin(fine_losses.mean(axis=0))
].item()
print(f"Optimal Regularisation Parameter: {optimal_regularisation_parameter}")

################################ SET SOLVER HYPERPARAMS ################################

SCHEDULE = 1e-1
constant_schedule = optax.constant_schedule(SCHEDULE)
feature_optimiser = optax.adam(learning_rate=constant_schedule)
response_optimiser = optax.adam(learning_rate=constant_schedule)

CORESET_SIZE = 500  # number of pairs to optimise
MAX_STEPS = 250  # maximum number of gradient steps
NUM_SEEDS = 10  # number of initial seeds for optimisation to check
CONVERGENCE_PARAMETER = 1e-7

NUM_RUNS = 20  # Number of runs of each coreset solver to use
run_keys = jr.split(jr.key(SEED), (NUM_RUNS,))

################################ COMPUTE HERD CORESETS ################################

print("\nConstruct the Joint Kernel Herding coresets...")
jkh_coresets, jkh_states = [], []
for i in range(NUM_RUNS):
    print(f"--------------------------- {i + 1}/{NUM_RUNS} ---------------------------")
    jkh_solver = ExactPseudoJointKernelHerding(
        coreset_size=CORESET_SIZE,
        random_key=run_keys[i],
        feature_kernel=feature_kernel,
        response_kernel=response_kernel,
        feature_optimiser=feature_optimiser,
        response_optimiser=response_optimiser,
        batch_size=1,  # Data is not used to estimate the loss function
        max_steps=MAX_STEPS,
        convergence_parameter=CONVERGENCE_PARAMETER,
        num_seeds=NUM_SEEDS,
        track_info=False,
        bias=BIAS,
        slope=SLOPE,
        feature_mean=X_MEAN,
        feature_standard_deviation=X_STANDARD_DEVIATION,
        response_standard_deviation=Y_STANDARD_DEVIATION,
    )

    jkh_coreset, jkh_state = jkh_solver.reduce(data)
    jkh_states.append(jkh_state)
    jkh_coresets.append((jkh_coreset.coreset.data, jkh_coreset.coreset.supervision))

print("\nConstruct the Average Conditional Kernel Herding coresets...")
ackh_coresets, ackh_states = [], []
for i in range(NUM_RUNS):
    print(f"--------------------------- {i + 1}/{NUM_RUNS} ---------------------------")
    ackh_solver = ExactAverageConditionalKernelHerding(
        coreset_size=CORESET_SIZE,
        random_key=run_keys[i],
        feature_kernel=feature_kernel,
        response_kernel=response_kernel,
        feature_optimiser=feature_optimiser,
        response_optimiser=response_optimiser,
        batch_size=1,  # Data is not used to estimate the loss function
        max_steps=MAX_STEPS,
        convergence_parameter=CONVERGENCE_PARAMETER,
        num_seeds=NUM_SEEDS,
        track_info=False,
        regularisation_parameter=optimal_regularisation_parameter,
        bias=BIAS,
        slope=SLOPE,
        feature_mean=X_MEAN,
        feature_standard_deviation=X_STANDARD_DEVIATION,
        response_standard_deviation=Y_STANDARD_DEVIATION,
    )

    ackh_coreset, ackh_state = ackh_solver.reduce(data)
    ackh_states.append(ackh_state)
    ackh_coresets.append((ackh_coreset.coreset.data, ackh_coreset.coreset.supervision))


############################### CONSTRUCT KIP CORESETS ###############################
# Set KIP solver specific hyperparams
MAX_ITERATIONS = 500

# Set up an array of coreset sizes to construct using KIP
NUM_CONSTRUCTIONS = 20
CORESET_SIZES = jnp.linspace(1, CORESET_SIZE, NUM_CONSTRUCTIONS).astype(jnp.int64)

print("\nConstruct the Average Conditional Kernel Inducing Point coresets...")
ackip_coresets, ackip_states = (
    [[] for _ in range(NUM_RUNS)],
    [[] for _ in range(NUM_RUNS)],
)
for i in range(NUM_RUNS):
    print(f"--------------------------- {i + 1}/{NUM_RUNS} ---------------------------")
    for j in range(NUM_CONSTRUCTIONS):
        ackip_solver = ExactAverageConditionalKIP(
            coreset_size=CORESET_SIZES[j],
            random_key=run_keys[i],
            feature_kernel=feature_kernel,
            response_kernel=response_kernel,
            regularisation_parameter=optimal_regularisation_parameter,
            max_iterations=MAX_ITERATIONS,
            target_sample_size=None,
            coreset_sample_size=None,
            convergence_parameter=CONVERGENCE_PARAMETER,
            feature_optimiser=feature_optimiser,
            response_optimiser=response_optimiser,
            track_info=False,
            num_seeds=NUM_SEEDS,
            bias=BIAS,
            slope=SLOPE,
            feature_mean=X_MEAN,
            feature_standard_deviation=X_STANDARD_DEVIATION,
            response_standard_deviation=Y_STANDARD_DEVIATION,
        )

        ackip_coreset, ackip_state = ackip_solver.reduce(data)
        ackip_states[i].append(ackip_state)
        ackip_coresets[i].append(
            (ackip_coreset.coreset.data, ackip_coreset.coreset.supervision)
        )

print("\nConstruct the Joint Kernel Inducing Point coresets...")
jkip_coresets, jkip_states = (
    [[] for _ in range(NUM_RUNS)],
    [[] for _ in range(NUM_RUNS)],
)
for i in range(NUM_RUNS):
    print(f"--------------------------- {i + 1}/{NUM_RUNS} ---------------------------")
    for j in range(NUM_CONSTRUCTIONS):
        jkip_solver = ExactJointKIP(
            coreset_size=CORESET_SIZES[j],
            random_key=run_keys[i],
            feature_kernel=feature_kernel,
            response_kernel=response_kernel,
            max_iterations=MAX_ITERATIONS,
            target_sample_size=None,
            coreset_sample_size=None,
            convergence_parameter=CONVERGENCE_PARAMETER,
            feature_optimiser=feature_optimiser,
            response_optimiser=response_optimiser,
            track_info=False,
            num_seeds=NUM_SEEDS,
            bias=BIAS,
            slope=SLOPE,
            feature_mean=X_MEAN,
            feature_standard_deviation=X_STANDARD_DEVIATION,
            response_standard_deviation=Y_STANDARD_DEVIATION,
        )

        jkip_coreset, jkip_state = jkip_solver.reduce(data)
        jkip_states[i].append(jkip_state)
        jkip_coresets[i].append(
            (jkip_coreset.coreset.data, jkip_coreset.coreset.supervision)
        )

# Save all the results in a dict
result_dict = {}
result_dict["SETUP"] = {}
result_dict["SETUP"]["SEED"] = SEED

result_dict["SETUP"]["DATA"] = {}
result_dict["SETUP"]["DATA"]["DATA_NAME"] = DATA_NAME
result_dict["SETUP"]["DATA"]["TRAINING"] = (features, responses)
result_dict["SETUP"]["DATA"]["TEST"] = (test_features, test_responses)
result_dict["SETUP"]["DATA"]["VALIDATION"] = (validation_features, validation_responses)
result_dict["SETUP"]["DATA"]["DATA_SIZE"] = DATA_SIZE
result_dict["SETUP"]["DATA"]["VALIDATION_SIZE"] = VALIDATION_SIZE
result_dict["SETUP"]["DATA"]["TEST_SIZE"] = TEST_SIZE
result_dict["SETUP"]["DATA"]["X_STANDARD_DEVIATION"] = X_STANDARD_DEVIATION
result_dict["SETUP"]["DATA"]["X_MEAN"] = X_MEAN
result_dict["SETUP"]["DATA"]["Y_STANDARD_DEVIATION"] = Y_STANDARD_DEVIATION
result_dict["SETUP"]["DATA"]["PARAMS"] = {}
result_dict["SETUP"]["DATA"]["PARAMS"]["BIAS"] = BIAS
result_dict["SETUP"]["DATA"]["PARAMS"]["SLOPE"] = SLOPE
result_dict["SETUP"]["DATA"]["FIRST_TERM"] = FIRST_TERM

result_dict["SETUP"]["CROSS_VALIDATION"] = {}
result_dict["SETUP"]["CROSS_VALIDATION"]["COARSE_GRID"] = COARSE_GRID
result_dict["SETUP"]["CROSS_VALIDATION"]["NUM_SETS"] = NUM_SETS
result_dict["SETUP"]["CROSS_VALIDATION"]["BATCH_SIZE"] = BATCH_SIZE
result_dict["SETUP"]["CROSS_VALIDATION"]["REGULARISATION_PARAMETER"] = (
    optimal_regularisation_parameter
)

result_dict["SETUP"]["CORESET_CONSTRUCTION"] = {}
result_dict["SETUP"]["CORESET_CONSTRUCTION"]["FEATURE_KERNEL"] = feature_kernel
result_dict["SETUP"]["CORESET_CONSTRUCTION"]["RESPONSE_KERNEL"] = response_kernel
result_dict["SETUP"]["CORESET_CONSTRUCTION"]["SCHEDULE"] = SCHEDULE
result_dict["SETUP"]["CORESET_CONSTRUCTION"]["CORESET_SIZE"] = CORESET_SIZE
result_dict["SETUP"]["CORESET_CONSTRUCTION"]["MAX_STEPS"] = MAX_STEPS
result_dict["SETUP"]["CORESET_CONSTRUCTION"]["NUM_SEEDS"] = NUM_SEEDS
result_dict["SETUP"]["CORESET_CONSTRUCTION"]["CONVERGENCE_PARAMETER"] = (
    CONVERGENCE_PARAMETER
)
result_dict["SETUP"]["CORESET_CONSTRUCTION"]["NUM_RUNS"] = NUM_RUNS
result_dict["SETUP"]["CORESET_CONSTRUCTION"]["MAX_ITERATIONS"] = MAX_ITERATIONS
result_dict["SETUP"]["CORESET_CONSTRUCTION"]["NUM_CONSTRUCTIONS"] = NUM_CONSTRUCTIONS
result_dict["SETUP"]["CORESET_CONSTRUCTION"]["CORESET_SIZES"] = CORESET_SIZES

result_dict["JKH_RESULTS"] = {}
result_dict["JKH_RESULTS"] = {}
result_dict["JKH_RESULTS"]["CORESETS"] = jkh_coresets
result_dict["JKH_RESULTS"]["STATES"] = jkh_states

result_dict["ACKH_RESULTS"] = {}
result_dict["ACKH_RESULTS"]["CORESETS"] = ackh_coresets
result_dict["ACKH_RESULTS"]["STATES"] = ackh_states

result_dict["JKIP_RESULTS"] = {}
result_dict["JKIP_RESULTS"]["CORESETS"] = jkip_coresets
result_dict["JKIP_RESULTS"]["STATES"] = jkip_states

result_dict["ACKIP_RESULTS"] = {}
result_dict["ACKIP_RESULTS"]["CORESETS"] = ackip_coresets
result_dict["ACKIP_RESULTS"]["STATES"] = ackip_states

jnp.save(f"results_{SEED}_{DATA_NAME}", result_dict)
