from config import *
from run.train import execute, train, fit
from run.plot import *
from metrics.styles import *

data = MOONS; seeds = list(range(8)); steps = 100; simulate = True

for label, (alg, run, args, cfg) in {
  'QGK (ours)': (QGK(2, steps), train, {'simulate': simulate}, (blue, True, False)),
  'QEK       ': (QEK(2, steps), train, {'simulate': simulate }, (red, True, False)),
  'HEE Linear': (HEE(2, steps), train, {'simulate': simulate }, (green, True, False)),
  'QGK Static': (QGK_STATIC(2), fit, {'add_loss':True, 'simulate': simulate}, (blue, True, True)),
  'HEE       ': (HEE(2), fit, {'add_loss':True, 'simulate': simulate}, (green, True, True)),
  'RBF       ': (RBF(), fit, {'add_loss':True}, (yellow, True, True)),
  'Linear    ': (LIN(), fit, {'add_loss':True}, (orange, True, True)),

  'HEE-D     ': (HEE(2, 0, 3), fit, {'add_loss':True, 'simulate': simulate}, (darkgreen, True, True)),
  'QEK-N     ': (QEK(2, steps, 0), train, {'simulate': simulate }, (lightred, True, False)),
}.items(): write(label, *cfg, data, steps, execute({**data, **alg}, seeds, run, **args)[0]) 
