# Computes entanglement capability of circuits
# based on: https://arxiv.org/abs/1905.10876

import numpy as np
import torch as th

def entanglement(kernel, bounds=(0,1), seed=None, n_shots:int=100):
    """Computes Meyer Wallach entanglement capability for a kernel. Args:
        kernel: A function that takes a batch of states and returns a fidelity tensor [A, B] -> [F]
        bounds: Range to sample kernel input from
        seed: Random seed for reproducibility.
        n_shots: How many fidelity samples to generate.
        n_bins: Number of equal-width bins.
        return_histogram: If `True`, additionally returns a tuple
            `(p_haar, p_circuit)`containing the normalized histogram data of the fidelity
            distributions.
    Returns: The expressiblity of the circuit. """
    if seed is not None: th.manual_seed(seed)

    # Generate random inputs and evlolve states
    data = (max(bounds)-min(bounds))*th.rand(n_shots, kernel.inputs)+min(bounds)

    # data = scale*th.rand(n_shots, kernel.inputs)
    with th.no_grad(): states = kernel.evolve(data)
    
    # Alternating masks for partial entanglement
    masked = lambda k, x, y: states[:, x == th.arange(2**kernel.eta) // 2**k % 2][:, y]
    
    # Compute Wedge distance
    i, j = np.triu_indices(np.floor(kernel.eta**2/2), k=1)
    wedges = lambda k: masked(k,0,i) * masked(k,1,j) - masked(k,0,j) * masked(k,1,i)
    
    entanglement = sum([th.sum(th.abs(wedges(k)) ** 2) for k in range(kernel.eta)])

    return 4 / kernel.eta * entanglement / n_shots



if __name__ == '__main__':
    from kernels import QGK, QEK
    import matplotlib.pyplot as plt
    import numpy as np
    from metrics.styles import *

    # Plot entanglement for QGK and QEK for |VGG.groups| parameters over 2..8 qubits
    plt.figure(figsize=(4, 3))

    etas = range(1, 7) ; seed = 42 
    G = lambda n: ( 2**n )**2 - 1
    G_eta = lambda n: n * (( 2**n )**2 - 1)
    # g = lambda n: g(n-1) * 2 + (n%2!=0) * 2 - 1 if n > 1 else 1
    # groups = lambda n: int(G(n)/g(n))

    results = [entanglement(QGK(eta, G(eta)), seed=seed) for eta in etas[1:]]
    plt.plot(etas, [None, *results], label=f"QGK (ours)", color=blue, **pltargs) 
    
    results = [entanglement(QEK(eta, G(eta)), seed=seed, n_shots=100) for eta in etas[1:]]
    plt.plot(etas, [None, *results], label='QEK', color=red, **pltargs)

    plt.xlabel('Number of qubits $\eta$')
    plt.ylabel('Entanglement')
    plt.xticks(etas)

    plt.legend(loc="upper left")
    plt.tight_layout()

    plt.savefig(f"plots/analysis/2-entanglement.pdf")

