import argparse
import os
from pathlib import Path

def str2bool(v):
    return v.lower() == "true"


def get_default_parser():
    parser = argparse.ArgumentParser(description="RL", conflict_handler="resolve")
    add_args(parser)
    return parser


def add_args(parser):
    #############################
    # INTERVALS
    #############################

    current_dir = Path.cwd()
    log_path = str(current_dir)
    
    parser.add_argument("--sanity", action="store_true", default=False)
    parser.add_argument("--sanity-verbose", action="store_true", default=False)
    parser.add_argument(
        "--sanity-cmd",
        type=str,
        default="",
        help="""
            Example format "after_rollout:5" stops after 5 calls to after_rollout
            """,
    )

    parser.add_argument(
        "--log-interval",
        type=int,
        default=1,
        help="log interval, one log per n updates (default: 1)",
    )
    parser.add_argument(
        "--save-interval",
        type=int,
        default=-1,
        help="save interval, one save per n updates (default: 100)",
    )
    parser.add_argument(
        "--eval-interval",
        type=int,
        default=100,
        help="eval interval, one eval per n updates (default: -1)",
    )

    #############################
    # RUN CONFIG
    #############################
    parser.add_argument(
        "--env-log-dir",
        default="/tmp/gym",
        help="directory to save agent logs (default: /tmp/gym)",
    )

    parser.add_argument(
        "--log-dir",
        default=os.path.join(log_path, "data", "log"),
        help="directory to save agent logs (default: /tmp/gym)",
    )

    parser.add_argument(
        "--sync",
        action="store_true",
        default=False,
        help="Whether to sync with properties specified in config",
    )

    parser.add_argument(
        "--save-dir",
        default=os.path.join(log_path, "data", "trained_models"),
        help="directory to save agent trained models (default: ./data/trained_models/)",
    )

    parser.add_argument("--prefix", default="debug", help="Run identifier")

    #############################
    # RL LOOP
    #############################
    parser.add_argument(
        "--num-env-steps",
        type=float,
        default=5e6,
        help="number of environment steps to train (default: 1e8)",
    )

    parser.add_argument(
        "--num-processes",
        type=int,
        default=32,
        help="how many training CPU processes to use (default: 32)",
    )

    parser.add_argument(
        "--env-name",
        help="environment to train on (default: PongNoFrameskip-v4)",
        required=True,
    )

    parser.add_argument(
        "--gamma",
        type=float,
        default=0.99,
        help="discount factor for rewards (default: 0.99)",
    )

    parser.add_argument(
        "--use-proper-time-limits",
        action="store_true",
        default=False,
        help="compute returns taking into account time limits",
    )

    parser.add_argument(
        "--num-steps",
        type=int,
        default=128,
        help="number of forward steps in A2C/PPO (old default: 128)",
    )

    parser.add_argument(
        "--seed", type=int, default=31, help="random seed (default: 31)"
    )
    parser.add_argument(
        "--detect-nan",
        type=str2bool,
        default=False,
        help="""
            If true, will use `torch.autograd.set_detect_anomaly`.
            """,
    )

    #############################
    # EVAL / RETRAIN
    #############################
    parser.add_argument(
        "--eval-only",
        action="store_true",
        default=False,
        help="""
            Enables rendering text to the observation from `mod_render_frames`
            of the env interface or the alg_env_settings.
            """,
    )
    parser.add_argument("--render-metric", action="store_true", default=False)
    parser.add_argument(
        "--render-succ-fails",
        action="store_true",
        default=False,
        help="""
            Renders successes and failures separately. When true,
            eval-num-processes must be 1.
            """,
    )
    parser.add_argument("--num-eval", type=int, default=5)
    parser.add_argument(
        "--log-smooth-len",
        type=int,
        default=100,
        help="Smoothing window for all logged statistics",
    )
    parser.add_argument(
        "--num-render", type=int, default=0, help="None places no limit"
    )
    parser.add_argument(
        "--resume", default=False, action="store_true", help="Resume training"
    )

    parser.add_argument(
        "--load-file", default="", help=".pt weights file for resuming or evaluating"
    )

    parser.add_argument(
        "--eval-num-processes",
        type=int,
        default=None,
        help="# of evaluation processes. When None use the same # as in non-evaluation",
    )
    parser.add_argument("--vid-fps", type=float, default=30.0)
    parser.add_argument("--vid-dir", type=str, default=os.path.join(log_path, "data", "vids"))

    #############################
    # POLICY
    #############################
    parser.add_argument(
        "--cuda", type=str2bool, default=True, help="disables CUDA training"
    )

    #############################
    # IMITATION LEARNING
    #############################
    parser.add_argument(
        "--eval-save",
        action="store_true",
        default=False,
        help="Save the trajectories from the evaluation",
    )
    parser.add_argument("--traj-dir", type=str, default=os.path.join(log_path, "data", "traj"))

    #############################
    # ENV
    #############################
    parser.add_argument("--normalize-env", type=str2bool, default=True)
    parser.add_argument("--clip-actions", type=str2bool, default=False)
    parser.add_argument("--frame-stack", type=str2bool, default=False)
    parser.add_argument("--policy-ob-key", type=str, default="observation")
    parser.add_argument("--force-multi-proc", type=str2bool, default=False)
    parser.add_argument(
        "--context-mode",
        type=str,
        default="fork",
        help="""
            The mode used for multi-processing. Valid options are "fork",
            "spawn", "forkserver". "fork" will not work if you are doing
            rendering in MuJoCo.
            """,
    )
    parser.add_argument("--transpose-frame", type=str2bool, default=True)
    parser.add_argument(
        "--warp-frame",
        type=str2bool,
        default=False,
        help="""
            Resizes to 84x84 and makes grayscale
            """,
    )
    parser.add_argument("--time-limit", type=float, default=None)

    #############################
    # RAY
    #############################
    parser.add_argument("--ray", action="store_true")
    parser.add_argument("--ray-config", type=str, default="{}")
    parser.add_argument("--ray-cpus", type=float, default=1)
    parser.add_argument("--ray-gpus", type=float, default=0)
    parser.add_argument("--ray-nsamples", type=int, default=1)
    parser.add_argument(
        "--ray-debug",
        action="store_true",
        help="""
            Turns on internal non-ray logging. Also enables `local_mode` for
            Ray.
            """,
    )

    #############################
    # MISC
    #############################
    parser.add_argument(
        "--use-gae",
        type=str2bool,
        default=True,
        help="use generalized advantage estimation",
    )

    parser.add_argument(
        "--gae-lambda",
        type=float,
        default=0.95,
        help="gae lambda parameter (default: 0.95)",
    )

    #############################
    parser.add_argument(
        "--eval-gen-heatmap", 
        type=str2bool, 
        default=False,
        help='add eval gen heatmap',
    )
    parser.add_argument(
        '--bc-init', 
        type=str2bool, 
        default=False,
        help='add bc init',    
    )
    parser.add_argument(
        "--add-expert-regularizer",
        type=str2bool,
        default=False,
        help='add regularizer to actor loss',
    )
    parser.add_argument(
        "--add-non-expert-entropy",
        type=str2bool,
        default=False,
        help='increase the entropy of non expert term',
    )
    parser.add_argument(
        "--non-expert-entropy-coef",
        type=float,
        default=1,
        help='coef of the entropy for non expert actions',
    )
    parser.add_argument(
        "--add-coef-decay",
        type=str2bool,
        default=True,
        help='add coefficient decay to regularizer',
    )
    parser.add_argument(
        "--add-expert-restricted-regularizer",
        type=str2bool,
        default=False,
        help='add restricted regularizer to actor loss',
    )
    parser.add_argument(
        "--expert-restricted-batch-size",
        type=int,
        default=32,
        help='batch size for restricted regularizer',
    )
    parser.add_argument(
        "--restricted-regularizer-sum-log",
        type=str2bool,
        default=False,
        help='For restricted regularizer, sum(logpi)',
    )
    parser.add_argument(
        "--restricted-regularizer-sum-prob",
        type=str2bool,
        default=True,
        help='For restricted regularizer, sum(pi)',
    )
    parser.add_argument(
        "--restricted-regularizer-related-ac",
        type=str2bool,
        default=True,
        help='For restricted regularizer, sum(pi)',
    )
    parser.add_argument(
        "--restricted-regularizer-sum-prob-wrap-log",
        type=str2bool,
        default=False,
        help='For restricted regularizer, log(sum(pi))',
    )
    parser.add_argument(
        "--restricted-regularizer-sum-prob-clamp-threshold",
        type=float,
        default=0,
        help="restricted regularizer sum prob's clamp threshold",
    )
    parser.add_argument(
        "--restricted-regularizer-sum-prob-update-frequency",
        type=int,
        default=1,
        help="restricted regularizer sum prob's udpate frequency",
    )
    parser.add_argument(
        "--reg-with-clamp",
        type=str2bool,
        default=False,
        help='when doing log(sum_pi), clamp the value',
    )
    parser.add_argument(
        "--exres-coef",
        type=float,
        default=1e-3,
        help="restricted regularizer coefficient",
    )
    parser.add_argument(
        "--exres-coef-decay-end-percent",
        type=float,
        default=0,
        help='coefficient decay ending percent',
    )
    parser.add_argument(
        "--exres-coef-decay-end-ts",
        type=float,
        default=1.0,
        help='coefficient decay ending timestep, 1.0 means the end of training',
    )
    parser.add_argument(
        "--add-kl-regularizer",
        type=str2bool,
        default=False,
        help='add kl divergence to actor loss',
    )
    parser.add_argument(
        "--kl-coef",
        type=float,
        default=1e-1,
        help="kl divergence coefficient",
    )
    parser.add_argument(
        "--add-ce-regularizer",
        type=str2bool,
        default=False,
        help='add cross entropy to actor loss',
    )
    parser.add_argument(
        "--ce-coef",
        type=float,
        default=1,
        help="cross entropy coefficient",
    )

    # regularizer for continous action space with box constraint
    parser.add_argument(
        "--reg-const-bound",
        type=float,
        default=1.0,
        help='add box regularizer to actor loss',
    )
    ## cdf regularizer
    parser.add_argument(
        "--add-cdf-regularizer",
        type=str2bool,
        default=False,
        help='add box regularizer to actor loss',
    )
    parser.add_argument(
        "--cdf-regularizer-coef",
        type=float,
        default=1.0,
        help='add cdf regularizer coef',
    )
    ## non-expert sampler regularizer
    parser.add_argument(
        "--add-pdf-regularizer",
        type=str2bool,
        default=False,
        help='add non-expert regularizer to actor loss',
    )
    parser.add_argument(
        "--pdf-coef",
        type=float,
        default=1.0,
        help='add non-expert regularizer coef',
    )
    parser.add_argument(
        "--pdf-n-samples",
        type=int,
        default=100,
        help='generate how many samples for non-expert sampler',
    )
    parser.add_argument(
        "--pdf-eps",
        type=float,
        default=1e-2,
        help='non-expert sampling epsilon',
    )

    # widowx save torque values
    parser.add_argument(
        "--widowx-save-torque",
        type=str2bool,
        default=False,
    )

    # add Out-of-Constraint Action estimation
    parser.add_argument(
        "--add-ooc-action-estimation",
        type=str2bool,
        default=True,
        help='add Out-of-Constraint Action estimation',
    )

    # add a new type of regularizer on proximity function
    parser.add_argument(
        "--add-regularizer-on-proximity",
        type=str2bool,
        default=False,
        help='regularizer: use bc to get the none-expert actions',
    )
    parser.add_argument(
        "--add-regularizer-on-proximity-infer-from-next",
        type=str2bool,
        default=False,
        help='regularizer: use bc to get the none-expert actions',
    )
    parser.add_argument(
        "--anneal-regularizer-on-proximity",
        type=str2bool,
        default=False,
        help='regularizer: use bc to get the none-expert actions',
    )
    parser.add_argument(
        "--anneal-regularizer-on-proximity-end-ts",
        type=float,
        default=0.5,
        help='regularizer: use bc to get the none-expert actions',
    )



    # add a new type of regularizer on proximity function, unkown actions will be inferred from next
    parser.add_argument(
        "--add-regularizer-on-proximity-infer-from-next",
        type=str2bool,
        default=False,
        help='',
    )

    # add a new type of regularizer: bc modeling regularizer
    parser.add_argument(
        "--add-bc-model-regularizer",
        type=str2bool,
        default=False,
        help='regularizer: use bc to get the none-expert actions',
    )
    parser.add_argument(
        "--add-bc-model-regularizer-reward-bonus",
        type=str2bool,
        default=False,
        help='regularizer: use bc to get the none-expert actions',
    )
    parser.add_argument(
        "--bc-model-regularizer-reward-bonus-value",
        type=float,
        default=1e-2,
        help='',
    )
    parser.add_argument(
        "--bc-model-checkpoint-path",
        type=str,
        default=None,
        help='regularizer: use bc to get the none-expert actions',
    )
    parser.add_argument(
        "--bc-model-reg-coef",
        type=float,
        default=1e-2,
        help='',
    )
    parser.add_argument(
        "--bc-model-regularizer-bound",
        type=float,
        default=None,
        help='',
    )
    parser.add_argument(
        "--add-bc-model-regularizer-log-sum-pi",
        type=str2bool,
        default=False,
        help='regularizer: use bc to get the none-expert actions',
    )


    # add linear decay to prox est
    parser.add_argument(
        "--pf-est-decay",
        type=str2bool,
        default=True,
        help='',
    )
    parser.add_argument(
        "--pf-est-decay-end-point",
        type=float,
        default=0.5,
        help='',
    )
    parser.add_argument(
        "--pf-est-decay-min-crop",
        type=float,
        default=0,
        help='',
    )

    # debug
    parser.add_argument(
        "--add-widowx-cube-height",
        type=str2bool,
        default=False,
        help='',
    )
    parser.add_argument(
        "--visualize-minigrid-path",
        type=str2bool,
        default=False,
        help='',
    )
    parser.add_argument(
        "--visualize-minigrid-path-saving-dir",
        type=str,
        default=None,
        help='',
    )






