from algos import PPO, RolloutStorage, ACAgent
from models import MultigridNetwork, NetHackAgentNet


def model_for_multigrid_agent(
    env,
    recurrent_arch=None,
    recurrent_hidden_size=256):
    observation_space = env.observation_space
    action_space = env.action_space
    num_directions = observation_space['direction'].high[0] + 1 

    model = MultigridNetwork(
        observation_space=observation_space, 
        action_space=action_space,
        scalar_fc=5,
        scalar_dim=num_directions,
        recurrent_arch=recurrent_arch,
        recurrent_hidden_size=recurrent_hidden_size)

    return model


def model_for_minihack_agent(
    env,
    recurrent_arch=None,
    recurrent_hidden_size=256):
    
    observation_space = env.observation_space
    action_space = env.action_space 

    model = NetHackAgentNet(
        observation_shape=observation_space,
        num_actions=action_space.n,
        use_lstm=recurrent_arch is not None,
        rnn_hidden_size=recurrent_hidden_size)

    return model


def model_for_env_agent(
    env_name,
    env,
    recurrent_arch=None,
    recurrent_hidden_size=256,
    use_global_critic=False,
    use_skip=False,
    choose_start_pos=False,
    use_popart=False,
    adv_use_popart=False,
    use_categorical_adv=False,
    use_goal=False,
    num_goal_bins=1):
        
    if env_name.startswith('MultiGrid'):
        model = model_for_multigrid_agent(
            env=env, 
            recurrent_arch=recurrent_arch,
            recurrent_hidden_size=recurrent_hidden_size)
    elif env_name.startswith('MiniHack'):
        model = model_for_minihack_agent(
            env=env,
            recurrent_arch=recurrent_arch,
            recurrent_hidden_size=recurrent_hidden_size)
    else:
        raise ValueError(f'Unsupported environment {env_name}.')

    return model


def make_agent(name, env, args, device='cpu'):
    # Create model instance
    observation_space = env.observation_space
    action_space = env.action_space
    num_steps = args.num_steps
    recurrent_arch = None
    recurrent_arch = None if not args.recurrent_agent else args.recurrent_arch
    entropy_coef = args.entropy_coef
    ppo_epoch = args.ppo_epoch
    num_mini_batch = args.num_mini_batch
    max_grad_norm = args.max_grad_norm
    use_popart = vars(args).get('use_popart', False)

    recurrent_hidden_size = args.recurrent_hidden_size

    actor_critic = model_for_env_agent(
        args.env_name, env, 
        recurrent_arch=recurrent_arch,
        recurrent_hidden_size=recurrent_hidden_size)

    algo = None
    storage = None
    agent = None

    use_proper_time_limits = \
        env.get_max_episode_steps() is not None and vars(args).get('handle_timelimits', False)

    if args.algo == 'ppo':
        # Create PPO
        algo = PPO(
            actor_critic=actor_critic,
            clip_param=args.clip_param,
            ppo_epoch=ppo_epoch,
            num_mini_batch=num_mini_batch,
            value_loss_coef=args.value_loss_coef,
            entropy_coef=entropy_coef,
            lr=args.lr,
            eps=args.eps,
            max_grad_norm=max_grad_norm,
            clip_value_loss=args.clip_value_loss,
            log_grad_norm=args.log_grad_norm
        )

        # Create storage
        storage = RolloutStorage(
            model=actor_critic,
            num_steps=num_steps,
            num_processes=args.num_processes,
            observation_space=observation_space,
            action_space=action_space,
            recurrent_hidden_state_size=args.recurrent_hidden_size,
            recurrent_arch=args.recurrent_arch,
            use_proper_time_limits=use_proper_time_limits,
            use_popart=use_popart
        )

        agent = ACAgent(algo=algo, storage=storage).to(device)

    else:
        raise ValueError(f'Unsupported RL algorithm {algo}.')

    return agent
