from memory import EMT, CMT
from learners import MemoryToCBLearner, StackedMemLearner
from evaluators import SlimOnlineOnPolicyEvaluation

from coba.environments import Environments
from coba.learners     import VowpalEpsilonLearner
from coba.experiments  import Experiment, ClassEnvironmentTask

config  = {"processes": 8, "chunk_by":'task', 'maxtasksperchunk': None, 'maxchunksperchild': 0 }
epsilon = 0.1

if __name__ == '__main__':

   learners = [
      #Parametric
      VowpalEpsilonLearner(epsilon, features=["a","xa","xxa"]),

      #EMT-CB (self-consistent)
      MemoryToCBLearner   (epsilon, EMT(split=100, scorer=3, router=2, bound=-1,                        interactions=['xa'])),
      
      #EMT-CB (not self-consistent)
      MemoryToCBLearner   (epsilon, EMT(split=50, scorer=4, router=2, bound=-1,                         interactions=[])),
      
      #CMT-CB
      MemoryToCBLearner   (epsilon, CMT(n_nodes=2000, leaf_multiplier=9 , dream_repeats=10, alpha=0.50, interactions=['xa'])),

      #PEMT-CB
      StackedMemLearner   (epsilon, EMT(split=100, scorer=3, router=2, bound=-1,                        interactions=['xa']), "xxa", False, True),

      #PCMT-CB
      StackedMemLearner   (epsilon, CMT(n_nodes=2000, leaf_multiplier=9, dream_repeats=10, alpha=0.50,  interactions=['xa']), "xxa", False, True),
   ]

   description = "Full 50 replicate run of the datasets used in the ICLR 2023 paper."
   
       #enter a path here if you'd like to save your results while it runs
    #experiment execution can be stopped and resumed at a later date if needed

   log         = None#"./results/ICLR-2023-unbounded.log.gz"

   #environments = Environments.from_template("./experiments/sanity.json", n_shuffle=1, n_take=1000)
   environments = Environments.from_template("./experiments/unbounded.json")
   environments = sorted(environments, key=lambda e: (e.params['shuffle'],e.params['openml_task']))

   result = Experiment(environments, learners, description, environment_task=ClassEnvironmentTask(), evaluation_task=SlimOnlineOnPolicyEvaluation()).config(**config).evaluate(log)
   result.filter_fin().plot_learners()
