from metrics.styles import *
import numpy as np
from matplotlib.ticker import ScalarFormatter
from matplotlib.ticker import FuncFormatter
from matplotlib.legend import Legend
import matplotlib.pyplot as plt


def comp_bound(eta, _G, gamma=1, approx=False):
    """ Computes the complexity bound in terms of dataset size n,
        using the number of groups g(eta) and compression (gamma=d/g(eta)).
    Args:
        eta: Number of qubits.
        g: Function to compute the number of groups given eta.
        gamma: Compression factor (default 1, i.e., 1:1 groups-to-input ratio).
    """
    # Simplified version assuming B^2>>4AC
    if approx: return (gamma * _G(eta)**2 + 8**eta) / (gamma * _G(eta) - 2**eta)

    # Exact version of compression_breakeven, using g = 3*2**eta - 6*(eta%2) + 3 rather than the approximation g≈3*2**eta """
    A = 2**eta - gamma * _G(eta)  # Adjusted for compression
    B = gamma * _G(eta)**2 + 8**eta
    C = 4**eta

    discriminant = B**2 - 4 * A * C
    assert discriminant >= 0, f"Discriminant must be non-negative for real solutions. (eta {eta})"
    n = (- B - np.sqrt(discriminant)) / (2 * A)  # Positive root
    return n


def plot_efficiency_breakeven(etas, _G, secondary, approx=False, print_error=False, scale_y2=10):

    _f, _label = secondary
    breakeven_g1 = [comp_bound(eta, _G, gamma=1, approx=approx) for eta in etas]
    breakeven_geta = [comp_bound(eta, _G, gamma=_f(eta), approx=approx) for eta in etas]

    if print_error and approx:
      error1 = np.array([comp_bound(eta, _G, 1) for eta in etas]) - np.array(breakeven_g1)
      erroreta = np.array([comp_bound(eta, _G, _f(eta)) for eta in etas]) - np.array(breakeven_geta)
      print(f"Approxmiation error: {abs(error1).sum() + abs(erroreta).sum()}")

    # fig, ax = plt.subplots(figsize=(8, 5))
    fig, ax = plt.subplots(figsize=(4, 3)); ax2 = ax.twinx()

    # Plot left y-axis with breakeven points [dataset size n]
    ax.plot(etas, breakeven_g1, marker='o', color="#424242", label='$\gamma=1$') # Plot breakeven 
    ax.plot(etas, breakeven_geta, marker='s', color='#646464', label=f'$\gamma={_label}$') # Plot breakeven 
    
    # Set and scale y-axes
    y_lim = ax.get_ylim()[1] * 1.4; ax.set_ylim(1, y_lim); ax2.set_ylim(1 / scale_y2, y_lim/scale_y2)

    ax.fill_between(etas, breakeven_geta, y_lim, color=green, alpha=0.2, label='$C_{QGK} < C_{Classical}$') # Fill above breakeven (Classical < QGK)
    ax.fill_between(etas, 0 ,breakeven_g1, color=red, alpha=0.2, label='$C_{QGK} > C_{Classical}$') # Fill above breakeven (Classical < QGK)

    # Add secondary y-axis to plot g(eta) scaling
    ax2.plot(etas, [_G(eta) for eta in etas], color=blue, linestyle='--', label='$\gamma g$ inputs with $g$ VGGs') # with $\gamma=1$ Number of VGGs $g=d$
    ax2.plot(etas, [_G(eta) for eta in etas], marker='o', color=blue, linestyle='--') # label='Number of VGGs $g=d$'
    ax2.plot(etas, [_f(eta)*_G(eta) for eta in etas], marker='s', color="#3eb6ea", linestyle='--') # label=f'Inputs $\gamma={_label}$'
   
    # Style axes
    ax.set_xlabel('Number of qubits $\eta$'); ax.grid(True); 
    ax.set_ylabel('Dataset Size $n$ ($n=10d$)'); ax.set_yscale('log'); 
    ax.tick_params(axis='y', colors='#424242'); ax.yaxis.label.set_color('#424242')
    
    # Combine the first two legend entries into a single line
    ax.legend(loc='upper left'); handles, labels = ax.get_legend_handles_labels() 

    # Create a custom legend with the first two entries on one line and overwrite original legend
    ax.add_artist(Legend(ax, handles[:2], labels[:2], loc='lower right', ncol=2, frameon=True))
    ax.legend(handles[2:], labels[2:], loc='upper left')
    # ax.set_title('Efficiency Breakeven')
   
    # Style scondary y-axis
    ax2.set_ylabel('Input Dimensionality $d$'); ax2.set_yscale('log')
    leg = ax2.legend(loc='lower right'); leg.set_bbox_to_anchor((1, 0.15))
    ax2.tick_params(axis='y', colors=blue); ax2.yaxis.label.set_color(blue)
    ax.set_yticklabels([])  # Hide tick labels but keep ticks

    # ax.set_ylim(1, 10000)
    # ax2.set_ylim(0.1,1000)

    ax.set_xlim(left=1, right=max(etas))

    # plt.savefig('breakeven-analysis.png')
    # plt.tight_layout(rect=[0, 0, 1, 1])  # Use tight layout but do not inset axes
    # plt.show()
    # plt.
    ax.set_xticks(etas)
    ax.set_xticklabels([str(eta) for eta in etas])
    plt.tight_layout()
    plt.savefig(f"plots/analysis/4-complexity.pdf")



def print_complexity_table(split=0.1, show_gamma=False, show_g=True):
    C_QGK = lambda eta, n, d: 4**eta + n * g(eta) * d + n * 8**eta + n**2 * 2**eta
    C_Classical = lambda eta, n, d: n**2 * d

    config = {
        'moons': { 'eta': 2, 'd': 2, 'n': 200},
        'circles': { 'eta': 2, 'd': 2, 'n': 200},
        'Bank': { 'eta': 2, 'd': 16, 'n': 200},
        'MNIST': { 'eta': 5, 'd': 784, 'n': 1000},
        'CIFAR10': { 'eta': 5, 'd': 3072, 'n': 1000},
    }

    split = 0.1
    print('\\textbf{Component} & \\textbf{QGK (ours)} & \\textbf{Classical Kernel} \\\\ \\hline')
    # Gamma Based
    for label, config in config.items():
        config_str = ', '.join(f"{k}={v}" for k, v in config.items())
        if show_g: config_str += f",g={g(config['eta'])}"
        if show_gamma: config_str += f", \\gamma \\approx {round(config['d'] / g(config['eta']), 2)}"
        format_O = lambda o:  '$\mathcal{O}('+f"{o:.2e})$"
        c1 = config.copy(); c1['n'] *= 1-split; c2 = config.copy(); c2['n'] *= split
        print(f"{label} ($\\{config_str}$) & {format_O(C_QGK(**c1) + C_QGK(**c2))} & {format_O(C_Classical(**c1) + C_Classical(**c2))} \\\\")


    C_QGK = lambda eta, n, gamma=1: 4**eta + n * gamma * g(eta)**2 + n * 8**eta + n**2 * 2**eta
    C_Classical = lambda eta, gamma=1: (gamma * g(eta)**2 + 8**eta) / (gamma * g(eta) - 2**eta)


def print_efficiency_bounds(etas, _G, secondary, approx=False):
  _f, _label = secondary
  eb = lambda eta, gamma=1: comp_bound(eta, _G, gamma, approx) / (gamma * _G(eta))
  print('$\eta$ & ' + ' & '.join([str(eta) for eta in etas]) + ' \\\\ \\hline')
  # print('$\epsilon b_1$ & ' + ' & '.join([f'{eb(eta, 1):.2f}' for eta in etas]) + ' \\\\ \\hline')
  print(f'$\epsilon b_{_label}$ & ' + ' & '.join([f'{eb(eta, _f(eta)):.2f}' for eta in etas]) + ' \\\\ \\hline')
  # print('$\epsilon b_{2^\eta}$ & ' + ' & '.join([f'{eb(eta, 2**eta):.2f}' for eta in _etas]) + ' \\\\ \\hline')



if __name__ == '__main__':
    etas = np.arange(1, 12)
    # etas = np.arange(1, 16)
    inverse = (lambda x: 1/x, '1/\eta')
    linear = (lambda x: x,'\eta')
    exp = (lambda x: np.sqrt(2)**x, '\sqrt{2}^\eta')

    g = lambda eta: 3*2**eta - 6*(eta%2) + 3 # Current implementation 
    g_approx = lambda eta: 3*2**eta
    H = lambda eta: 4**eta - 1 # Use max groups 

    # print_complexity_table(split=0.1) # Appendix D Table 1

    # print_efficiency_bounds(range(2,9), g, secondary=linear)
    # print_efficiency_bounds(range(2,300), g_approx, exp, approx=True)

    plot_efficiency_breakeven(range(1, 9), g, secondary=linear)
    # plot_efficiency_breakeven(range(1, 200), g_approx, secondary=exp, approx=True) # demonstrating Theorem D2
