import argparse
import torch


def get_args():
    parser = argparse.ArgumentParser(
        description='REgion-Aware Dynamic Exploration')
    # General Arguments
    parser.add_argument('--seed', type=int, default=1,
                        help='random seed (default: 1)')
    parser.add_argument('--auto_gpu_config', type=int, default=0)
    parser.add_argument('--total_num_scenes', type=str, default="auto")
    parser.add_argument('--gpus', type=int, default=2)
    parser.add_argument('-n', '--num_processes', type=int, default=1,  
                        help="""how many training processes to use (default:5)
                                Overridden when auto_gpu_config=1
                                and training on gpus""")
    parser.add_argument('--num_processes_per_gpu', type=int, default=3)  # 6
    parser.add_argument('--num_processes_on_first_gpu',
                        type=int, default=2)  # 3
    parser.add_argument('--eval', type=int, default=1,
                        help='0: Train, 1: Evaluate (default: 0)')

    parser.add_argument('--use_gtsem', type=int, default=0,
                        help='0: eval, 1: eval_with_bbox')
    parser.add_argument('--num_eval_episodes', type=int, default=2000,
                        help="number of test episodes per scene")
    parser.add_argument('--episodes_per_scene', type=int, default=100,
                        help="number of episodes to run per scene before switching")
    parser.add_argument('--num_train_episodes', type=int, default=10000,
                        help="""number of train episodes per scene
                                before loading the next scene""")
    parser.add_argument('--no_cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument("--sim_gpu_id", type=int, default=0,
                        help="gpu id on which scenes are loaded")
    parser.add_argument("--sem_gpu_id", type=int, default=-1,
                        help="""gpu id for semantic model,
                                -1: same as sim gpu, -2: cpu""")

    # Logging, loading models, visualization
    parser.add_argument('--log_interval', type=int, default=5,
                        help="""log interval, one log per n updates
                                (default: 10) """)
    # -------------------------------------------------------------
    parser.add_argument('-d', '--dump_location', type=str, default="./tmp/",
                        help='path to dump models and log (default: ./tmp/)')
    parser.add_argument('--exp_name', type=str, default="exp_880_2",
                        help='experiment name (default: exp1)')
    # -------------------------------------------------------------
    parser.add_argument('-v', '--visualize', type=int, default=2,
                        help="""1: Render the observation and
                                   the predicted semantic map,
                                2: Render the observation with semantic
                                   predictions and the predicted semantic map
                                (default: 0)""")
    parser.add_argument('--print_images', type=int, default=1,
                        help='1: save visualization as images')

    # parser.add_argument('-efw', '--env_frame_width', type=int, default=640,
    #                     help='Frame width (default:640)')
    # parser.add_argument('-efh', '--env_frame_height', type=int, default=480,
    #                     help='Frame height (default:480)')
    parser.add_argument('-efw', '--env_frame_width', type=int, default=1280,
                        help='Frame width (default:640)')
    parser.add_argument('-efh', '--env_frame_height', type=int, default=960,
                        help='Frame height (default:480)')
    parser.add_argument('-fw', '--frame_width', type=int, default=160,
                        help='Frame width (default:160)')
    parser.add_argument('-fh', '--frame_height', type=int, default=120,
                        help='Frame height (default:120)')
    parser.add_argument('-el', '--max_episode_length', type=int, default=500,
                        help="""Maximum episode length""")
    parser.add_argument("--task_config", type=str,
                        default="tasks/objectnav_hm3d.yaml",
                        help="path to config yaml containing task information")
    parser.add_argument("--split", type=str, default="val_880",
                        help="dataset split (val) ")
    parser.add_argument('--camera_height', type=float, default=0.88,
                        help="agent camera height in metres")
    parser.add_argument('--traversable_height_threshold', type=float, default=0.875,
                        help="traversable height in metres")
    parser.add_argument('--hfov', type=float, default=110.0,
                        help="horizontal field of view in degrees")
    parser.add_argument('--turn_angle', type=float, default=60,
                        help="Agent turn angle in degrees")
    parser.add_argument('--min_depth', type=float, default=0,
                        help="Minimum depth for depth sensor in meters")
    parser.add_argument('--max_depth', type=float, default=100.0,
                        help="Maximum depth for depth sensor in meters")
    parser.add_argument('--success_dist', type=float, default=1.0,
                        help="success distance threshold in meters")
    parser.add_argument('--floor_thr', type=int, default=50,
                        help="floor threshold in cm")
    parser.add_argument('--min_d', type=float, default=1.5,
                        help="min distance to goal during training in meters")
    parser.add_argument('--max_d', type=float, default=100.0,
                        help="max distance to goal during training in meters")
    parser.add_argument('--version', type=str, default="v1.1",
                        help="dataset version")

    # Model Hyperparameters
    parser.add_argument('--agent', type=str, default="sem_exp")
    parser.add_argument('--reward_coeff', type=float, default=0.1,
                        help="Object goal reward coefficient")
    parser.add_argument('--num_sem_categories', type=float, default=66)

    # YOLO Detection Parameters
    parser.add_argument('--yolo_conf_thresh', type=float, default=0.75,
                        help="YOLO confidence threshold for detection filtering (default: 0.75)")
    parser.add_argument('--yolo_mask_thresh', type=float, default=0.75,
                        help="YOLO mask threshold for binary mask generation (default: 0.75)")
    parser.add_argument('--yolo_binary_thresh', type=float, default=0.5,
                        help="YOLO binary threshold for converting probability mask to binary mask (default: 0.5)")
    parser.add_argument('--sem_pred_prob_thr', type=float, default=0.6,
                        help="MASKRcnnSemantic prediction confidence threshold")

    # TV black pixel filtering
    parser.add_argument('--tv_black_pixel_threshold', type=float, default=0.8,
                        help="TV black pixel ratio threshold: if black pixels in TV mask exceed this ratio, the detection is filtered out (default: 0.7)")

    # TV specific YOLO thresholds
    parser.add_argument('--tv_yolo_conf_thresh', type=float, default=0.85,
                        help="TV-specific YOLO confidence threshold (if None, uses general yolo_conf_thresh)")
    parser.add_argument('--tv_yolo_mask_thresh', type=float, default=0.85,
                        help="TV-specific YOLO mask threshold (if None, uses general yolo_mask_thresh)")
    parser.add_argument('--tv_yolo_binary_thresh', type=float, default=0.5,
                        help="TV-specific YOLO binary threshold (if None, uses general yolo_binary_thresh)")

    # Mapping
    parser.add_argument('--global_downscaling', type=int, default=2)
    parser.add_argument('--vision_range', type=int, default=100)
    parser.add_argument('--map_resolution', type=int, default=5)
    parser.add_argument('--du_scale', type=int, default=1)
    parser.add_argument('--map_size_cm', type=int, default=2400)
    parser.add_argument('--cat_pred_threshold', type=float, default=5)
    parser.add_argument('--map_pred_threshold', type=float, default=1)
    parser.add_argument('--exp_pred_threshold', type=float, default=1)

    parser.add_argument('--collision_threshold', type=float, default=0.20)

    # parse arguments
    args = parser.parse_args()

    args.cuda = not args.no_cuda and torch.cuda.is_available()

    if args.cuda:
        if args.auto_gpu_config:
            num_gpus = torch.cuda.device_count()
            if args.total_num_scenes != "auto":
                args.total_num_scenes = int(args.total_num_scenes)
            elif "objectnav_gibson" in args.task_config and \
                    "train" in args.split:
                args.total_num_scenes = 25
            elif "objectnav_gibson" in args.task_config and \
                    "val" in args.split:
                args.total_num_scenes = 5
            else:
                assert False, "Unknown task config, please specify" + \
                    " total_num_scenes"

            # GPU Memory required for the SemExp model:
            #       0.8 + 0.4 * args.total_num_scenes (GB)
            # GPU Memory required per thread: 2.6 (GB)
            min_memory_required = max(0.8 + 0.4 * args.total_num_scenes, 2.6)
            # Automatically configure number of training threads based on
            # number of GPUs available and GPU memory size
            gpu_memory = 1000
            for i in range(num_gpus):
                gpu_memory = min(gpu_memory,
                                 torch.cuda.get_device_properties(
                                     i).total_memory
                                 / 1024 / 1024 / 1024)
                assert gpu_memory > min_memory_required, \
                    """Insufficient GPU memory for GPU {}, gpu memory ({}GB)
                    needs to be greater than {}GB""".format(
                        i, gpu_memory, min_memory_required)

            num_processes_per_gpu = int(gpu_memory / 2.6)
            num_processes_on_first_gpu = \
                int((gpu_memory - min_memory_required) / 2.6)

            if args.eval:
                max_threads = num_processes_per_gpu * (num_gpus - 1) \
                    + num_processes_on_first_gpu
                assert max_threads >= args.total_num_scenes, \
                    """Insufficient GPU memory for evaluation"""

            if num_gpus == 1:
                args.num_processes_on_first_gpu = num_processes_on_first_gpu
                args.num_processes_per_gpu = 0
                args.num_processes = num_processes_on_first_gpu
                assert args.num_processes > 0, "Insufficient GPU memory"
            else:
                num_threads = num_processes_per_gpu * (num_gpus - 1) \
                    + num_processes_on_first_gpu
                num_threads = min(num_threads, args.total_num_scenes)
                args.num_processes_per_gpu = num_processes_per_gpu
                args.num_processes_on_first_gpu = max(
                    0,
                    num_threads - args.num_processes_per_gpu * (num_gpus - 1))
                args.num_processes = num_threads

            args.sim_gpu_id = 1

            print("Auto GPU config:")
            print("Number of processes: {}".format(args.num_processes))
            print("Number of processes on GPU 0: {}".format(
                args.num_processes_on_first_gpu))
            print("Number of processes per GPU: {}".format(
                args.num_processes_per_gpu))
    else:
        args.sem_gpu_id = -2

    return args
