"""Construct Bilateral Coresets."""

#################################### Import Modules ####################################

from time import time
from typing import Callable

import jax
import jax.random as jr
import jax.numpy as jnp
import equinox as eqx
import optax
from sklearn.preprocessing import StandardScaler
from jax import vmap
from coreax.solvers import (
    LinearBilateralDistributionCompression,
    NonLinearBilateralDistributionCompression,
    KernelInducingPoints,
    BaseCoder,
)

from coreax.data import Data
from coreax.kernels import SquaredExponentialKernel, median_heuristic
from coreax.solvers.autoencoders import Autoencoder, Encoder, Decoder

################################### Set JAX Settings ###################################

jax.config.update("jax_platform_name", "gpu")
jax.config.update("jax_enable_x64", True)
print(f"Jax backend: {jax.default_backend()}")

################################## Set File Settings ###################################

SEED = 10001

CORESET_SIZE = 200
AMBIENT_DIMENSION = 500
NUM_PROJECTION_EPOCHS = 10
NUM_AUTOENCODER_EPOCHS = 20

MAX_CORESET_ITERATIONS = 2500
NUM_CORESET_SEEDS = 5
NUM_PROJECTION_SEEDS = 1
ORTHONORMAL = True
PROJECTION_BATCH_SIZE = 1024
NUM_CONSTRUCTIONS = 25
CONVERGENCE_PARAMETER = 1e-8
MEDIAN_HEURISTIC_DIVISOR = 10

PROJECTION_RATE = 1e-3
AUTOENCODER_RATE = 1e-3
CORESET_FEATURE_RATE = 1e-3
CORESET_RESPONSE_RATE = 1e-3

projection_optimiser = optax.adam(
    learning_rate=optax.constant_schedule(PROJECTION_RATE)
)
autoencoder_optimiser = optax.adam(
    learning_rate=optax.constant_schedule(AUTOENCODER_RATE)
)
coreset_optimiser = optax.adam(
    learning_rate=optax.constant_schedule(CORESET_FEATURE_RATE)
)

##################################### Generate Data ####################################
print("\n------------------------ Loading Data ------------------------\n")
key, subkey = jr.split(jr.PRNGKey(SEED), 2)

DATA_NAME = "clusters"

# pylint:disable=invalid-name


def sample_triangle(sample_key: jax.Array, vertices: jax.Array, n: int) -> jax.Array:
    """Sample triangle"""
    key_u, key_v = jax.random.split(sample_key)
    u = jax.random.uniform(key_u, shape=(n, 1))
    v = jax.random.uniform(key_v, shape=(n, 1))

    # Triangle flip trick to ensure uniformity on the simplex
    mask = (u + v) > 1.0
    u = jnp.where(mask, 1.0 - u, u)
    v = jnp.where(mask, 1.0 - v, v)

    # Barycentric weights (w0, w1, w2) sum to 1 and are >= 0
    w0 = 1.0 - u - v
    w1 = u
    w2 = v
    W = jnp.concatenate([w0, w1, w2], axis=1)  # (n, 3)

    # Convex combination of vertices: (n,3) @ (3,d) -> (n,d)
    points = W @ vertices
    return points


def sample_annulus(
    sample_key: jax.Array,
    n: int,
    center: jax.Array,
    r_inner: float = 5.0,
    r_outer: float = 10.0,
):
    """Sample annulus."""
    key_u, key_ang = jax.random.split(sample_key, 2)

    u = jax.random.uniform(key_u, shape=(n,))  # in [0,1]
    r = jnp.sqrt((r_outer**2 - r_inner**2) * u + r_inner**2)

    theta = jax.random.uniform(key_ang, shape=(n,), minval=0.0, maxval=2 * jnp.pi)
    xy = jnp.stack([r * jnp.cos(theta), r * jnp.sin(theta)], axis=1)
    return xy + center


def sample_spiral(
    sample_key, n, turns=3.0, a=0.0, b=0.25, center=(0.0, 0.0), noise=0.0
):
    """Archimedean spiral"""
    t = jnp.linspace(0.0, 2 * jnp.pi * turns, n)
    r = a + b * t
    x = r * jnp.cos(t)
    y = r * jnp.sin(t)
    pts = jnp.stack([x, y], axis=1)
    if noise > 0:
        eps = jax.random.normal(sample_key, shape=pts.shape) * noise
        pts = pts + eps
    c = jnp.array(center)
    return pts + c


def sample_x_cross(sample_key, n, arm_length=1.0, arm_width=0.1, center=(0.0, 0.0)):
    """Sample uniformly from an X-shaped cross."""
    # Choose diagonal: +1 -> y=x ( +45° ), -1 -> y=-x ( -45° )
    key_choice, key_u, key_v = jax.random.split(sample_key, 3)
    diag_sign = jnp.where(jax.random.bernoulli(key_choice, 0.5, (n,)), 1.0, -1.0)

    # Coordinates in the bar’s local frame: u along, v across
    u = jax.random.uniform(key_u, (n,)) * (2 * arm_length) - arm_length
    v = jax.random.uniform(key_v, (n,)) * arm_width - arm_width / 2

    # Rotate local (u,v) to global (x,y)
    # For diag_sign = +1 (y=x):  +45°; for -1 (y=-x): -45°
    inv_sqrt2 = 1.0 / jnp.sqrt(2.0)
    x = (u - diag_sign * v) * inv_sqrt2
    y = (diag_sign * u + v) * inv_sqrt2

    return jnp.stack([x, y], axis=1) + jnp.array(center)


def generate_clusters(n_noise: int, n_cluster: int, seed: int):
    """Generate clusters."""
    x = jr.uniform(jr.key(seed), minval=-38, maxval=40, shape=(n_noise, 1))
    y = jr.uniform(jr.key(seed + 1), minval=-38, maxval=30, shape=(n_noise, 1))
    X_noise = jnp.hstack((x, y))
    Y_noise = -1 * jnp.ones(X_noise.shape[0])

    X1 = jr.multivariate_normal(
        jr.key(seed),
        mean=jnp.array([-10, -10]),
        cov=jnp.array([[5, 0], [-3, 5]]),
        shape=(n_cluster,),
    )
    Y1 = 0 * jnp.ones(X1.shape[0])

    # x = 3 + 5 * jr.normal(jr.key(seed), shape=(n_cluster,))
    x = jr.uniform(jr.key(seed), minval=-10, maxval=18, shape=(n_cluster,))
    y = (
        1 / 5 * x
        + 5
        + jnp.sin(x)
        + 1 * jr.normal(jr.key(seed + 1), shape=(x.shape[0],))
    )
    X2 = jnp.hstack((x.reshape(-1, 1), y.reshape(-1, 1)))
    Y2 = 1 * jnp.ones(X2.shape[0])

    X3 = sample_annulus(jr.key(seed), n_cluster, jnp.array([-19, 19]), 4, 8)
    Y3 = 2 * jnp.ones(X3.shape[0])

    x = jr.uniform(jr.key(seed), minval=15, maxval=35, shape=(n_cluster, 1))
    y = jr.uniform(jr.key(seed + 1), minval=-35, maxval=-25, shape=(n_cluster, 1))
    X4 = jnp.hstack((x, y))
    Y4 = 3 * jnp.ones(X4.shape[0])

    # x = 20 + 3 * jr.normal(jr.key(seed), shape=(n_cluster,))
    x = jr.uniform(jr.key(seed), minval=25, maxval=35, shape=(n_cluster,))
    y = -38 + 1 / 30 * x**2 + 2 * jr.normal(jr.key(seed + 1), shape=(x.shape[0],))
    X5 = jnp.hstack((x.reshape(-1, 1), y.reshape(-1, 1)))
    Y5 = 4 * jnp.ones(X5.shape[0])

    # y = -12 + 4.5 * jr.normal(jr.key(seed), shape=(n_cluster,))
    y = jr.uniform(jr.key(seed), minval=-20, maxval=7, shape=(n_cluster,))
    x = (
        -28
        + 1 / 5 * y
        + 5
        + jnp.cos(1 / 2 * y)
        + 1 * jr.normal(jr.key(seed + 1), shape=(y.shape[0],))
    )
    X6 = jnp.hstack((x.reshape(-1, 1), y.reshape(-1, 1)))
    Y6 = 5 * jnp.ones(X6.shape[0])

    A = jnp.array([0, 25])
    B = jnp.array([34, 25])
    C = jnp.array([34, 12])
    V = jnp.stack([A, B, C], axis=0)
    X7 = sample_triangle(jr.key(seed), V, n=n_cluster)
    Y7 = 6 * jnp.ones(X7.shape[0])

    A = jnp.array([4, -25])
    B = jnp.array([-10, -30])
    C = jnp.array([-10, -20])
    V = jnp.stack([A, B, C], axis=0)
    X8 = sample_triangle(jr.key(seed), V, n=n_cluster // 2)
    Y8 = 7 * jnp.ones(X8.shape[0])

    A = jnp.array([-4, -25])
    B = jnp.array([10, -30])
    C = jnp.array([10, -20])
    V = jnp.stack([A, B, C], axis=0)
    X9 = sample_triangle(jr.key(seed), V, n=n_cluster // 2)
    Y9 = 7 * jnp.ones(X9.shape[0])

    X10 = sample_spiral(
        jr.key(SEED), n=n_cluster, turns=1.15, a=0, b=1.25, center=(8, -9), noise=0.5
    )
    Y10 = 8 * jnp.ones(X10.shape[0])

    X11 = sample_x_cross(
        jr.key(SEED), n=n_cluster, arm_length=8, arm_width=3, center=(-25, -29)
    )
    Y11 = 9 * jnp.ones(X10.shape[0])

    x = jnp.vstack((X_noise, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11))
    y = jnp.hstack((Y_noise, Y1, Y2, Y3, Y4, Y5, Y6, Y7, Y8, Y9, Y10, Y11))

    return jr.permutation(
        jr.key(seed), StandardScaler().fit_transform(x), axis=0
    ), jr.permutation(jr.key(seed), y, axis=0)


N_CLUSTERS = 10000
N_NOISE = 2000
X_INTRINSIC, Y = generate_clusters(N_NOISE, N_CLUSTERS, SEED)
X_INTRINSIC_TEST, Y_TEST = generate_clusters(N_NOISE // 10, N_CLUSTERS // 10, SEED * 2)

############################# Project to Ambient Dimension #############################


class TanhDecoder(BaseCoder):
    """Decoder using tanh activations."""

    num_hidden_layers: int

    def __init__(
        self,
        random_key,
        ambient_dimension: int,
        intrinsic_dimension: int,
        num_hidden_layers: int,
        hidden_layer_sizes: list,
        output_transformation: Callable = jax.nn.identity,
    ):
        """Initialise Decoder."""
        assert len(hidden_layer_sizes) == num_hidden_layers, (
            "Number of hidden layers should equal number of provided layer sizes."
        )

        input_key, output_key = jax.random.split(random_key)
        layers = [
            eqx.nn.Linear(
                intrinsic_dimension,
                hidden_layer_sizes[0],
                key=input_key,
            ),
            jax.nn.tanh,
        ]

        hidden_keys = jax.random.split(input_key, num_hidden_layers)
        for j in range(num_hidden_layers):
            if j == num_hidden_layers - 1:
                layers.append(
                    eqx.nn.Linear(
                        hidden_layer_sizes[j],
                        ambient_dimension,
                        key=output_key,
                    )
                )
            else:
                layers.append(
                    eqx.nn.Linear(
                        hidden_layer_sizes[j],
                        hidden_layer_sizes[j + 1],
                        key=hidden_keys[j],
                    )
                )
                layers.append(jax.nn.tanh)

        # Add output layer
        layers.append(output_transformation)
        self.layers = layers
        self.num_hidden_layers = num_hidden_layers


# Initialise decoder using tanh activations
decoder = TanhDecoder(
    random_key=jr.key(SEED),
    ambient_dimension=500,
    intrinsic_dimension=2,
    num_hidden_layers=4,
    hidden_layer_sizes=[32, 64, 128, 256],
)
X = vmap(decoder)(X_INTRINSIC)
X_TEST = vmap(decoder)(X_INTRINSIC_TEST)

###################################### Save Setup ######################################

setup_dict = {}

setup_dict["SETUP"] = {}
setup_dict["SETUP"]["SEED"] = SEED
setup_dict["SETUP"]["DATA_NAME"] = DATA_NAME
setup_dict["SETUP"]["N_CLUSTERS"] = N_CLUSTERS
setup_dict["SETUP"]["N_NOISE"] = N_NOISE
setup_dict["SETUP"]["CORESET_SIZE"] = CORESET_SIZE
setup_dict["SETUP"]["AMBIENT_DIMENSION"] = AMBIENT_DIMENSION
setup_dict["SETUP"]["NUM_PROJECTION_EPOCHS"] = NUM_PROJECTION_EPOCHS
setup_dict["SETUP"]["NUM_AUTOENCODER_EPOCHS"] = NUM_AUTOENCODER_EPOCHS
setup_dict["SETUP"]["MAX_CORESET_ITERATIONS"] = MAX_CORESET_ITERATIONS
setup_dict["SETUP"]["NUM_CORESET_SEEDS"] = NUM_CORESET_SEEDS
setup_dict["SETUP"]["NUM_PROJECTION_SEEDS"] = NUM_PROJECTION_SEEDS
setup_dict["SETUP"]["ORTHONORMAL"] = ORTHONORMAL
setup_dict["SETUP"]["PROJECTION_BATCH_SIZE"] = PROJECTION_BATCH_SIZE
setup_dict["SETUP"]["AUTOENCODER_BATCH_SIZE"] = PROJECTION_BATCH_SIZE
setup_dict["SETUP"]["CONVERGENCE_PARAMETER"] = CONVERGENCE_PARAMETER
setup_dict["SETUP"]["PROJECTION_RATE"] = PROJECTION_RATE
setup_dict["SETUP"]["AUTOENCODER_RATE"] = AUTOENCODER_RATE
setup_dict["SETUP"]["CORESET_FEATURE_RATE"] = CORESET_FEATURE_RATE
setup_dict["SETUP"]["CORESET_RESPONSE_RATE"] = CORESET_RESPONSE_RATE
setup_dict["SETUP"]["NUM_CONSTRUCTIONS"] = NUM_CONSTRUCTIONS
setup_dict["SETUP"]["MEDIAN_HEURISTIC_DIVISOR"] = MEDIAN_HEURISTIC_DIVISOR

setup_dict["DATA"] = {}
setup_dict["DATA"]["X_INTRINSIC"] = X_INTRINSIC
setup_dict["DATA"]["X"] = X
setup_dict["DATA"]["Y"] = Y
setup_dict["DATA"]["X_INTRINSIC_TEST"] = X_INTRINSIC_TEST
setup_dict["DATA"]["X_TEST"] = X_TEST
setup_dict["DATA"]["Y_TEST"] = Y_TEST

jnp.save(f"./synthetic/{DATA_NAME}/setup_{SEED}_{DATA_NAME}", setup_dict)

################################## Construct Coresets ##################################
reconstruction_kernel = SquaredExponentialKernel(median_heuristic(X[:1000]))

construct_keys = jr.split(jr.key(SEED), num=(NUM_CONSTRUCTIONS,))

for i in range(NUM_CONSTRUCTIONS):
    print(f"\n------- {i + 1}/{NUM_CONSTRUCTIONS} -------\n")
    print("BDC-NL...")
    nonlinear_solver = NonLinearBilateralDistributionCompression(
        coreset_size=CORESET_SIZE,
        autoencoder=Autoencoder(
            encoder=Encoder(
                random_key=construct_keys[i],
                ambient_dimension=X.shape[1],
                intrinsic_dimension=2,
                num_hidden_layers=1,
                hidden_layer_sizes=[64],
            ),
            decoder=Decoder(
                random_key=construct_keys[i],
                ambient_dimension=X.shape[1],
                intrinsic_dimension=2,
                num_hidden_layers=1,
                hidden_layer_sizes=[64],
            ),
        ),
        random_key=construct_keys[i],
        reconstruction_kernel=reconstruction_kernel,
        compression_kernel=MEDIAN_HEURISTIC_DIVISOR,
        num_autoencoder_epochs=NUM_AUTOENCODER_EPOCHS,
        autoencoder_optimiser=autoencoder_optimiser,
        num_coreset_seeds=NUM_CORESET_SEEDS,
        max_coreset_iterations=MAX_CORESET_ITERATIONS,
        convergence_parameter=CONVERGENCE_PARAMETER,
        coreset_optimiser=coreset_optimiser,
        track_info=False,
        autoencoder_batch_size=PROJECTION_BATCH_SIZE,
        validation_data=X_TEST[:1000],
    )

    t0 = time()
    nonlinear_coreset, nonlinear_state = nonlinear_solver.reduce(Data(X))
    nonlinear_timer = time() - t0

    # Save autoencoder ready for deserialisation
    eqx.tree_serialise_leaves(
        f"./synthetic/{DATA_NAME}/{DATA_NAME}_model_{i}.eqx",
        nonlinear_state.autoencoder,
    )

    print("BDC-L...")
    linear_solver = LinearBilateralDistributionCompression(
        coreset_size=CORESET_SIZE,
        intrinsic_dimension=5,
        random_key=construct_keys[i],
        reconstruction_kernel=reconstruction_kernel,
        compression_kernel=MEDIAN_HEURISTIC_DIVISOR,
        orthonormal=ORTHONORMAL,
        num_projection_seeds=NUM_PROJECTION_SEEDS,
        num_projection_epochs=NUM_PROJECTION_EPOCHS,
        projection_optimiser=projection_optimiser,
        num_coreset_seeds=NUM_CORESET_SEEDS,
        max_coreset_iterations=MAX_CORESET_ITERATIONS,
        projection_convergence_parameter=CONVERGENCE_PARAMETER,
        coreset_convergence_parameter=CONVERGENCE_PARAMETER,
        coreset_optimiser=coreset_optimiser,
        track_info=False,
        projection_batch_size=PROJECTION_BATCH_SIZE,
        validation_data=X_TEST[:1000],
    )

    t0 = time()
    linear_coreset, linear_state = linear_solver.reduce(Data(X))
    linear_timer = time() - t0

    print("ADC...")
    kip_solver = KernelInducingPoints(
        coreset_size=CORESET_SIZE,
        random_key=construct_keys[i],
        kernel=SquaredExponentialKernel(
            median_heuristic(X[:2500]) / MEDIAN_HEURISTIC_DIVISOR
        ),
        target_sample_size=None,
        coreset_sample_size=None,
        num_seeds=NUM_CORESET_SEEDS,
        max_iterations=MAX_CORESET_ITERATIONS,
        convergence_parameter=CONVERGENCE_PARAMETER,
        optimiser=coreset_optimiser,
        track_info=False,
    )

    t0 = time()
    kip_coreset, kip_state = kip_solver.reduce(Data(X))
    kip_timer = time() - t0

    ################################# Save the result #################################

    result_dict = {}

    result_dict["RESULTS"] = {}

    result_dict["RESULTS"]["LINEAR_RESULTS"] = {}
    result_dict["RESULTS"]["LINEAR_RESULTS"]["LINEAR_AUTOENCODER"] = (
        linear_state.autoencoder
    )
    result_dict["RESULTS"]["LINEAR_RESULTS"]["LINEAR_CORESET"] = linear_coreset
    result_dict["RESULTS"]["LINEAR_RESULTS"]["LINEAR_TIMES"] = linear_timer

    result_dict["RESULTS"]["NONLINEAR_RESULTS"] = {}
    result_dict["RESULTS"]["NONLINEAR_RESULTS"]["NONLINEAR_CORESET"] = nonlinear_coreset
    result_dict["RESULTS"]["NONLINEAR_RESULTS"]["NONLINEAR_TIME"] = nonlinear_timer

    result_dict["RESULTS"]["KIP_RESULTS"] = {}
    result_dict["RESULTS"]["KIP_RESULTS"]["KIP_CORESET"] = kip_coreset
    result_dict["RESULTS"]["KIP_RESULTS"]["KIP_TIME"] = kip_timer

    jnp.save(f"./synthetic/{DATA_NAME}/results_{i}_{SEED}_{DATA_NAME}", result_dict)
