from vec_env import HashVectorizedPartialEnvironment, HashMHMVectorizedEnvironment
from runner import Runner
from agent import Agent

import pickle
from pathlib import Path
from datetime import datetime


class Runnables:
    def __init__(self, num_views):
        self.runner_file_name = ""
        self.runner_load_file_name = ""

        self.runs = 5
        self.run_number = None
        self.num_views = num_views
        self.num_s_batches = 50

        self.env = None

        self.environments = ["MiniGrid-DoorKey-8x8-v0"]
        self.num_steps =[]
        self.epsilons = []

    def run(self, args, save):
        # STAGE 1
        #runner = Runner(self.num_views)
        #runner.agent = Agent(self.num_views, 7)

        for i in range(len(self.environments)):
            print("NEW RUN")
            for j in range(self.runs):
                print("NEW RUNNER")

                if args.env_name is None:
                    env_name = self.environments[i]
                else:
                    env_name = args.env_name

                if save:
                    runner = Runner(self.num_views)
                    runner.env = self.env(self.num_views, env_name)
                    runner.agent = Agent(self.num_views, runner.env.get_action_space())
                else:
                    runner = Runner(self.num_views)
                    runner.env = self.env(self.num_views, env_name)
                    runner.agent = Agent(self.num_views, runner.env.get_action_space())
                    ttables = self.load()
                    runner.agent.set_transfer_table(ttables)
                    ttables = None

                # Number of runs
                if args.run_num is not None:
                    _run_num = args.run_num
                elif self.run_number is not None:
                    _run_num = self.run_number
                else:
                    _run_num = j

                if args.num_steps is not None:
                    _num_steps = args.num_steps
                else:
                    _num_steps = self.num_steps[i]

                # Exploration parameter epsilon
                if args.epsilon is not None:
                    _epsilon = args.epsilon
                else:
                    _epsilon = self.epsilons[i]

                if args.gamma is not None:
                    _gamma = args.gamma
                else:
                    _gamma = 0.99

                if args.alpha is not None:
                    _alpha = args.alpha
                else:
                    _alpha = 0.2

                if args.rho is not None:
                    _rho = args.rho
                else:
                    _rho = 1.0

                steps = 0
                total_eps = 0
                episodes_successful = 0

                date_and_time = str(datetime.now())
                Path("experiments/" + env_name.split('-')[0] + "/" + env_name + "/" + date_and_time).mkdir(parents=True, exist_ok=True)
                log_name = "experiments/" + env_name.split('-')[0] + "/" + env_name + "/" + date_and_time + "/logs.csv" 
               
                print("NUMBER OF STEPS: ", _num_steps)

                while steps < _num_steps:
                    r_steps, r_episodes_successful = runner.explore_and_navigate(log_name,
                                                                                                            steps,
                                                                                                            total_eps,
                                                                                                            episodes_successful,
                                                                                                            alpha=_alpha,
                                                                                                            gamma=_gamma,
                                                                                                            epsilon=_epsilon,
                                                                                                            rho=_rho,
                                                                                                            exp_thresh=50,
                                                                                                            run_num=_run_num)

                    steps += r_steps
                    total_eps += 1
                    episodes_successful += r_episodes_successful

                if save:
                    self.save_runner(runner)
        
        print("DONE")

    def save_runner(self, runner):
        pickle.dump(runner.agent.get_transfer_table(), open(self.runner_file_name, "wb"))

    def load(self):
        runner = pickle.load(open(self.runner_load_file_name, "rb"))
        return runner


class RunnableUnlockPickup(Runnables):
    def __init__(self, num_views):
        super().__init__(num_views)

        self.runs = 1
        #self.run_number = 1
        self.num_views = num_views
        self.num_s_batches = 50

        self.env = HashVectorizedPartialEnvironment

        self.environments = ["MiniGrid-DoorKey-8x8-v0"]

        #self.runner_load_file_name = "saved/yourFile"

        self.num_steps = [30000000]
        self.epsilons = [0.1]

    def run(self, args, save):
        date_and_time = str(datetime.now())
        Path("saved/" + args.env_name.split('-')[0] + "/" + args.env_name + "/" + date_and_time).mkdir(parents=True, exist_ok=True)
        self.runner_file_name ="experiments/" + args.env_name.split('-')[0] + "/" + args.env_name + "/" + date_and_time + "/model.p"
        super().run(args, save)


class RunnableMiniHack(Runnables):
    def __init__(self, num_views):
        super().__init__(num_views)

        self.bead_file_name = "saved/bead_UP.p"
        self.trainer_file_name = "saved/trainer_UP.p"
        self.trainer_agent_file_name = "saved/tt_UP.p"
        self.runner_file_name = "saved/runner_UP.p"

        self.runs = 3
        # self.run_number = 10

        self.num_views = num_views
        self.num_s_batches = 50

        self.env = HashMHMVectorizedEnvironment

        self.environments = ["MiniHack-Room-Random-15x15-v0"] 

        self.num_steps = [15000000]
        self.epsilons = [0.3]

    def run(self, args, save):
        super().run(args, save)
