import argparse

parser = argparse.ArgumentParser()
par = parser.add_argument

# Environment setup
par("--env_id",         default="simple_speaker_listener", type=str)        # Environment name
par("--model_name",     default="test", type=str)               # Name for the log
par("--seed",           default=1, type=int)                    # Random seed
par("--agent_algo",     default="cen", type=str)                # if "dec", DDPG
par("--adv_algo",       default="cen", type=str)                # Algorithm of adversaries (e.g., for predator)
# Encapsulation setup
par("--size_batch",     default=1024, type=int)                 # Size of batch (training batch)
par("--len_buf",        default=int(1e6), type=int)             # Size of buffer
par("--len_epi",        default=25, type=int)                   # Maximum length of episodes
par("--n_epi",          default=10000, type=int)                # Number of episodes
#par("--n_epi",          default=25000, type=int)
# Exploration setup
par("--s_noise",        default=0.3, type=float)                # End of noise for exploration
par("--e_noise",        default=0, type=float)                  # End of noise for exploration
par("--n_epi_exp",      default=25000, type=int)                # Number of episodes
# Evaluation setup
par("--save_interval",  default=500, type=int)                  # Run test every save_interval
par("--save_iter",      default=25, type=int)                   # At each test, run save_iter episodes
# Extra setup
par("--freq_update",    default=100, type=int)                  # Update networks every 100 steps
par("--hidden_dim",     default=128, type=int)                  # Neural networks idden dimensions
par("--lr",             default=0.01, type=float)               # Learning rate for both the critic and the policy
par("--tau",            default=0.01, type=float)               # Soft target update
# Experimental setup
par("--drop_p",         default=0, type=float)                  # Training data missing probability
par("--unscale",        default=False, action='store_true')
# Here, scale means slow down training of IA-MARL due to mask-based update
# Unscale option makes undo such slow down
# Buffer
par("--buf_train",      default=2, type=int)                    #
par("--buf_flush",      default=2, type=int)                    #
par("--pre_epi",        default=2500, type=int)                 # Pretraining episodes for GAIN

# Parameters for the ablation study and gain loss estimation
par("--ablation",       default=False, action='store_true')     # If True, simple-IA-MARL
par("--random_imputation", default=False, action='store_true')  # If True, use random impuration instead of gain imputation
par("--imputation_loss", default=False, action='store_true')    # For the debug

cfg = parser.parse_args()


import torch
from pathlib import Path
from tensorboardX import SummaryWriter
import numpy as np
from utils import make_env
if cfg.ablation:        # simple-IA-MARL
    cfg.mask_type = 0   # sample all data
else:
    cfg.mask_type = 2
if cfg.imputation_loss:
    gain_path = Path('model/gain_loss/{}'.format(cfg.env_id))
    pre_gain = [[], []]
    post_gain = [[], []]
    # pre_gain = Path('./model/gain_loss/{}/pre_{}_R{}'.format(cfg.env_id, cfg.model_name, cfg.seed))
    # post_gain = Path('./model/gain_loss/{}/post_{}_R{}'.format(cfg.env_id, cfg.model_name, cfg.seed))

log_dir = Path('./model/{}/{}_R{}'.format(cfg.env_id, cfg.model_name, cfg.seed))
eval_dir = Path('./model/{}'.format(cfg.env_id))

logger = SummaryWriter(str(log_dir))

torch.manual_seed(cfg.seed * 5 + 995)
np.random.seed(cfg.seed * 5 + 995)

env = make_env(cfg.env_id)
env.seed(cfg.seed * 5 + 995)

# Currently, implemetation support environments having 25 steps per episode.
assert cfg.len_epi == 25
##############################
cfg.n_epi += 1
cfg.buf_flush = cfg.buf_train
# cfg.n_epi_exp = cfg.n_epi
