import argparse
import numpy as np
import time
import gym
from copy import copy
import os

import torch
import torch.multiprocessing as mp
from torchvision import transforms

from src.networks import *
from src.dynamicmap import DynamicMap
from src.rl import GlimpseAgent
from src.goalsearch import GoalSearchSimple

from pytorch_rl import callbacks, agents, algorithms, policies, networks
from pytorch_rl.utils import ImgToTensor
from gym_minigrid.wrappers import *
from src.minigrid import OneHotDynamicObjectsWrapper

# def preprocess(x):
    # return x

if __name__ == '__main__':
    torch.multiprocessing.set_start_method("spawn")

    device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
    SEED = 123
    max_train_steps = 80000000

    np.random.seed(SEED)
    torch.manual_seed(SEED)

    # physenv full
    # trunk = ConvTrunk84
    # nb_actions = 4
    # env_name = 'PhysEnv-v1'

    # goalsearch
    ENV_SIZE = 10
    CHANNELS = 4
    NB_ACTIONS = 4
    trunk = FlattenTrunk
    env_name = 'GoalSearch-v2'

    def make_env():
        if env_name == 'GoalSearch-v2':
            return GoalSearchSimple(10)
        elif env_name == 'minigrid-v0':
            env = gym.make('MiniGrid-Dynamic-Obstacles-16x16-v0')
            env = OneHotDynamicObjectsWrapper(env)
            return ImgObsWrapper(env) # Get rid of the 'mission' field
        elif env_name == 'PhysEnv-v2':
            return None
        else:
            raise ValueError("Uknown env_name")

    env = make_env()
    obs_shape = (CHANNELS, ENV_SIZE, ENV_SIZE)

    policy = policies.MultinomialPolicy()
    ppo = algorithms.PPO(
        actor_critic_arch=networks.ActorCritic,
        trunk_arch=trunk,
        state_shape=obs_shape,
        action_space=NB_ACTIONS,
        policy=policy,
        ppo_epochs=4,
        clip_param=0.1,
        target_kl=0.01,
        minibatch_size=256,
        device=device,
        gamma=0.99,
        lam=0.95,
        clip_value_loss=False,
        value_loss_weighting=0.5,
        entropy_weighting=0.01)
    save_dir = '{}/train_rl_full_1'.format(env_name)
    calls = [callbacks.PrintCallback(freq=10),
                 callbacks.SaveMetrics(
                     save_dir=save_dir,
                     freq=1000,),
                ]
    agent = agents.MultithreadedOnPolicyDiscreteAgent(
        algorithm=ppo,
        policy=policy,
        nb_rollout_steps=128,
        state_shape=obs_shape,
        max_env_steps=1.01*max_train_steps,
        test_freq=100000,
        nb_threads=4,
        frame_stack=1,
        device=device,
        callbacks=calls,)
    preprocess = ImgToTensor()
    agent.callbacks.append(callbacks.SaveNetworks(
        save_dir=save_dir,
        freq=100,
        network_func=agent.tosave))
    agent.train(make_env, preprocess)

