import argparse
import tensorflow as tf
parser = argparse.ArgumentParser()
parser.add_argument('savename', type=str)
parser.add_argument('--task', type=str)
parser.add_argument('--num_subs', type=int)
parser.add_argument('--macro_duration', type=int)
parser.add_argument('--num_rollouts', type=int)
parser.add_argument('--iter', type=int, default=10000)
parser.add_argument('--warmup_time', type=int)
parser.add_argument('--train_time', type=int)
parser.add_argument('--force_subpolicy', type=int)
parser.add_argument('--seed', type=int, default=1401)
parser.add_argument('--num_hid', type=int, default=2)
parser.add_argument('--replay', type=str)
parser.add_argument('-s', action='store_true')
parser.add_argument('--continue_iter', type=str)
parser.add_argument('--logdir', type=str, default="../logs")
parser.add_argument('--loaddir', type=str, default="")
parser.add_argument('--fixed_task', type=int, default=-1)
args = parser.parse_args()

# python main.py --task MovementBandits-v0 --num_subs 2 --macro_duration 10 --num_rollouts 1000 --warmup_time 60 --train_time 1 --replay True test

from mpi4py import MPI
from rl_algs.common import set_global_seeds, tf_util as U
from rl_algs import logger
import os.path as osp
import gym, logging
import numpy as np
from collections import deque
from gym import spaces
import misc_util
import sys
import shutil
import subprocess
import master

def str2bool(v):
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

replay = str2bool(args.replay)
args.replay = str2bool(args.replay)

RELPATH = osp.join(args.savename)
LOGDIR = osp.join(args.logdir, RELPATH)
LOADDIR = osp.join(args.loaddir)


def callback(it):
    if MPI.COMM_WORLD.Get_rank() == 0:
        if it % 5 == 0 and it > 3 and not replay:
            fname = osp.join(LOGDIR, "savedir/", 'checkpoints', '%.5i'%it)
            U.save_state(fname)
    if it == 0 and args.continue_iter is not None:
        fname = osp.join(LOADDIR, "savedir/", "checkpoints/", str(args.continue_iter))
        U.load_state(fname)

        # Plotting part
        reader = tf.train.NewCheckpointReader(fname)
        shapes_dict = reader.get_variable_to_shape_map()  # use it to get the variable names
        from gym.envs.toy_text.taxi_custom import Taxi
        from scipy.special import softmax
        env = Taxi()
        subpolicy_probs = np.stack([softmax(reader.get_tensor(f"sub_policy_{i}/polfinal/w") + reader.get_tensor(f"sub_policy_{i}/polfinal/b"), axis=-1) for i in range(4)], axis=1)
        env.plot_params(option_probs=softmax(reader.get_tensor("policy/masterpol_final/w") + reader.get_tensor(f"policy/masterpol_final/b"), axis=-1), policy_probs=subpolicy_probs)
        pass


def train():
    num_timesteps=1e9
    seed = args.seed
    rank = MPI.COMM_WORLD.Get_rank()
    sess = U.single_threaded_session()
    sess.__enter__()
    workerseed = seed + 1000 * MPI.COMM_WORLD.Get_rank()
    rank = MPI.COMM_WORLD.Get_rank()
    set_global_seeds(workerseed)

    # This should disable logging for non 0 cores
    if rank != 0:
        logger.set_level(logger.DISABLED)

    # Split into theta groups
    world_group = MPI.COMM_WORLD.Get_group()
    mygroup = rank % 10
    theta_group = world_group.Incl([x for x in range(MPI.COMM_WORLD.size) if (x % 10 == mygroup)])
    comm = MPI.COMM_WORLD.Create(theta_group)
    comm.Barrier()
    # comm = MPI.COMM_WORLD

    # Call to master comm is the group of 10 NOT WORLD
    master.start(callback, args=args, workerseed=workerseed, rank=rank, comm=comm)


def main():
    if MPI.COMM_WORLD.Get_rank() == 0 and osp.exists(LOGDIR):
        MPI.COMM_WORLD.Abort()
        # shutil.rmtree(LOGDIR)
    MPI.COMM_WORLD.Barrier()
    with logger.session(dir=LOGDIR, format_strs=['stdout', 'log', 'tensorboard']):
        train()


if __name__ == '__main__':
    main()
