# Computes expressbility of circuits
# inspired by: https://arxiv.org/abs/1905.10876, and https://github.com/vbelis/triple_e
# Generalized to Spectral Concentration (KL divergence on the normalized eigenvalue spectrum) 
# for comparability to classical kernels 

from typing import Callable
import torch as th


def kernel_expressibility(kernel: Callable, inputs=None, bounds=(0, 1), seed=None,
    n_shots: int = 100, return_histogram=False, epstol=1e-18):
    """
    Estimates kernel expressibility via either histogram of kernel values
    or KL divergence on the normalized eigenvalue spectrum.

    Args:
        kernel: Function taking (A, B) and returning similarity matrix.
        inputs: Dimensionality of synthetic input.
        bounds: Sampling range for synthetic input.
        seed: Optional seed for reproducibility.
        n_shots: Number of synthetic input points.
        return_histogram: Whether to return histogram data.
        epstol: Epsilon to avoid log(0).
    Returns:
        KL divergence (either histogram-based or eigenvalue-based),
        and optionally histogram or spectrum data.
    """
    if seed is not None:
        th.manual_seed(seed)

    dim = inputs or getattr(kernel, "inputs", None)
    if dim is None: raise ValueError("Input dimension must be provided.")

    A = (max(bounds) - min(bounds)) * th.rand(n_shots, dim) + min(bounds)
    B = (max(bounds) - min(bounds)) * th.rand(n_shots, dim) + min(bounds)

    with th.no_grad(): K = kernel(A, B)

    # Compute normalized eigenvalue spectrum
    eigvals = th.linalg.eigvalsh(K)
    eigvals = eigvals.clamp(min=0)  # remove tiny negative values
    spectrum = eigvals / eigvals.sum()

    # Uniform reference distribution
    ref = th.ones_like(spectrum) / len(spectrum)

    # KL divergence
    valid = spectrum > epstol
    D_kl = th.sum(spectrum[valid] * th.log(spectrum[valid] / (ref[valid] + epstol)))

    return (D_kl, (ref, spectrum)) if return_histogram else D_kl


if __name__ == '__main__':
    from kernels import QGK, QEK
    import matplotlib.pyplot as plt
    import numpy as np
    from metrics.styles import *
    from sklearn.metrics.pairwise import rbf_kernel, linear_kernel


    plt.figure(figsize=(4, 3))

    etas = range(1, 7) ; seed = 42; G = lambda n: ( 2**n )**2 - 1
    # print('& '+ ' & '.join([ str(G(eta)) for eta in etas]))
    
    results = [kernel_expressibility(QGK(eta, G(eta)), seed=seed) for eta in etas]
    # print('QGK (ours) & '+ ' & '.join([ str(r) for r in results]))
    plt.plot(etas, results, label=f"QGK (ours)",color=blue, **pltargs) #color=clr[p], 
    
    results = [None, *[kernel_expressibility(QEK(eta, G(eta)), seed=seed) for eta in etas[1:]]]
    # print('QEK & '+ ' & '.join([ str(r) for r in results]))
    plt.plot(etas, results, label='QEK', color=red, **pltargs)

    # Compute entropy or KL divergence
    rbf = lambda x, y: th.tensor(rbf_kernel(x, y, gamma=1.0))
    results = [kernel_expressibility(rbf, inputs=G(eta), seed=seed) for eta in etas]
    # print('RBF & '+ ' & '.join([ str(r) for r in results]))
    plt.plot(etas, results, label=f"RBF",color=yellow, **pltargs) #color=clr[p], 

    linear = lambda x, y: th.tensor(linear_kernel(x, y))
    results = [kernel_expressibility(linear, inputs=G(eta), seed=seed) for eta in etas]
    # print('Linear & '+ ' & '.join([ str(r) for r in results]))
    plt.plot(etas, results, label=f"Linear",color=orange, **pltargs) #color=clr[p], 

    plt.ylabel('Spectral Concentration')
    plt.xlabel('Number of inputs (Number of qubits $\eta$)')  # ($4^\eta-1$)
    # plt.xticks(range(1, len(etas)+1), [str(G(eta)) for eta in etas])
    plt.xticks(range(1, len(etas)+1), [f'{G(eta)}({eta})' for eta in etas])
    plt.legend(loc="lower left")
    plt.tight_layout()

    plt.savefig(f"plots/analysis/3-expressibility.pdf")

