"""Compress the conditional distribution of a synthetic regression dataset."""

# Import required modules
import optax
import jax
import jax.numpy as jnp
import jax.random as jr
from jax import Array

from tqdm import tqdm
from sklearn.preprocessing import StandardScaler

from coreax.kernels import SquaredExponentialKernel, median_heuristic
from coreax.data import SupervisedData
from coreax.solvers import (
    AverageConditionalKIP,
    AverageConditionalKernelHerding,
    PseudoJointKernelHerding,
    JointKIP,
)

# Check device
print(f"Jax backend: {jax.default_backend()}")

# Enable float64
jax.config.update("jax_enable_x64", True)

# Set top-seed
SEED = 10001

#################################### GENERATE DATA ####################################
DATA_NAME = "heteroscedastic_non_linear"

DATA_SIZE = 8000
VALIDATION_SIZE = 1000
TEST_SIZE = 1000

# Generate features from a Gaussian distributions
X_STANDARD_DEVIATION = 2
X_MEAN = 0
features = X_MEAN + X_STANDARD_DEVIATION * jr.normal(
    jr.key(SEED),
    shape=(DATA_SIZE, 1),
)
test_features = X_MEAN + X_STANDARD_DEVIATION * jr.normal(
    jr.key(2 * SEED),
    shape=(TEST_SIZE, 1),
)
validation_features = X_MEAN + X_STANDARD_DEVIATION * jr.normal(
    jr.key(3 * SEED),
    shape=(VALIDATION_SIZE, 1),
)

# Set model parameters
A_PARAMS = jnp.array([3, -3, 6, -6])
B_PARAMS = jnp.array([-5, -2, 2, 5])
C_PARAMS = jnp.array([1, 0.1, 2, 0.5])


def generating_function(x, a_params=A_PARAMS, b_params=B_PARAMS, c_params=C_PARAMS):
    """Nonlinear generating function."""
    val = 0
    for i in range(a_params.shape[0]):
        val += a_params[i] * jnp.exp(-((x - b_params[i]) ** 2) / c_params[i])
    return val


Y_MULTIPLICATIVE_STANDARD_DEVIATION = 0.75
Y_ADDITIVE_STANDARD_DEVIATION = 0.1


def generating_noise(x):
    """Heteroscedastic noise function."""
    return Y_ADDITIVE_STANDARD_DEVIATION + jnp.abs(
        Y_MULTIPLICATIVE_STANDARD_DEVIATION * jnp.sin(x)
    )


# Generate responses
responses = generating_function(features) + generating_noise(features) * jr.normal(
    jr.key(4 * SEED), shape=(DATA_SIZE, 1)
)
test_responses = generating_function(test_features) + generating_noise(
    test_features
) * jr.normal(jr.key(5 * SEED), shape=(TEST_SIZE, 1))
validation_responses = generating_function(validation_features) + generating_noise(
    VALIDATION_SIZE
) * jr.normal(jr.key(6 * SEED), shape=(VALIDATION_SIZE, 1))

# Scale the features
feature_scaler = StandardScaler().fit(features)
features = jnp.asarray(feature_scaler.transform(features))
test_features = jnp.asarray(feature_scaler.transform(test_features))
validation_features = jnp.asarray(feature_scaler.transform(validation_features))

# Scale the responses
response_scaler = StandardScaler().fit(responses)
responses = jnp.asarray(response_scaler.transform(responses))
test_responses = jnp.asarray(response_scaler.transform(test_responses))
validation_responses = jnp.asarray(response_scaler.transform(validation_responses))

# Extract scale and mean
feature_scale = jnp.array(feature_scaler.scale_).item()
feature_mean = jnp.array(feature_scaler.mean_).item()
response_scale = jnp.array(response_scaler.scale_).item()
response_mean = jnp.array(response_scaler.mean_).item()


def func(x):
    """Nonlinear relationship between x and y, undo feature scaling."""
    x = x * feature_scale + feature_mean
    return (generating_function(x) - response_mean) / response_scale


def noise(x):
    """Heteroscedastic noise function, undo feature scaling."""
    x = x * feature_scale + feature_mean
    return generating_noise(x)


# Put the data in a container
data = SupervisedData(features, responses)


################################ SETUP KERNEL FUNCTIONS ################################
# Use median heuristic with a subset of the data to choose lengthscales
feature_kernel = SquaredExponentialKernel(median_heuristic(features[:1000, :]))
response_kernel = SquaredExponentialKernel(median_heuristic(responses[:1000, :]))

####################### CROSS-VALIDATE REGULARISATION PARAMETER #######################
# Define a coarse grid for initial check
COARSE_GRID = jnp.array(
    [1e1, 0.5e1, 1e0, 0.5e0, 1e-1, 0.5e-1, 1e-2, 0.5e-2, 1e-3, 0.5e-3, 1e-4]
)


# Define the validation loss function
def validation_loss(
    regularisation_parameter: Array,
    feature_gramian: Array,
    cross_feature_gramian: Array,
    response_gramian: Array,
    cross_response_gramian: Array,
):
    """Compute the loss on a validation set."""
    # Compute the coefficients
    coefficients = jnp.linalg.solve(
        a=feature_gramian
        + regularisation_parameter * jnp.eye(feature_gramian.shape[0]),
        b=cross_feature_gramian,
    )

    # Compute the terms of the validation loss
    term_2 = (coefficients * cross_response_gramian).sum()
    term_3 = (coefficients.T.dot(response_gramian) * coefficients.T).sum()

    return (1 / cross_feature_gramian.shape[1]) * (-2 * term_2 + term_3)


# Vmap the validation loss ready for evaluating across an array of regularisation
# parameters.
vmapped_validation_loss = jax.vmap(validation_loss, in_axes=(0, None, None, None, None))

# Define the number of subsets we will take, and the size of those subsets
NUM_SETS = 10
BATCH_SIZE = 1000
coarse_losses = jnp.zeros((NUM_SETS, len(COARSE_GRID)))
rough_keys = jr.split(jr.key(SEED), (NUM_SETS,))
print("\nFinding average rough minima...")
for i in tqdm(range(NUM_SETS)):
    # Sample the batch
    batch_indices = jr.choice(
        rough_keys[i], features.shape[0], shape=(BATCH_SIZE,), replace=False
    )

    # Compute the validation loss over the coarse grid
    coarse_losses = coarse_losses.at[i, :].set(
        vmapped_validation_loss(
            COARSE_GRID,
            feature_kernel.compute(features[batch_indices], features[batch_indices]),
            feature_kernel.compute(features[batch_indices], validation_features),
            response_kernel.compute(responses[batch_indices], responses[batch_indices]),
            response_kernel.compute(responses[batch_indices], validation_responses),
        )
    )
# Average the losses and choose the region to zoom into that contains the minima
averaged_coarse_losses = coarse_losses.mean(axis=0)
min_idx = jnp.nanargmin(averaged_coarse_losses) - 1
max_idx = jnp.nanargmin(averaged_coarse_losses) + 1

FINE_SIZE = 25
fine_losses = jnp.zeros((NUM_SETS, FINE_SIZE))
fine_grid = jnp.linspace(COARSE_GRID[min_idx], COARSE_GRID[max_idx], FINE_SIZE)
fine_keys = jr.split(rough_keys[-1], (NUM_SETS,))
print("Finding average fine minima...")
for i in tqdm(range(NUM_SETS)):
    # Get indices
    batch_indices = jr.choice(
        fine_keys[i], features.shape[0], shape=(BATCH_SIZE,), replace=False
    )
    # Compute the validation loss over the coarse grid
    fine_losses = fine_losses.at[i, :].set(
        vmapped_validation_loss(
            fine_grid,
            feature_kernel.compute(features[batch_indices], features[batch_indices]),
            feature_kernel.compute(features[batch_indices], validation_features),
            response_kernel.compute(responses[batch_indices], responses[batch_indices]),
            response_kernel.compute(responses[batch_indices], validation_responses),
        )
    )

# Set the optimal regularisation parameter
optimal_regularisation_parameter = fine_grid[
    jnp.nanargmin(fine_losses.mean(axis=0))
].item()
print(f"Optimal Regularisation Parameter: {optimal_regularisation_parameter}")

################################ SET SOLVER HYPERPARAMS ################################

SCHEDULE = 1e-1
constant_schedule = optax.constant_schedule(SCHEDULE)
feature_optimiser = optax.adam(learning_rate=constant_schedule)
response_optimiser = optax.adam(learning_rate=constant_schedule)

CORESET_SIZE = 250  # number of pairs to optimise
MAX_STEPS = 250  # maximum number of gradient steps
NUM_SEEDS = 10  # number of initial seeds for optimisation to check
CONVERGENCE_PARAMETER = 1e-7

NUM_RUNS = 20  # Number of runs of each coreset solver to use
run_keys = jr.split(jr.key(SEED), (NUM_RUNS,))

################################ COMPUTE HERD CORESETS ################################

print("\nConstruct the Joint Kernel Herding coresets...")
jkh_coresets, jkh_states = [], []
for i in range(NUM_RUNS):
    print(f"--------------------------- {i + 1}/{NUM_RUNS} ---------------------------")
    jkh_solver = PseudoJointKernelHerding(
        coreset_size=CORESET_SIZE,
        random_key=run_keys[i],
        feature_kernel=feature_kernel,
        response_kernel=response_kernel,
        feature_optimiser=feature_optimiser,
        response_optimiser=response_optimiser,
        batch_size=None,
        max_steps=MAX_STEPS,
        convergence_parameter=CONVERGENCE_PARAMETER,
        num_seeds=NUM_SEEDS,
        track_info=False,
    )

    jkh_coreset, jkh_state = jkh_solver.reduce(data)
    jkh_states.append(jkh_state)
    jkh_coresets.append((jkh_coreset.coreset.data, jkh_coreset.coreset.supervision))

print("\nConstruct the Average Conditional Kernel Herding coresets...")
ackh_coresets, ackh_states = [], []
for i in range(NUM_RUNS):
    print(f"--------------------------- {i + 1}/{NUM_RUNS} ---------------------------")
    ackh_solver = AverageConditionalKernelHerding(
        coreset_size=CORESET_SIZE,
        random_key=run_keys[i],
        feature_kernel=feature_kernel,
        response_kernel=response_kernel,
        feature_optimiser=feature_optimiser,
        response_optimiser=response_optimiser,
        batch_size=None,
        max_steps=MAX_STEPS,
        convergence_parameter=CONVERGENCE_PARAMETER,
        num_seeds=NUM_SEEDS,
        track_info=False,
        regularisation_parameter=optimal_regularisation_parameter,
    )

    ackh_coreset, ackh_state = ackh_solver.reduce(data)
    ackh_states.append(ackh_state)
    ackh_coresets.append((ackh_coreset.coreset.data, ackh_coreset.coreset.supervision))


############################### CONSTRUCT KIP CORESETS ###############################
# Set KIP solver specific hyperparams
MAX_ITERATIONS = 500

# Set up an array of coreset sizes to construct using KIP
NUM_CONSTRUCTIONS = 20
CORESET_SIZES = jnp.linspace(1, CORESET_SIZE, NUM_CONSTRUCTIONS).astype(jnp.int64)

print("\nConstruct the Average Conditional Kernel Inducing Point coresets...")
ackip_coresets, ackip_states = (
    [[] for _ in range(NUM_RUNS)],
    [[] for _ in range(NUM_RUNS)],
)
for i in range(NUM_RUNS):
    print(f"--------------------------- {i + 1}/{NUM_RUNS} ---------------------------")
    for j in range(NUM_CONSTRUCTIONS):
        ackip_solver = AverageConditionalKIP(
            coreset_size=CORESET_SIZES[j],
            random_key=run_keys[i],
            feature_kernel=feature_kernel,
            response_kernel=response_kernel,
            regularisation_parameter=optimal_regularisation_parameter,
            max_iterations=MAX_ITERATIONS,
            target_sample_size=None,
            coreset_sample_size=None,
            convergence_parameter=CONVERGENCE_PARAMETER,
            feature_optimiser=feature_optimiser,
            response_optimiser=response_optimiser,
            track_info=False,
            num_seeds=NUM_SEEDS,
        )

        ackip_coreset, ackip_state = ackip_solver.reduce(data)
        ackip_states[i].append(ackip_state)
        ackip_coresets[i].append(
            (ackip_coreset.coreset.data, ackip_coreset.coreset.supervision)
        )

print("\nConstruct the Joint Kernel Inducing Point coresets...")
jkip_coresets, jkip_states = (
    [[] for _ in range(NUM_RUNS)],
    [[] for _ in range(NUM_RUNS)],
)
for i in range(NUM_RUNS):
    print(f"--------------------------- {i + 1}/{NUM_RUNS} ---------------------------")
    for j in range(NUM_CONSTRUCTIONS):
        jkip_solver = JointKIP(
            coreset_size=CORESET_SIZES[j],
            random_key=run_keys[i],
            feature_kernel=feature_kernel,
            response_kernel=response_kernel,
            max_iterations=MAX_ITERATIONS,
            target_sample_size=None,
            coreset_sample_size=None,
            convergence_parameter=CONVERGENCE_PARAMETER,
            feature_optimiser=feature_optimiser,
            response_optimiser=response_optimiser,
            track_info=False,
            num_seeds=NUM_SEEDS,
        )

        jkip_coreset, jkip_state = jkip_solver.reduce(data)
        jkip_states[i].append(jkip_state)
        jkip_coresets[i].append(
            (jkip_coreset.coreset.data, jkip_coreset.coreset.supervision)
        )

# Save all the results in a dict
result_dict = {}
result_dict["SETUP"] = {}
result_dict["SETUP"]["SEED"] = SEED

result_dict["SETUP"]["DATA"] = {}
result_dict["SETUP"]["DATA"]["DATA_NAME"] = DATA_NAME
result_dict["SETUP"]["DATA"]["TRAINING"] = (features, responses)
result_dict["SETUP"]["DATA"]["TEST"] = (test_features, test_responses)
result_dict["SETUP"]["DATA"]["VALIDATION"] = (validation_features, validation_responses)
result_dict["SETUP"]["DATA"]["DATA_SIZE"] = DATA_SIZE
result_dict["SETUP"]["DATA"]["VALIDATION_SIZE"] = VALIDATION_SIZE
result_dict["SETUP"]["DATA"]["TEST_SIZE"] = TEST_SIZE
result_dict["SETUP"]["DATA"]["X_STANDARD_DEVIATION"] = X_STANDARD_DEVIATION
result_dict["SETUP"]["DATA"]["X_MEAN"] = X_MEAN
result_dict["SETUP"]["DATA"]["PARAMS"] = {}
result_dict["SETUP"]["DATA"]["PARAMS"]["A_PARAMS"] = A_PARAMS
result_dict["SETUP"]["DATA"]["PARAMS"]["B_PARAMS"] = B_PARAMS
result_dict["SETUP"]["DATA"]["PARAMS"]["C_PARAMS"] = C_PARAMS
result_dict["SETUP"]["DATA"]["Y_ADDITIVE_STANDARD_DEVIATION"] = (
    Y_ADDITIVE_STANDARD_DEVIATION
)
result_dict["SETUP"]["DATA"]["Y_MULTIPLICATIVE_STANDARD_DEVIATION"] = (
    Y_MULTIPLICATIVE_STANDARD_DEVIATION
)
result_dict["SETUP"]["DATA"]["FEATURE_SCALER"] = {}
result_dict["SETUP"]["DATA"]["FEATURE_SCALER"]["SCALER"] = feature_scaler
result_dict["SETUP"]["DATA"]["FEATURE_SCALER"]["MEAN"] = feature_scaler.mean_
result_dict["SETUP"]["DATA"]["FEATURE_SCALER"]["SCALE"] = feature_scaler.scale_
result_dict["SETUP"]["DATA"]["RESPONSE_SCALER"] = {}
result_dict["SETUP"]["DATA"]["RESPONSE_SCALER"]["SCALER"] = response_scaler
result_dict["SETUP"]["DATA"]["RESPONSE_SCALER"]["MEAN"] = response_scaler.mean_
result_dict["SETUP"]["DATA"]["RESPONSE_SCALER"]["SCALE"] = response_scaler.scale_

result_dict["SETUP"]["CROSS_VALIDATION"] = {}
result_dict["SETUP"]["CROSS_VALIDATION"]["COARSE_GRID"] = COARSE_GRID
result_dict["SETUP"]["CROSS_VALIDATION"]["NUM_SETS"] = NUM_SETS
result_dict["SETUP"]["CROSS_VALIDATION"]["BATCH_SIZE"] = BATCH_SIZE
result_dict["SETUP"]["CROSS_VALIDATION"]["REGULARISATION_PARAMETER"] = (
    optimal_regularisation_parameter
)

result_dict["SETUP"]["CORESET_CONSTRUCTION"] = {}
result_dict["SETUP"]["CORESET_CONSTRUCTION"]["FEATURE_KERNEL"] = feature_kernel
result_dict["SETUP"]["CORESET_CONSTRUCTION"]["RESPONSE_KERNEL"] = response_kernel
result_dict["SETUP"]["CORESET_CONSTRUCTION"]["SCHEDULE"] = SCHEDULE
result_dict["SETUP"]["CORESET_CONSTRUCTION"]["FEATURE_OPTIMISER"] = feature_optimiser
result_dict["SETUP"]["CORESET_CONSTRUCTION"]["RESPONSE_OPTIMISER"] = response_optimiser
result_dict["SETUP"]["CORESET_CONSTRUCTION"]["CORESET_SIZE"] = CORESET_SIZE
result_dict["SETUP"]["CORESET_CONSTRUCTION"]["MAX_STEPS"] = MAX_STEPS
result_dict["SETUP"]["CORESET_CONSTRUCTION"]["NUM_SEEDS"] = NUM_SEEDS
result_dict["SETUP"]["CORESET_CONSTRUCTION"]["CONVERGENCE_PARAMETER"] = (
    CONVERGENCE_PARAMETER
)
result_dict["SETUP"]["CORESET_CONSTRUCTION"]["NUM_RUNS"] = NUM_RUNS
result_dict["SETUP"]["CORESET_CONSTRUCTION"]["MAX_ITERATIONS"] = MAX_ITERATIONS
result_dict["SETUP"]["CORESET_CONSTRUCTION"]["NUM_CONSTRUCTIONS"] = NUM_CONSTRUCTIONS
result_dict["SETUP"]["CORESET_CONSTRUCTION"]["CORESET_SIZES"] = CORESET_SIZES

result_dict["JKH_RESULTS"] = {}
result_dict["JKH_RESULTS"] = {}
result_dict["JKH_RESULTS"]["CORESETS"] = jkh_coresets
result_dict["JKH_RESULTS"]["STATES"] = jkh_states

result_dict["ACKH_RESULTS"] = {}
result_dict["ACKH_RESULTS"]["CORESETS"] = ackh_coresets
result_dict["ACKH_RESULTS"]["STATES"] = ackh_states

result_dict["JKIP_RESULTS"] = {}
result_dict["JKIP_RESULTS"]["CORESETS"] = jkip_coresets
result_dict["JKIP_RESULTS"]["STATES"] = jkip_states

result_dict["ACKIP_RESULTS"] = {}
result_dict["ACKIP_RESULTS"]["CORESETS"] = ackip_coresets
result_dict["ACKIP_RESULTS"]["STATES"] = ackip_states

jnp.save(f"results_{SEED}_{DATA_NAME}", result_dict)
