from config import *
from run.train import execute, train, fit
from run.plot import *
from metrics.styles import *

data = CIFAR10; seeds = list(range(8)); steps = 100

for label, (alg, run, args, cfg) in {
  'QGK (ours)': (QGK(5, steps), train, {}, (blue, True, False)),
  'QEK       ': (QEK(5, steps), train, {}, (red, True, False)),
  'HEE Linear': (HEE(5, steps), train, {}, (green, True, False)),
  'QGK Static': (QGK_STATIC(6), fit, {'add_loss':True}, (blue, True, True)),
  'HEE       ': (HEE(5), fit, {'add_loss':True, 'pca_features': 5}, (green, True, True)),
  'RBF       ': (RBF(), fit, {'add_loss':True}, (yellow, True, True)),
  'Linear    ': (LIN(), fit, {'add_loss':True}, (orange, True, True)),

  'HEE-D     ': (HEE(5, 0, 166), fit, {'add_loss':True, 'pca_features': 5}, (darkgreen, True, True)),
  'QEK-N     ': (QEK(5, steps, 0), train, {}, (lightred, True, False)),
}.items(): write(label, *cfg, data, steps, execute({**data, **alg}, seeds, run, **args)[0]) 
