import model
import simulator
import agent
import observer

from learners.UCRL2 import UCRL2
from learners.UCRL3 import UCRL3
from learners.KL_UCRL import KL_UCRL

import numpy as np
import pickle

import sys

class LearnerExperimentRun:
    def __init__(self, model_, model_bounds, rng, max_step_count, no_ucrl3=False, ucrl3_only=False):
        self.model = model_
        self.model_bounds = model_bounds
        self.rng = rng
        self.max_step_count = max_step_count
        self.no_ucrl3 = no_ucrl3
        self.ucrl3_only = ucrl3_only

        n_states = self.model_bounds.n_states
        n_actions = self.model_bounds.n_actions
        delta = 0.05

        self.ideal_agent = agent.KnownPOAgent(self.model)
        self.ucrl2_agent = agent.LearnerAgent(self.model_bounds, self.model.state_rewards, UCRL2(n_states, n_actions, delta))
        self.ucrl3_agent = agent.LearnerAgent(self.model_bounds, self.model.state_rewards, UCRL3(n_states, n_actions, delta))
        self.kl_agent = agent.LearnerAgent(self.model_bounds, self.model.state_rewards, KL_UCRL(n_states, n_actions, delta))

        self.ucrl2_agent.learner.reset(0)
        self.ucrl3_agent.learner.reset(0)
        self.kl_agent.learner.reset(0)

        self.ucrl2_observer = observer.Observer(self.model)
        self.ucrl2_simulator = simulator.Simulator(self.model, self.ucrl2_agent, self.ucrl2_observer, self.rng)
        
        self.ucrl3_observer = observer.Observer(self.model)
        self.ucrl3_simulator = simulator.Simulator(self.model, self.ucrl3_agent, self.ucrl3_observer, self.rng)
        
        self.kl_observer = observer.Observer(self.model)
        self.kl_simulator = simulator.Simulator(self.model, self.kl_agent, self.kl_observer, self.rng)

    def run(self, verbose=False):
        for i in range(self.max_step_count):
            if not self.no_ucrl3:
                self.ucrl3_simulator.step()
            if not self.ucrl3_only:
                self.ucrl2_simulator.step()
                self.kl_simulator.step()

            if i > 0 and i % 1000 == 0 and verbose:
                print(f"after {i} steps")
                if not self.no_ucrl3:
                    print("Trailing gain (ucrl3): ", self.ucrl3_observer.get_past_n_gain(10000))
                if not self.ucrl3_only:
                    print("Trailing gain (ucrl2): ", self.ucrl2_observer.get_past_n_gain(10000))
                    print("Trailing gain (kl-ucrl): ", self.kl_observer.get_past_n_gain(10000))

    def summarize(self):
        ideal_gain = self.ideal_agent.get_estimated_gain()
        timestep = 10000
        if self.no_ucrl3:
            return {
                    "ucrl2": self.ucrl2_observer.summarize(ideal_gain, timestep=timestep),
                    "kl": self.kl_observer.summarize(ideal_gain, timestep=timestep)
                }
        elif self.ucrl3_only:
            return {
                    "ucrl3": self.ucrl3_observer.summarize(ideal_gain, timestep=timestep),
                }
        return {
                "ucrl2": self.ucrl2_observer.summarize(ideal_gain, timestep=timestep),
                "ucrl3": self.ucrl3_observer.summarize(ideal_gain, timestep=timestep),
                "kl": self.kl_observer.summarize(ideal_gain, timestep=timestep)
            }

class LearnerExperiment:
    def __init__(self, all_bounds, runs_per_bound, rng, max_step_count, rerun=False, starting_no=0, no_ucrl3=False, ucrl3_only=False):
        self.all_bounds = all_bounds
        self.runs_per_bound = runs_per_bound
        self.max_step_count = max_step_count
        self.rng = rng
        self.rerun = rerun
        self.starting_no = starting_no
        self.no_ucrl3 = no_ucrl3
        self.ucrl3_only = ucrl3_only

        #self.runs = [[ExperimentRun(bound, self.rng, self.max_step_count) for i in range(self.runs_per_bound)] for bound in self.all_bounds]
        self.failed_runs = []
        
        self.models = []
        for i, bound in enumerate(all_bounds):
            self.models.append([])
            lstate = bound.capacities[0]
            rstate = bound.capacities[1]
            for run in range(runs_per_bound):
                with open(f"exp_out/bound_{lstate}_{rstate}_states/model_{run}", "rb") as f:
                    self.models[-1].append(pickle.load(f))

    def run(self):
        for i, bound in enumerate(self.all_bounds):
            lstate = bound.capacities[0]
            rstate = bound.capacities[1]
            for run in range(self.starting_no,self.runs_per_bound):
                print(f"Running run # {run}")
                exp_run = LearnerExperimentRun(self.models[i][run], bound, self.rng, self.max_step_count, no_ucrl3=self.no_ucrl3, ucrl3_only = self.ucrl3_only)
                exp_run.run(verbose=True)
                if self.no_ucrl3:
                    with open(f"exp_out/bound_{lstate}_{rstate}_states/no_ucrl3_baselines_{run}", "wb") as f:
                        pickle.dump(exp_run.summarize(), f)
                elif self.ucrl3_only:
                    with open(f"exp_out/bound_{lstate}_{rstate}_states/ucrl3_baselines_{run}", "wb") as f:
                        pickle.dump(exp_run.summarize(), f)
                else:
                    with open(f"exp_out/bound_{lstate}_{rstate}_states/baselines_{run}", "wb") as f:
                        pickle.dump(exp_run.summarize(), f)
                #if i == 9:
                #    raise Exception("stop")

if __name__ == "__main__":
    # schedule: 
    # seed 1000: 6 types, 11 states
    # seed 2000: 6 types, 21 states
    # seed 3000: 6 types, 51 states
    # seed 4000: 6 types, 101 states
    bound_no = int(sys.argv[1])

    if len(sys.argv) > 2:
        rerun = int(sys.argv[2]) == 1
        if rerun:
            starting_no = int(sys.argv[3])
        else:
            starting_no = 0
        no_ucrl3 = int(sys.argv[4]) == 1
        ucrl3_only = int(sys.argv[5]) == 1
    else:
        rerun = False
        starting_no = 0
        no_ucrl3 = False
        ucrl_only = False

    rng = np.random.default_rng(seed=(1000*bound_no))
    bounds = [
            model.ModelBounds([3,3],[5,5]),
            model.ModelBounds([3,3],[10,10]),
            model.ModelBounds([3,3],[25,25]),
            model.ModelBounds([3,3],[50,50]),
            ]

    experiment = LearnerExperiment([bounds[bound_no-1]], 50, rng, 10000000, rerun=rerun, starting_no=starting_no, no_ucrl3=no_ucrl3, ucrl3_only=ucrl3_only)

    experiment.run()
