from typing import List

from BACKEND import cp, sp, to_cpu, to_gpu
from model.kernels import Kernel
from utils.metrics import SamplingMetric, sample_pairs

# Entropy criterion for selecting the sampling metrics

def calibrate_sampling_entropy(X, Y, metrics: List[SamplingMetric], num_samples=500, demean_x=True, demean_y=True, min_norm=1e-3):
    num_samples = min(num_samples, X.shape[0])
    X = to_gpu(X[:num_samples]).copy()
    Y = to_gpu(Y[:num_samples]).copy()
    if demean_x:
        X -= X.mean(axis=2, keepdims=True)
    if demean_y:
        Y -= Y.mean(axis=2, keepdims=True)

    n_metrics = len(metrics)

    i_upper, j_upper = cp.triu_indices(num_samples, k=1)
    n_uniq = len(i_upper)

    mask_norm = (cp.linalg.norm(X, axis=-1).min(axis=-1) >= min_norm)

    dists_in = cp.zeros((n_metrics, n_uniq), dtype=cp.float32)
    dists_out = cp.zeros((n_metrics, n_uniq), dtype=cp.float32)
    for i, d_in in enumerate(metrics):
        dists_in[i] = d_in(X)[i_upper, j_upper]
    for i, d_out in enumerate(metrics):
        dists_out[i] = (d_out(Y) * mask_norm[:, None])[i_upper, j_upper]

    best_pair = []
    best_score = cp.infty
    is_better = cp.less

    for i_d_in in range(n_metrics):
        for j_d_in in range(n_metrics):
            G = (dists_out[j_d_in] / (dists_in[i_d_in] + 1e-6)).astype(cp.float64)
            p = G / G.sum()

            H = -cp.nansum(cp.where(p > 0, p * cp.log2(p), 0.0))
            if is_better(H, best_score):
                best_score = H
                best_pair = [metrics[i_d_in], metrics[j_d_in]]

    return best_pair


