import model
import simulator
import agent
import observer

import numpy as np
import pickle

import sys

class ExperimentRun:
    def __init__(self, model_, model_bounds, rng, max_step_count):
        self.model = model_
        self.model_bounds = model_bounds
        self.rng = rng
        self.max_step_count = max_step_count

        self.ideal_agent = agent.KnownPOAgent(self.model)
        self.agent = agent.ACRLAgent(self.model_bounds, self.model.state_rewards)
        self.ablation_agent = agent.AblationACRLAgent(self.model_bounds, self.model.state_rewards)
        self.ucrl_agent = agent.UCRLAgent(self.model_bounds, self.model.state_rewards)

        self.acrl_observer = observer.Observer(self.model)
        self.acrl_simulator = simulator.Simulator(self.model, self.agent, self.acrl_observer, self.rng)
        
        self.ablation_observer = observer.Observer(self.model)
        self.ablation_simulator = simulator.Simulator(self.model, self.ablation_agent, self.ablation_observer, self.rng)

        self.ucrl_observer = observer.Observer(self.model)
        self.ucrl_simulator = simulator.Simulator(self.model, self.ucrl_agent, self.ucrl_observer, self.rng)

    def run(self, verbose=False):
        for i in range(self.max_step_count):
            self.acrl_simulator.step()
            self.ablation_simulator.step()
            self.ucrl_simulator.step()

            if i > 0 and i % 1000 == 0 and verbose:
                print(f"after {i} steps")
                print("Trailing gain (learning): ", self.acrl_observer.get_past_n_gain(10000))
                print("Ideal gain: ", self.ideal_agent.get_estimated_gain())
                print("Trailing gain (ablation): ", self.ablation_observer.get_past_n_gain(10000))
                print("Trailing gain (ucrl): ", self.ucrl_observer.get_past_n_gain(10000))
                print("Optimistic gain (learning): ", self.agent.get_estimated_gain())

    def summarize(self):
        ideal_gain = self.ideal_agent.get_estimated_gain()
        timestep = 10000
        return {
                "acrl": self.acrl_observer.summarize(ideal_gain, timestep=timestep),
                "ablation": self.ablation_observer.summarize(ideal_gain, timestep=timestep),
                "ucrl": self.ucrl_observer.summarize(ideal_gain, timestep=timestep)
            }

class Experiment:
    def __init__(self, all_bounds, runs_per_bound, rng, max_step_count, rerun=False, starting_no=0):
        self.all_bounds = all_bounds
        self.runs_per_bound = runs_per_bound
        self.max_step_count = max_step_count
        self.rng = rng
        self.rerun = rerun
        self.starting_no = starting_no

        #self.runs = [[ExperimentRun(bound, self.rng, self.max_step_count) for i in range(self.runs_per_bound)] for bound in self.all_bounds]
        self.failed_runs = []
        
        if self.rerun:
            self.models = []
            for i, bound in enumerate(all_bounds):
                self.models.append([])
                lstate = bound.capacities[0]
                rstate = bound.capacities[1]
                for run in range(runs_per_bound):
                    with open(f"exp_out/bound_{lstate}_{rstate}_states/model_{run}", "rb") as f:
                        self.models[-1].append(pickle.load(f))
        else:
            self.models = [[model.generate_model(bound, model.RewardGenerator(rng), rng) for i in range(self.runs_per_bound)] for bound in all_bounds]

            for i, models in enumerate(self.models):
                bound = self.all_bounds[i]
                lstate = bound.capacities[0]
                rstate = bound.capacities[1]
                for run, model_ in enumerate(models):
                    with open(f"exp_out/bound_{lstate}_{rstate}_states/model_{run}", "wb") as f:
                        pickle.dump(model_, f)

    def run(self):
        for i, bound in enumerate(self.all_bounds):
            lstate = bound.capacities[0]
            rstate = bound.capacities[1]
            for run in range(self.starting_no,self.runs_per_bound):
                print(f"Running run # {run}")
                exp_run = ExperimentRun(self.models[i][run], bound, self.rng, self.max_step_count)
                try:
                    exp_run.run(verbose=True)
                except Exception as e:
                    self.failed_runs.append([run, str(e)])
                    continue
                with open(f"exp_out/bound_{lstate}_{rstate}_states/run_{run}", "wb") as f:
                    pickle.dump(exp_run.summarize(), f)
                #if i == 9:
                #    raise Exception("stop")

if __name__ == "__main__":
    # schedule: 
    # seed 1000: 6 types, 11 states
    # seed 2000: 6 types, 21 states
    # seed 3000: 6 types, 51 states
    # seed 4000: 6 types, 101 states
    bound_no = int(sys.argv[1])

    if len(sys.argv) > 2:
        rerun = int(sys.argv[2]) == 1
        if rerun:
            starting_no = int(sys.argv[3])
        else:
            starting_no = 0
    else:
        rerun = False
        starting_no = 0

    rng = np.random.default_rng(seed=(1000*bound_no))
    bounds = [
            model.ModelBounds([3,3],[5,5]),
            model.ModelBounds([3,3],[10,10]),
            model.ModelBounds([3,3],[25,25]),
            model.ModelBounds([3,3],[50,50]),
            ]

    experiment = Experiment([bounds[bound_no-1]], 50, rng, 10000000, rerun=rerun, starting_no=starting_no)

    experiment.run()
