import argparse

import torch
import torch.multiprocessing as mp

from src.networks import *

from src.rl import DMMAgent, AttentionConstrainedEnvironment

from pytorch_rl import callbacks, agents, algorithms, policies, networks

from src.goalsearch import GoalSearchSimple

if __name__ == '__main__':
    torch.multiprocessing.set_start_method("spawn")

    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
    SEED = 456
    mode = 'follow'
    max_train_steps = 80000000

    np.random.seed(SEED)
    torch.manual_seed(SEED)

    # physenv
    ATTN_SIZE = 21
    ENV_SIZE = 84
    CHANNELS = 3
    NB_ACTIONS = 4
    MAP_SIZE = 21
    MAP_CHANNELS = 48
    trunk = ConvTrunk21
    env_name = 'PhysEnv-v1'

    # minigrid
    # trunk = FlattenTrunk
    # nb_actions = 3
    # env_name = 'minigrid-v0'

    # # goalsearch
    # ATTN_SIZE = 3
    # ENV_SIZE = 10
    # CHANNELS = 4
    # NB_ACTIONS = 4
    # MAP_SIZE = 10
    # MAP_CHANNELS = 16
    # trunk = FlattenTrunk
    # env_name = 'GoalSearch-v2'


    obs_shape = (CHANNELS, ENV_SIZE, ENV_SIZE)
    state_shape = (MAP_CHANNELS, MAP_SIZE, MAP_SIZE)


    def make_env():
        if env_name == 'minigrid-v0':
            env = gym.make('MiniGrid-Dynamic-Obstacles-16x16-v0')
            env = OneHotDynamicObjectsWrapper(env)
            env = ImgObsWrapper(env) # Get rid of the 'mission' field
            return AttentionConstrainedEnvironment(ENV_SIZE, ATTN_SIZE, device, env)
        elif env_name == 'PhysEnv-v1':
            return AttentionConstrainedEnvironment(ENV_SIZE, ATTN_SIZE, device, None)
        elif env_name == 'GoalSearch-v2':
            return AttentionConstrainedEnvironment(ENV_SIZE, ATTN_SIZE, device, GoalSearchSimple(10))
        else:
            raise ValueError("Uknown env_name")

    policy = policies.MultinomialPolicy()
    savedir = 'PhysEnv-v1/train_rl_{}_2'.format(mode)
    calls = [callbacks.PrintCallback(freq=10),
             callbacks.SaveMetrics(
                 save_dir=savedir,
                 freq=1000,),
             ]
    ppo = algorithms.PPO(
        actor_critic_arch=networks.ActorCritic,
        trunk_arch=trunk,
        state_shape=state_shape,
        action_space=NB_ACTIONS,
        policy=policy,
        ppo_epochs=4,
        clip_param=0.1,
        target_kl=0.01,
        minibatch_size=256,
        device=device,
        gamma=0.99,
        lam=0.95,
        clip_value_loss=False,
        value_loss_weighting=0.5,
        entropy_weighting=0.01)
    agent = DMMAgent(
        algorithm=ppo,
        policy=policy,
        mode=mode,
        nb_threads=4,
        nb_rollout_steps=128,
        max_env_steps=1.01*max_train_steps,
        state_shape=state_shape,
        obs_shape=obs_shape,
        nb_actions=NB_ACTIONS,
        attn_size=ATTN_SIZE,
        batchsize=8,
        device=device,
        callbacks=calls,)
    agent.callbacks.append(callbacks.SaveNetworks(
        save_dir=savedir,
        freq=100,
        network_func=agent.tosave))
    # finally train
    agent.train(make_env)

