from kernels import QEK, QGK, HEE
import numpy as np; import torch as th

from qiskit_ibm_runtime.fake_provider import FakeQuitoV2
import torch


n = 8
np.random.seed(42)
torch.manual_seed(42)

config = {
    'moons / circles': { 'eta': 2, 'd': 2, 'n': 200, 'kernels': {
        'QGK': {'optimizer':'Adam'},
        'QEK': {'reupload': 1, 'optimizer':'Adam'},
        'HEE': {'layers': 2},

        'QEK Adapted': {'reupload': 0, 'optimizer':'Adam'},
        'HEE Adapted': {'layers': 3},
    }},
    
    'bank': { 'eta': 2, 'd': 16, 'n': 200, 'kernels': {
        'QGK': {'optimizer':'Adam'},
        'QEK': {'reupload': 1, 'optimizer':'Adam'},
        'HEE': {'inputs': 2, 'layers': 2},

        'QEK Adapted': {'reupload': 0, 'optimizer':'Adam'},
        'HEE Adapted': {'inputs': 2, 'layers': 3},
    }},

    'MNIST': { 'eta': 5, 'd': 784, 'n': 1000, 'kernels': {
        'QGK': {'optimizer':'Adam'},
        'QEK': {'eta': 5, 'reupload': 1, 'optimizer':'Adam'}, 
        'HEE': {'layers': 2, 'inputs': 5},

        'QEK Adapted': {'eta': 5, 'reupload': 0, 'optimizer':'Adam'}, 
        'HEE Adapted': {'layers': 166, 'inputs': 5},
    }},

    'CIFAR10': { 'eta': 5, 'd': 3072, 'n': 1000, 'kernels': {
        'QGK': {'optimizer':'Adam'},
        'QEK': {'eta': 5, 'reupload': 1, 'optimizer':'Adam'}, 
        'HEE': {'layers': 2, 'inputs': 5},

        'QEK Adapted': {'eta': 5, 'reupload': 0, 'optimizer':'Adam'}, 
        'HEE Adapted': {'layers': 166, 'inputs': 5},
    }},
}

print(config['moons / circles']['kernels'].keys())

for ds, cfg in config.items():
  a = th.randn(n, cfg['d'], dtype=torch.float64)
  b = th.randn(n, cfg['d'], dtype=torch.float64)
  print(f"{ds}: ", end='')

  for kernel, kwargs in cfg['kernels'].items():
    if 'inputs' not in kwargs: kwargs['inputs'] = cfg['d']
    if 'eta' not in kwargs: kwargs['eta'] = cfg['eta']
    kernel = eval(kernel.split(' ')[0])(mode='simulation', **kwargs)
    kernel.backend = FakeQuitoV2()
    depths = kernel.kernel(a, b, dry_run=True)
    print(int(np.mean(depths)), end=' & ')
  print('\\\\')
  print()
