import argparse
from distutils.util import strtobool
import numpy as np

def parse_args() :
    parser = argparse.ArgumentParser()
    
    # basic setup
    parser.add_argument("--wandb", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True)
    parser.add_argument("--seed", type=int, default = 27)
    parser.add_argument("--device", type=int, default = 7)
    parser.add_argument("--directory", type=str, default = None)
    parser.add_argument("--load", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, 
                        help = "whether to load model from directory, default false")

    # environment parameters
    parser.add_argument("--state_dim", type=int, default = 5,
                        help = "dimension of the state space")
    parser.add_argument("--action_n", type=int, default = 5,
                        help = "number of actions")
    parser.add_argument("--dataset", type=str, default = None, 
                        help = "path to dataset if it is stored on the disk, default regenerated from at the beginning of each run")
    parser.add_argument("--env", type=str, default = 'LinearDynamicalSystem', choices = ['CyclicHard', 'CyclicEasy', 'HMM', 'MatRot', 'Grid', 'Dihedral', 'CyclicRealHMM', 'AutoRegression', 'LinearRegression', 'LinearDynamicalSystem'])
    parser.add_argument("--rank", type=int, default = None,
                        help = "deprecated, do not use")
    parser.add_argument("--perturb", type=int, default = 0,
                        help = "deprecated, do not use")
    parser.add_argument("--length", type=int, default = 120,
                        help = "length of the trajectory")
    parser.add_argument("--alpha", type=float, default = 0.0083333333,
                        help = "alpha for the CyclicHMM-HARD environment")
    parser.add_argument("--eps", type=float, default = 0.00,
                        help = "epsilon for the CyclicHMM-RND environment")
    parser.add_argument("--real", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
                        help = "wheter to use state instead of belief state as the state distribution")

    # network setup
    parser.add_argument("--agent", type=str, default = 'RNN', choices = ['RNN', 'LSTM', 'NRNN', 'TF', 'MLP'])
    parser.add_argument("--num_envs", type=int, default = 256,
                        help = "depracated, do not use")
    parser.add_argument("--num_layers", type=int, default = None, nargs = '?', const = 1,
                        help = "number of layers in the network")
    parser.add_argument("--loss", type=str, default = None, nargs = '?', const = 'ce', choices = ['ce', 'mse'])
    parser.add_argument("--hidden_dim", type=int, default = 512)
    parser.add_argument("--pos_embed", type=str, default = 'learnable', choices = ['vanilla', 'learnable', 'none', 'additive'])
    parser.add_argument("--tf_model", type=str, default = 'scratch', choices = ['gpt2', 'scratch'],
                        help = "use hugging face gpt2 or a handcrafted transformer model")
    # parameters for the transformer model
    parser.add_argument("--nn_max_len", type=int, default = 256)
    parser.add_argument("--num_heads", type=int, default = 8)
    parser.add_argument("--mlp_layers", type=int, default = 2)
    parser.add_argument("--dropout", type=float, default = 0.1)
    
    # training setup
    parser.add_argument("--warmup", type=int, default = 4000,
                        help = "number of warmup steps")
    parser.add_argument("--num_traj", type=int, default = 5000000,
                        help = "number of trajectories to in the training set")
    parser.add_argument("--curriculum", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
                        help = "whether to use curriculum learning")
    parser.add_argument("--curriculum_double", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
                        help = "whether to use double curriculum schedule")
    parser.add_argument("--curriculum_init", type=int, default = 2,
                        help = "if not using curriculum_double, the initial length of the curriculum")
    parser.add_argument("--curriculum_step", type=int, default = 2,
                        help = "if not using curriculum_double, the step size of the curriculum")
    parser.add_argument("--curriculum_update_freq", type=int, default = 30,
                        help = "if not using curriculum_double, the frequency of updating the curriculum")
    parser.add_argument("--epoch", type=int, default = 100,
                        help = "number of training epochs")
    parser.add_argument("--batch_size", type=int, default = 256)
    parser.add_argument("--lr", type=float, default = 1e-3,
                        help = "initial learning rate")
    parser.add_argument("--lr_decay_gap", type=int, default = 20,
                        help = "number of epochs before decaying the learning rate")
    parser.add_argument("--lr_decay_rate", type=float, default = 0.5,
                        help = "decay rate of the learning rate")
    parser.add_argument("--max_grad_norm", type=float, default = 5.0)
    parser.add_argument("--fresh_data_per_epoch", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
                        help = "whether to generate fresh data at the beginning of each epoch")

    # parallel setup
    parser.add_argument("--world_size", type=int, default = 1,
                        help = "number of parallel gpu workers")
    parser.add_argument("--gpu_bias", type=int, default = 6,
                        help = "the first gpu to use")
    parser.add_argument("--port", type=int, default = 12227)
    
    #enable block CoT
    parser.add_argument("--block_size", type=int, default = 0) #if > 0, enable block CoT

    # logging setup
    parser.add_argument("--log_every_steps", type=int, default = 1000)
    parser.add_argument("--log_evaluate_step", type=int, default = 2)
    parser.add_argument("--noeval", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True)

    args = parser.parse_args()
        
    if args.num_layers is None :
        if args.agent == 'TF' :
            args.num_layers = 2
        else :
            args.num_layers = 1
        
    # setting appropriate loss for the environment
    if args.loss is None :
        if args.env in ('MatRot', 'LinearRegression', 'LinearDynamicalSystem') :
            args.loss = 'mse'
        else :
            args.loss = 'ce'

    # this two envs require args.real = True
    if args.env == 'CyclicHard' or args.env == 'LinearDynamicalSystem' :
        args.real = True

    # set up curriculum learning
    if args.curriculum :
        if args.curriculum_double is False :
            assert (args.length - args.curriculum_init) % args.curriculum_step == 0, "curriculum_step should be a divisor of length - curriculum_init"
            args.curriculum_update_freq = args.epoch // (1 + (args.length - args.curriculum_init) // args.curriculum_step)
        else :
            args.curriculum_update_freq = args.epoch // (1 + np.ceil(np.log2(args.length) - args.num_layers))
    return args