import os
import copy
import yaml
import random
import argparse
import setproctitle

import torch
import numpy as np

from runner.mas_evaluator import ModelSimEvaluator

def get_args():
    parser = argparse.ArgumentParser(description="Model as Simulator")

    # environment settings
    parser.add_argument("--env", type=str, default="d4rl")
    parser.add_argument("--env-name", type=str, default="hopper-medium-v2")
    
    # dynamics model parameters
    # "sadm" -- our proposed self-transition any-step dynamics model
    # "adm"  -- original any-step dynamics model
    # "en"   -- ensemble dynamics model
    parser.add_argument("--dyna-model", type=str, default="sadm", choices=["sadm", "adm", "en"])
    parser.add_argument("--model-hidden-dim", type=int, default=200)
    parser.add_argument("--rollout-length", type=int, default=1000)
    parser.add_argument("--given-reward", type=bool, default=True)
    parser.add_argument("--load-label", type=str, default="sadm-sac")
    parser.add_argument("--load-time", type=str, default="25-0521-094857")
    parser.add_argument("--load-seed", type=int, default=0)
    parser.add_argument("--max-adm-step", type=int, default=5)                          # maximum length of rnn input
    # only for sadm
    parser.add_argument("--n-starts", type=int, default=5)

    # policy parameters
    parser.add_argument("--algo", type=str, default="sac", choices=["sac", "td3", "ppo"])
    parser.add_argument("--ac-hidden-dims", type=list, default=[256, 256])              # dimensions of actor/critic hidden layers
    parser.add_argument("--actor-lr", type=float, default=1e-4)                         # learning rate of actor
    parser.add_argument("--lr-schedule", type=bool, default=False)
    parser.add_argument("--critic-lr", type=float, default=3e-4)                        # learning rate of critic
    parser.add_argument("--gamma", type=float, default=0.99)                            # discount factor
    parser.add_argument("--tau", type=float, default=0.005)                             # update rate of target network
    parser.add_argument("--penalty-coef", type=float, default=0)                        # penalty coefficient
    # for SAC
    parser.add_argument("--alpha", type=float, default=0.05)                            # weight of entropy
    parser.add_argument("--auto-alpha", type=bool, default=True)                        # auto alpha adjustment
    parser.add_argument("--alpha-lr", type=float, default=1e-4)                         # learning rate of alpha
    parser.add_argument("--target-entropy", type=int, default=None)                     # target entropy
    parser.add_argument("--deterministic-backup", type=bool, default=False)
    parser.add_argument("--q-clip", type=float, default=None)
    # for TD3
    parser.add_argument("--explore-noise", type=float, default=0.25)
    parser.add_argument("--target-policy-noise", type=float, default=0.2)
    parser.add_argument("--noise-clip", type=float, default=0.5)
    # for PPO
    parser.add_argument("--gae-lambda", type=float, default=0.95)
    parser.add_argument("--clip-ratio", type=float, default=0.2)
    parser.add_argument("--value-clip", type=bool, default=True)
    parser.add_argument("--value-coef", type=float, default=0.5)
    parser.add_argument("--entropy-coef", type=float, default=0.01)
    parser.add_argument("--max-grad-norm", type=float, default=1.0)
    parser.add_argument("--ppo-slice", type=int, default=24)
    parser.add_argument("--ppo-epoch", type=int, default=5)

    # running parameters
    parser.add_argument("--device", type=str, default="cuda:0")
    parser.add_argument("--seed", type=int, default=0)

    args = parser.parse_args()
    return args

def main():
    """ main function """
    args = get_args()
    algo_yml_path = "./config/{}/{}.yml".format(args.env, args.env_name.split("-v")[0])
    algo_yml = yaml.load(open(algo_yml_path, 'r'), Loader=yaml.FullLoader)
    for key, value in algo_yml.items():
        setattr(args, key, value)

    setproctitle.setproctitle("{} {}".format(args.algo.upper(), args.env_name))

    random.seed(args.seed)
    np.random.seed(args.seed)
    os.environ["PYTHONHASHSEED"] = str(args.seed)

    # set seed of torch
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    runner = ModelSimEvaluator(copy.deepcopy(args))
    runner.run()

if __name__ == "__main__":
    main()