import gym_minigrid
import gym
from numpy.core.fromnumeric import product

from utils.wrappers import MinigridDoorKeyWrapper, MinigridDoorKeyTabularWrapper
from utils.q_learning import q_learning
import argparse
from numpy import savetxt
import numpy as np
import os
from multiprocessing import Pool
import datetime
from functools import partial
from utils.plot_utils import plot_from_folder
from time import sleep, time
import random


def main(method, run_index, args):
    np.random.seed(run_index)
    # create environment
    env = gym.make("MiniGrid-DoorKey-8x8-v0", size=args.env_size)
    env = gym_minigrid.wrappers.ReseedWrapper(env, seeds=[42])
    env = MinigridDoorKeyTabularWrapper(env)
    env = MinigridDoorKeyWrapper(env)

    # train
    q_values, rewards, steps = q_learning(
        env,
        num_steps=args.num_steps,
        explore_method=method,
        depth=args.depth,
        learning_rate=args.learning_rate,
        epsilon=args.epsilon,
    )

    # plot
    os.makedirs(os.path.join(args.log_path, method), exist_ok=True)
    savetxt(
        os.path.join(args.log_path, method, f"{run_index}.csv"),
        np.array([steps, rewards]).T,
        delimiter=",",
    )


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--log-path",
        help="Path for log files",
        default="logs",
        type=str,
    )

    parser.add_argument(
        "-n",
        "--num-steps",
        help="Number of steps to train for",
        default=5000000,
        type=int,
    )
    parser.add_argument(
        "--env-size",
        help="Size of square grid environment",
        default=19,
        type=int,
    )
    parser.add_argument(
        "--epsilon",
        help="epsilon parameter of epsilon-greedy",
        default=0.3,
        type=float,
    )
    parser.add_argument(
        "--depth",
        help="depth of EASEE",
        default=6,
        type=int,
    )
    parser.add_argument(
        "--explore-methods",
        help="EASEE or classic",
        nargs="+",
        type=str,
    )
    parser.add_argument(
        "--learning-rate",
        help="Q-learning learning rate",
        default=0.1,
        type=float,
    )

    parser.add_argument(
        "--num-seeds",
        help="Number of trials to run",
        default=1,
        type=int,
    )

    parser.add_argument(
        "--num-threads",
        help="Number of threads to run in parallel",
        default=1,
        type=int,
    )

    date = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
    args = parser.parse_args()
    args.log_path = os.path.join(args.log_path, date)

    thread_args = [
        (method, idx)
        for method in args.explore_methods
        for idx in range(args.num_seeds)
    ]
    print(thread_args)

    with Pool(processes=args.num_threads) as pool:
        pool.starmap(
            partial(main, args=args),
            thread_args,
        )

    plot_from_folder(args.log_path, args.explore_methods, f"{date}.pdf")