import os
import yaml
import argparse

from mmengine.runner import Runner
from mmengine.config import Config

# from data.get_partial_data import create_train_val_dataloaders_config_from_data_path_list
import torch

def argparse_setup():
    parser = argparse.ArgumentParser(description='Train a model')
    parser.add_argument('--config', required=True, type=str, default='config.yaml', help='Path to config file')
    parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
    parser.add_argument("--checkpoint", type=str, required=True, help="Checkpoint file to load")
    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher')
    parser.add_argument(
        "--project_dirs",
        type=str,
        default="missing_person_obj_det",
        help="project directories for saving the model and log files")
    parser.add_argument(
        '--exp_name',
        default='',
        type=str,
        help='experiment name, used for saving the model and log files')
    args = parser.parse_args()
    if 'LOCAL_RANK' not in os.environ:
        os.environ['LOCAL_RANK'] = str(args.local_rank)
    return args


config_names = [
    "exp_config/retinanet.yaml",
    "exp_config/faster_rcnn.yaml",
    "exp_config/ssd.yaml",
    "exp_config/yolov3.yaml",
    "exp_config/yolox.yaml",
]

checkpoints = [
    "missing_person/250423_missing_person_retinanet_annotation_v0.3/best_coco_bbox_mAP_epoch_12.pth",
    "missing_person/250425_missing_person_faster_rcnn_annotation_v0.3/best_coco_bbox_mAP_epoch_9.pth",
    "missing_person/250503_missing_person_ssd_forestperson_v3/epoch_2.pth",
    "missing_person/250425_missing_person_yolov3_annotation_v0.3/epoch_49_best.pth",
    "missing_person/250422_missing_person_yolox_annotation_v0.3/best_coco_bbox_mAP_epoch_74.pth"
]


def main():
    args = argparse_setup()
    dicts = {}
    for config_name, checkpoint in zip(config_names, checkpoints):
        args.config = config_name
        args.checkpoint = checkpoint

        print(f"Running with config: {config_name} and checkpoint: {checkpoint}")
        with open(args.config, 'r') as file:
            config = yaml.safe_load(file)
        print(config)
        
        # 1. setup the dataloader
        base_config = Config.fromfile(config["base_config"])
        base_config["work_dir"] = os.path.join(args.project_dirs, args.exp_name)
        
        base_config.launcher = args.launcher
        base_config.load_from = args.checkpoint

        base_config.test_dataloader.batch_size = 1
        

        # Get experiment name
        # base_config["experiment_name"] = config["experiment_name"]
        base_config["experiment_name"] = args.exp_name

        base_config["default_hooks"] = dict()

        season_prediction_result = "season_prediction_results2"
        outfile_prefix = os.path.join(
            season_prediction_result, checkpoint.split("/")[-2]
        )
        base_config.test_evaluator.update(
            dict(outfile_prefix=outfile_prefix)
        )

        runner = Runner.from_cfg(base_config)
        
        test = runner.test()
        dicts[config_name] = {
            "test": test
        }
        print(f"Finished running with config: {config_name} and checkpoint: {checkpoint}")
        print(f"Test results: {test}")
    print(dicts)
    # Save the results to a file
    with open("overall_results.yaml", 'w') as file:
        yaml.dump(dicts, file)


    
    
if __name__=="__main__":
    
    main()