import numpy as np
from scipy.spatial.transform import Rotation
import miniball
from geom_median.numpy import compute_geometric_median

np.random.seed(2023)

# ----------------------------------------------------------------------------------------------------------------------

# index
n = 10
l = np.arange(n + 1)
i = np.stack(np.meshgrid(l, l, l)).reshape(3, -1)

# factor
y = i / n

# transformation
r = Rotation.random()
k = lambda y: np.exp(r.apply(y.squeeze().T).T)
t = lambda y: k(k(y))
t_min = t(y).min(1)
t_ptp = t(y).ptp(1)
g = lambda y: (t(y) - t_min[:, None]) / t_ptp[:, None]

# observation
x = g(y)

# representation
z_entanglement = x.copy()[..., None]
z_rotation = r.apply(y.squeeze().T).T[..., None]
z_duplicate = (y.T, y.T, y[2:].T)
z_complement = (y[[1, 2]].T, y[[0, 2]].T, y[[0, 1]].T)
z_misalignment = (y[[1]].T, y[[2]].T, y[[0]].T)
z_redundancy = (np.stack([y[0], -y[0]], axis=1), y[1][:, None], y[2][:, None])
z_contraction = 0.01 * y.copy()[..., None]
z_nonlinear = y.copy()[..., None] ** 2
z_constant = np.zeros_like(x)[..., None]
z_random = np.random.uniform(0, 1, size=x.shape)[..., None]

# ----------------------------------------------------------------------------------------------------------------------

def q_product(y: np.ndarray, z: np.ndarray, aggregate, deviation):
    return np.sum([aggregate([deviation(zi[yi == yv]) for yv in np.unique(yi)]) for yi, zi in zip(y, z)])


# max of Euclidean distances
ball = lambda z: miniball.get_bounding_ball(z + 1e-12 * np.random.randn(*z.shape))
radius = lambda z: np.sqrt(ball(z)[1])

# sum of Euclidean distances
median = lambda z: compute_geometric_median(z).median
mean_absolute_deviation_around_median = lambda z: np.linalg.norm(z - median(z), axis=-1).mean()

# sum of squared Euclidean distances
variance = lambda z: z.var(axis=0).sum()

# pairwise Euclidean distances
pairwise_distance = lambda z: np.linalg.norm((z[:, None] - z), axis=-1)
diameter = lambda z: pairwise_distance(z).max()
mean_pairwise_distance = lambda z: 0.5 * pairwise_distance(z).mean()

# ----------------------------------------------------------------------------------------------------------------------

fnum = lambda num: f'{np.exp(-num):4.2f}'

for z, z_name in [
    (z_entanglement, 'entanglement'),
    (z_rotation, 'rotation'),
    (z_duplicate, 'duplicate'),
    (z_complement, 'complement'),
    (z_misalignment, 'misalignment'),
    (z_redundancy, 'redundancy'),
    (z_contraction, 'contraction'),
    (z_nonlinear, 'nonlinear'),
    (z_constant, 'constant'),
    (z_random, 'random'),
]:
    print(f'\n{z_name:<12}', end=' ')

    for aggregate, deviation in [
        (np.max, radius),
        (np.mean, mean_absolute_deviation_around_median),
        (np.mean, variance),
        (np.max, diameter),
        (np.mean, mean_pairwise_distance),
    ]:
        print(fnum(q_product(y, z, aggregate, deviation)), end=' ')

# entanglement 0.44 0.75 0.96 0.19 0.82
# rotation     0.22 0.51 0.80 0.05 0.64
# duplicate    0.24 0.43 0.67 0.06 0.56
# complement   0.12 0.28 0.55 0.01 0.42
# misalignment 0.22 0.44 0.74 0.05 0.58
# redundancy   1.00 1.00 1.00 1.00 1.00
# contraction  1.00 1.00 1.00 1.00 1.00
# nonlinear    1.00 1.00 1.00 1.00 1.00
# constant     1.00 1.00 1.00 1.00 1.00
# random       0.22 0.48 0.78 0.05 0.61
