# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
defaults:
- _self_
- env_config: crafter
- task: run
- override hydra/job_logging: colorlog
- override hydra/hydra_logging: colorlog
# - hydra/launcher: submitit_slurm

# # To Be Used With hydra submitit_slurm if you have SLURM cluster
# # pip install hydra-core hydra_colorlog
# # can set these on the commandline too, e.g. `hydra.launcher.partition=dev`
# hydra:
#   launcher:
#     timeout_min: 4300
#     cpus_per_task: 20  
#     gpus_per_node: 2
#     tasks_per_node: 1
#     mem_gb: 20
#     nodes: 1
#     partition: dev
#     comment: null  
#     max_num_timeout: 5  # will requeue on timeout or preemption


name: null  # can use this to have multiple runs with same params, eg name=1,2,3,4,5
data_root: null

## WANDB settings
wandb: false                 # Enable wandb logging.
project: nethack-impala   # The wandb project name.
entity: # The wandb user to log to.
group: test                # The wandb group for the run.
wandb_name: null

# POLYBEAST ENV settings
seed: 0
mock: false                  # Use mock environment instead of NetHack.
single_ttyrec: true          # Record ttyrec only for actor 0.
num_seeds: 0                 # If larger than 0, samples fixed number of environment seeds to be used.'
write_profiler_trace: false  # Collect and write a profiler trace for chrome://tracing/.
fn_penalty_step: constant    # Function to accumulate penalty.
penalty_time: 0.0            # Penalty per time step in the episode.
penalty_step: -0.01          # Penalty per step in the episode.
reward_lose: 0               # Reward for losing (dying before finding the staircase).
reward_win: 100              # Reward for winning (finding the staircase).
state_counter: none          # Method for counting state visits. Default none.
character: 'mon-hum-neu-mal' # Specification of the NetHack character.
                              ## typical characters we use
                                # 'mon-hum-neu-mal'
                                # 'val-dwa-law-fem'
                                # 'wiz-elf-cha-mal'
                                # 'tou-hum-neu-fem'
                                # '@'   # random (used in Challenge assessment)

# RUN settings.
mode: train                  # Training or test mode.
env: nethack-score               # Name of Gym environment to create.
                             # # env (task) names: challenge, staircase, pet, 
                             #     eat, gold, score, scout, oracle

# TRAINING settings.
num_actors: 256              # Number of actors.
total_steps: 1e9             # Total environment steps to train for. Will be cast to int.
total_episodes: null
batch_size: 32               # Learner batch size.
unroll_length: 80            # The unroll length (time dimension).
num_learner_threads: 1       # Number learner threads.
max_learner_queue_size: null
num_inference_threads: 1     # Number inference threads.
disable_cuda: false          # Disable CUDA.
learner_device: cuda:0       # Set learner device.
actor_device: cuda:0         # Set actor device.

# OPTIMIZER settings. (RMS Prop)
learning_rate: 0.0002        # Learning rate.
grad_norm_clipping: 40       # Global gradient norm clip.
alpha: 0.99                  # RMSProp smoothing constant.
momentum: 0                  # RMSProp momentum.
epsilon: 0.000001            # RMSProp epsilon.

# LOSS settings.
entropy_cost: 0.001          # Entropy cost/multiplier.
baseline_cost: 0.5           # Baseline cost/multiplier.
discounting: 0.999           # Discounting factor.
normalize_reward: true       # Normalizes reward by dividing by running stdev from mean.
clip_reward: false           # Clip reward to [-1, 1]

# MODEL settings.
model: baseline              # Name of model to build (see models/__init__.py).
use_lstm: true               # Use LSTM in agent model.
hidden_dim: 256              # Size of hidden representations.
embedding_dim: 64            # Size of glyph embeddings.
layers: 5                    # Number of ConvNet Layers for Glyph Model
crop_dim: 9                  # Size of crop (c x c)
use_index_select: true       # Whether to use index_select instead of embedding lookup (for speed reasons).
restrict_action_space: True  # Use a restricted ACTION SPACE (only nethack.USEFUL_ACTIONS)

msg:                      
  hidden_dim: 64             # Hidden dimension for message encoder.
  embedding_dim: 32          # Embedding dimension for characters in message encoder.

# TEST settings.    
load_dir: null               # Path to load a model from for testing

savedir: null

save_env: false
pred_model: false

actor_load_dir: null
policy_migrate: false
no_train_actor: false
no_train_pred: true
no_reward_pred: false
no_contrast_loss: false

multi_objective: false
include_new_tasks: false
num_objectives: 1
cluster_load_dir: null
cluster_pred_model_load_dir: null
cluster_threshold: null
done_at_reward: false
causal_graph_load_path: null
objective_as_input: false
resume: false
no_reward_exploration: false
frame_pred_delta: false
frame_pred_alpha: false
transpose_cnn_version: 1
dynamics_pred: false
dynamics_k: 10
dynamics_z: 1
dynamics_contrast: true
dynamics_contrast_coef: 0.1
dynamics_downsample: false
dynamics_discrete: true
contrast_with_pred_diff: false
baseline_only: false
frame_pred_no_moving: false
frame_pred_embed_eps: 0.0
frame_pred_embed_scale: false
frame_pred_embed_std_detach: true
frame_pred_causal_predictor: false
frame_pred_rnd: false
frame_pred_causal_clustering: false
frame_pred_causal_clustering_criterion: 'causal'
frame_pred_mask_inventory: false
frame_pred_chain_contrast: false
frame_pred_discrete_exploration: false
frame_pred_variational_loss: false
frame_pred_error_prediction: false

pred_no_next_frame: false

goal_generation: false

crafter_original: false
crafter_static: false
crafter_repeat_deduction: 0.0

predict_items:
  reward: [1]

dict_key: ${env}


actor_load_dirs:
  crafter: "${data_root}/crafter/vec-original_env/checkpoint.tar"
  crafter-vanila: "${data_root}/crafter/run/test-vanila/checkpoint.tar"
  minigrid-keycorridor: "${data_root}/minigrid/keycorridor/run-and-pred/fixed-contrast-only-no-next/checkpoint.tar"
  minigrid-blockedunlockpickup: "${data_root}/minigrid/blockedunlockpickup/run/test-fixed/checkpoint.tar"
  minigrid-distractions: "${data_root}/minigrid/distractions/run/test-hard/checkpoint.tar"
  # minigrid-keycorridor: "${data_root}/minigrid/keycorridor/run/test/checkpoint.tar"

mo_actor_load_dirs:
  crafter: "${data_root}/crafter/mo/explore-new-tasks/checkpoint.tar"

num_objectives_dict:
  crafter: 17
  crafter-vanila: 18
  minigrid-distractions: 19

cluster_load_dirs:
  # crafter: "${data_root}/crafter/test-pred-adam/cluster.data"
  crafter: "${data_root}/crafter/test-pred-adam/cluster.new.data"
  crafter-mo: "${data_root}/crafter/mo-pred/mo-contrast-only/cluster.data"
  crafter-vanila: "${data_root}/crafter/pred/vanila/cluster.data"
  minigrid-distractions: "${data_root}/minigrid/distractions/pred/contrast-only/cluster.data"

cluster_pred_model_load_dirs:
  # crafter: "${data_root}/crafter/test-pred-adam"
  crafter: "${data_root}/crafter/mo-pred/mo-contrast-only"
  crafter-vanila: "${data_root}/crafter/pred/vanila"
  minigrid-distractions: "${data_root}/minigrid/distractions/pred/contrast-only"

cluster_thresholds:
  crafter: 3.5
  crafter-vanila: 0.08
  minigrid-distractions: 100

causal_graph_load_paths:
  # crafter: "${data_root}/crafter/test-pred-adam/graph.new.data"
  crafter: "${data_root}/crafter/mo-pred/mo-contrast-only/graph.data"
  crafter-vanila: "${data_root}/crafter/pred/vanila/graph.data"
  minigrid-distractions: "${data_root}/minigrid/distractions/pred/contrast-only/graph.data"
  # minigrid-distractions: null


num_events_dict:
  crafter: 22


use_crafter_monitor: false
use_rnd: false
rnd_output_dim: 128

pred_supervised: false
num_events: null
