import gym
import test_envs
import tensorflow as tf
import rollouts
from policy_network import Policy
from subpolicy_network import SubPolicy
from observation_network import Features
from learner import Learner
import rl_algs.common.tf_util as U
import numpy as np
# from tinkerbell import logger
from rl_algs import logger
import pickle

def start(callback, args, workerseed, rank, comm):
    env = gym.make(args.task)
    env.seed(workerseed)
    np.random.seed(workerseed)
    ob_space = env.observation_space
    ac_space = env.action_space

    num_subs = args.num_subs
    num_hid = args.num_hid
    macro_duration = args.macro_duration
    num_rollouts = args.num_rollouts
    warmup_time = args.warmup_time
    train_time = args.train_time

    num_batches = 15

    # observation in.
    ob = U.get_placeholder(name="ob", dtype=tf.float32, shape=[None, ob_space.shape[0]])
    # ob = U.get_placeholder(name="ob", dtype=tf.float32, shape=[None, 104])

    # features = Features(name="features", ob=ob)
    policy = Policy(name="policy", ob=ob, ac_space=ac_space, hid_size=64, num_hid_layers=num_hid, num_subpolicies=num_subs)
    old_policy = Policy(name="old_policy", ob=ob, ac_space=ac_space, hid_size=64, num_hid_layers=num_hid, num_subpolicies=num_subs)

    sub_policies = [SubPolicy(name="sub_policy_%i" % x, ob=ob, ac_space=ac_space, hid_size=64, num_hid_layers=num_hid) for x in range(num_subs)]
    old_sub_policies = [SubPolicy(name="old_sub_policy_%i" % x, ob=ob, ac_space=ac_space, hid_size=64, num_hid_layers=num_hid) for x in range(num_subs)]

    # Learner gets local comm
    # Master adam initialized with local comm
    # Subpolicy adams with global comm
    learner = Learner(env, policy, old_policy, sub_policies, old_sub_policies, comm, clip_param=0.2, entcoeff=0, optim_epochs=10, optim_stepsize=3e-4, optim_batchsize=64)
    rollout = rollouts.traj_segment_generator(policy, sub_policies, env, macro_duration, num_rollouts, stochastic=True, args=args)


    for x in range(args.iter):
        callback(x)
        if x == 0:
            learner.syncSubpolicies()
            print("synced subpols")
        # Run the inner meta-episode.

        # Master policy reset and synchronization locally!
        policy.reset()
        learner.syncMasterPolicies()

        # Same environment locally
        env.env.randomizeCorrect()
        shared_goal = comm.bcast(env.env.realgoal, root=0)
        if args.fixed_task == -1:
            env.env.realgoal = shared_goal
        else:
            env.env.realgoal = args.fixed_task

        print("It is iteration %d so i'm changing the goal to %s" % (x, env.env.realgoal))
        mini_ep = 0 if x > 0 else -1 * (rank % 10)*int(warmup_time+train_time / 10)
        # mini_ep = 0
        totalmeans = []
        logger.info(f"Epoch {x+1}")
        logger.logkv(f"Env", env.env.realgoal)
        print(mini_ep)
        while mini_ep < warmup_time+train_time:
            mini_ep += 1
            # rollout
            # Gather episode data (they use D=2000 in paper)
            rolls = rollout.__next__()
            allrolls = []
            allrolls.append(rolls)
            # train theta
            rollouts.add_advantage_macro(rolls, macro_duration, 0.99, 0.98)
            gmean, lmean = learner.updateMasterPolicy(rolls)
            ldiscret = np.mean(rolls["discrets"])
            # Above globally synchronized so we get correct averages
            # train phi
            test_seg = rollouts.prepare_allrolls(allrolls, macro_duration, 0.99, 0.98, num_subpolicies=num_subs)
            learner.updateSubPolicies(test_seg, num_batches, (mini_ep >= warmup_time))
            # learner.updateSubPolicies(test_seg,
            # log
            if args.fixed_task == -1:
                logger.logkv(f"Return", gmean)
            else:
                logger.logkv(f"Return", lmean)
                logger.logkv(f"DiscountedReturn", ldiscret)
            logger.dumpkvs()
            # print(f"Epoch {x+1} Env {shared_goal} {phase} {i_phase} Global: {gmean} Local: {lmean} ")
            # print(("%d: global: %s, local: %s" % (mini_ep, gmean, lmean)))
            # if args.s and ((x % 20) == 19):
            #     totalmeans.append(gmean)
            #     with open('savedir/outfile'+str(x)+'.pickle', 'wb') as fp:
            #         pickle.dump(totalmeans, fp)



