# IMPORTANT
# Across GPUs, we vary by FOLDER (not confidence)

import argparse
import glob
import os
import shutil
import subprocess
import sys

import numpy as np

parser = argparse.ArgumentParser()
parser.add_argument(
    '--gpu',
    type=str,
    default=None,
    help="CUDA_VISIBLE_DEVICES parameter"
)


parser.add_argument(
    '--videos', 
    nargs='+', 
    type=str, 
    default=None
)


parser.add_argument(
    '--weights',
    type=str,
    default="../../../../checkpoints/maskformer_swin_s_sem_cityscapes_maskpatch16_joint.pkl",
)



args = parser.parse_args()


if __name__ == '__main__':
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    detectron2_root = os.getenv("DETECTRON2_DATASETS")
    assert detectron2_root is not None, 'Need to set $DETECTRON2_DATASETS enviroment variable!'

    train_set = ["0000",  "0001",  "0003",  "0004",  "0005",  "0009",  "0011",  "0012",  "0015",  "0017",  "0019",  "0020"]
    val_set = ["0002",  "0006",  "0007",  "0008",  "0010",  "0013",  "0014",  "0016",  "0018"]

    base_dir = os.path.join(detectron2_root, "kitti_step/images")
    for vid in args.videos:
        split = "train" if vid in train_set else "val"
        ttt_in_dir = os.path.join(base_dir, split, vid)
        num_imgs = len(glob.glob(os.path.join(ttt_in_dir, "*.png")))
        conf = 0.1

        # Copy video directory
        train_dir = os.path.join(detectron2_root, "kitti_step/dropout_video_4chan", 
                                        str(conf), "train", vid)
        val_dir = os.path.join(detectron2_root, "kitti_step/dropout_video_4chan", 
                                        str(conf), "val", vid)
        root_dir = os.path.join("exp_dir", "kitti_step_baselines", vid)
        os.makedirs(root_dir, exist_ok=True)


        # 1. Create output folder and copy relevant data over from dataset root
        data_train_dir = os.path.join(root_dir, "train")
        if os.path.exists(data_train_dir):
            shutil.rmtree(data_train_dir)
        shutil.copytree(train_dir, data_train_dir)
        data_val_dir = os.path.join(root_dir, "val")
        if os.path.exists(data_val_dir):
            shutil.rmtree(data_val_dir)
        shutil.copytree(val_dir, data_val_dir)

        all_miou = []
        for i in range(num_imgs):
            subprocess.run([
                    sys.executable, "val_single_image.py",
                    "--num-gpus", "1",
                    # "--config-file", "configs/kitti_step/semantic-segmentation/swin/maskformer2_swin_tiny_bs16_90k_ttt_drop_vid.yaml",
                    "--config-file", "configs/kitti_step/semantic-segmentation/swin/maskformer2_swin_small_bs16_90k_ttt_mae.yaml",
                    "--ttt_in_dir", ttt_in_dir,
                    "--video", vid,
                    "--st_iters", str(i),
                    "--eval-only", "1",
                    "MODEL.WEIGHTS", args.weights,
                    "OUTPUT_DIR", root_dir
            ], check=True)

            # Retrieve IoU and add it
            with open(os.path.join(root_dir, "log.txt"), 'r') as fp:
                runs = fp.readlines()
            miou = runs[-1].split(': ')[-1].split(',')[0].strip()
            all_miou.append(miou)
        
        
        with open(os.path.join(root_dir, "performance.txt"), 'w') as fp:
            for line in all_miou:
                fp.write(line)
                fp.write('\n')

        shutil.rmtree(data_train_dir)
        shutil.rmtree(data_val_dir)
        