# IMPORTANT
# Across GPUs, we vary by FOLDER (not confidence)

import shutil
import argparse
import subprocess
import os
import sys
import numpy as np

# -define and run HPs
#     -used ratio: 0.1, 0.2, 0.3, 0.4, 0.5
#     -mask ratio: 0.6, 0.7, 0.8, 0.9, 0.95
#     -videos: 0001, 0019, 0020
#     -win size: 1, 2, 4, 8

parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument(
    '--gpu',
    type=str,
    default=None,
    help="CUDA_VISIBLE_DEVICES parameter"
)


parser.add_argument(
    '--used', 
    nargs='+', 
    type=float, 
    default=None,
    help="Top k% most confident"
)

parser.add_argument(
    '--masked', 
    nargs='+', 
    type=float, 
    default=None,
    help="Top k% most confident"
)

parser.add_argument(
    '--mask_type', 
    type=str,
    default=None,
    help="Random or confidence"
)

parser.add_argument(
    '--videos', 
    nargs='+', 
    type=str, 
    default=None,
    help="Top k% most confident"
)

parser.add_argument(
    '--win_sizes', 
    nargs='+', 
    type=int, 
    default=None,
    help="Top k% most confident"
)

parser.add_argument(
    '--batch_size', 
    type=int, 
    default=None,
    help="Batch size"
)

parser.add_argument(
    '--accum_iter', 
    type=int, 
    default=None,
    help="Accumulative iterations"
)

parser.add_argument(
    '--base_lr', 
    type=float, 
    default=0.0001,
    help="Base learning rate"
)

parser.add_argument(
    '--setting', 
    type=str, 
    default="standard",
)

parser.add_argument(
    '--weights',
    type=str,
    default="../../../../checkpoints/maskformer_swin_s_sem_cityscapes_maskpatch16_joint.pkl",
)

parser.add_argument(
    '--output_dir',
    type=str,
    default="output",
)

parser.add_argument(
    '--st_iters',
    type=int,
    default=20,
)

parser.add_argument(
    '--restart_optimizer', action='store_true', help='Restart the optimizer between frames.')
parser.set_defaults(restart_optimizer=False)

args = parser.parse_args()


if __name__ == '__main__':
    # exp_log_dir = os.path.join('exp_dir', os.path.basename(args))
    # model_output_dir = os.path.join('output', os.path.basename(args.ttt_output))

    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    detectron2_root = os.getenv("DETECTRON2_DATASETS")
    assert detectron2_root is not None, 'Need to set $DETECTRON2_DATASETS enviroment variable!'

    train_set = ["0000",  "0001",  "0003",  "0004",  "0005",  "0009",  "0011",  "0012",  "0015",  "0017",  "0019",  "0020"]
    val_set = ["0002",  "0006",  "0007",  "0008",  "0010",  "0013",  "0014",  "0016",  "0018"]

    for vid in args.videos:
        split = "train" if vid in train_set else "val"
        for conf in args.used:
            input_label_dir = os.path.join(detectron2_root, "kitti_step/panoptic_maps",
                                            split, vid)
            in_dir = os.path.join(detectron2_root, "kitti_step/images", 
                                            split, vid)
            in_images = os.path.join(in_dir, "*.png")
            
            for drop_ratio in args.masked:
                for win_size in args.win_sizes:
                    # import ipdb; ipdb.set_trace()       # Check that all params are as expected
                    train_dir = os.path.join(detectron2_root, "kitti_step/dropout_video_4chan", 
                                            str(conf), "train", vid)
                    val_dir = os.path.join(detectron2_root, "kitti_step/dropout_video_4chan", 
                                                    str(conf), "val", vid)
                    root_dir = os.path.join(args.output_dir, 
                                            vid + "_" + 
                                            args.mask_type + "_" + 
                                            str(conf) + "_" + 
                                            str(drop_ratio) + "_" + 
                                            str(win_size) + "_" +
                                            str(args.batch_size))
                    os.makedirs(root_dir, exist_ok=True)


                    # 1. Create output folder and copy relevant data over from dataset root
                    data_dir = os.path.join(root_dir, "data")
                    os.makedirs(data_dir, exist_ok=True)
                    data_train_dir = os.path.join(data_dir, "train")
                    if os.path.exists(data_train_dir):
                        shutil.rmtree(data_train_dir)
                    shutil.copytree(train_dir, data_train_dir)
                    data_val_dir = os.path.join(data_dir, "val")
                    if os.path.exists(data_val_dir):
                        shutil.rmtree(data_val_dir)
                    shutil.copytree(val_dir, data_val_dir)

                    subprocess.run([
                        # sys.executable, "train_dropout_video_win.py",
                        sys.executable, "train_ttt_mae.py",
                        "--num-gpus", "1",
                        "--config-file", "configs/kitti_step/semantic-segmentation/swin/maskformer2_swin_small_bs16_90k_ttt_mae.yaml",
                        "--st_iters", str(args.st_iters),
                        "--win_size", str(win_size),
                        "--ttt_setting", args.setting,
                        "--drop_aug",
                        "--mask_type", args.mask_type,
                        "--drop_ratio", str(drop_ratio),
                        "--ttt_in_dir", in_dir,
                        "--ttt_out_dir", data_dir,
                        "--exp_dir", "exp_dir/mae" + "_" + str(args.batch_size) + "_" + args.setting + "_" + str(args.base_lr),
                        "MODEL.WEIGHTS", args.weights,
                        "OUTPUT_DIR", root_dir,
                        "SOLVER.IMS_PER_BATCH", str(args.batch_size),
                        "SOLVER.ACCUM_ITER", str(args.accum_iter),
                        "SOLVER.BASE_LR", str(args.base_lr),
                        "TTT.USE_SEG_HEAD", "False",
                        "TTT.RESTART_OPTIMIZER", 'True' if args.restart_optimizer else 'False',
                    ], check=True)
        
                    # Remove the appropriate directory after training
                    if os.path.exists(root_dir):
                        shutil.rmtree(root_dir)
            


