import train
import argparse
import numpy as np

# Parse arguments
parser = argparse.ArgumentParser()


## General parameters
parser.add_argument("--env", required=True,
                    help="name of the environment to train on (REQUIRED)")
parser.add_argument("--log-interval", type=int, default=1,
                    help="number of updates between two logs (default: 1)")
parser.add_argument("--save-interval", type=int, default=10,
                    help="number of updates between two saves (default: 10, 0 means no saving)")
parser.add_argument("--procs", type=int, default=16,
                    help="number of processes (default: 16)")
parser.add_argument("--frames", type=int, default=2000000,
                    help="number of frames of training (default: 2000000)")
parser.add_argument("--device", type=int, default=0,
                    help="cuda device to use, if available (default: 0)")

## Hyperparams  
parser.add_argument("--nseeds", type=int, default=10,
                    help="number of seeds to run")
parser.add_argument("--seed", type=int, default=None,
                    help="seed, if not given it'll run for nseeds")
parser.add_argument("-nA", action="store_true", 
                    help="if set, it'll run non adaptive")
parser.add_argument("-ud", action="store_true", 
                    help="include uniform preferences in loss function")
parser.add_argument("--epsilon", type=str, default=None, 
                    help="the pref vector, sep. by commas (if not set it's calculated)")

## Parameters for main algorithm
parser.add_argument("--pgd-lr", type=float, default=0.03, 
                    help="larning rate for the projected gradient descent algorithm (default: 0.03)")
parser.add_argument("--epochs", type=int, default=10,
                    help="number of epochs (default: 10)")
parser.add_argument("--batch-size", type=int, default=64,
                    help="batch size (default: 64)")
parser.add_argument("--frames-per-proc", type=int, default=256,
                    help="number of frames per process before update (default: 256)")
parser.add_argument("--discount", type=float, default=0.99,
                    help="discount factor (default: 0.99)")
parser.add_argument("--lr", type=float, default=0.0003,
                    help="learning rate for Adam network optimizer (default: 0.0003)")
parser.add_argument("--gae-lambda", type=float, default=0.95,
                    help="lambda coefficient in GAE formula (default: 0.95, 1 means no gae)")
parser.add_argument("--entropy-coef", type=float, default=0.001,
                    help="entropy term coefficient (default: 0.001)")
parser.add_argument("--value-loss-coef", type=float, default=0.5,
                    help="value loss term coefficient (default: 0.5)")
parser.add_argument("--max-grad-norm", type=float, default=0.5,
                    help="maximum norm of gradient (default: 0.5)")
parser.add_argument("--optim-eps", type=float, default=1e-8,
                    help="Adam optimizer epsilon (default: 1e-8)")
parser.add_argument("--clip-eps", type=float, default=0.2,
                    help="clipping epsilon (default: 0.2)")
parser.add_argument("--obs-clip", type=float, default=None,
                    help="observation clipping (default: None)")
parser.add_argument('-s', action='store_true', help="if set store to tmp dir")

if __name__ == "__main__":
    args = parser.parse_args()
    if not args.seed:
        SEEDS = np.random.randint(low=1, high=1e5, size=(args.nseeds,)).tolist()
    else:
        SEEDS = [args.seed]

    if args.epsilon is not None:
        eps_str = args.epsilon.split(",")
        eps = [float(e) for e in eps_str]
        args.epsilon = eps 

    for seed in SEEDS:
        args.seed = seed 
        train.run(args)
