import os
import yaml
import argparse

from mmengine.runner import Runner
from mmengine.config import Config

# from data.get_partial_data import create_train_val_dataloaders_config_from_data_path_list
import torch

from torch.utils.data import Subset
import numpy as np

def argparse_setup():
    parser = argparse.ArgumentParser(description='Train a model')
    parser.add_argument('--config', required=True, type=str, default='config.yaml', help='Path to config file')
    parser.add_argument('--epoch', type=int, default=-1, help='Epoch to start training from')
    parser.add_argument('--local_rank', '--local-rank', type=int, default=0)
    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher')
    parser.add_argument(
        "--project_dirs",
        type=str,
        default="missing_person_obj_det",
        help="project directories for saving the model and log files")
    parser.add_argument(
        '--exp_name',
        default='',
        type=str,
        help='experiment name, used for saving the model and log files')
    parser.add_argument(
        '--selected_season',
        default='',
        type=str,
        help='selected season, used for saving the model and log files')
    parser.add_argument(
        '--selected_place',
        default='',
        type=str,
        help='selected place, used for saving the model and log files')
    args = parser.parse_args()
    if 'LOCAL_RANK' not in os.environ:
        os.environ['LOCAL_RANK'] = str(args.local_rank)
    return args

def main():
    args = argparse_setup()
    with open(args.config, 'r') as file:
        config = yaml.safe_load(file)
    print(config)
    
    torch.multiprocessing.set_sharing_strategy('file_system') # TMP

    # 2. setup the training environment
    base_config = Config.fromfile(config["base_config"])
    if args.epoch != -1:
        base_config["train_cfg"]["max_epochs"] = args.epoch
    base_config["work_dir"] = os.path.join(args.project_dirs, args.exp_name)
    base_config.launcher = args.launcher

    if args.selected_season != "":
        base_config.train_dataloader.dataset.selected_season = args.selected_season
        base_config.val_dataloader.dataset.selected_season = args.selected_season
    # if args.selected_place != "":
    #     base_config.train_dataloader.dataset.selected_place = args.selected_place
    #     base_config.val_dataloader.dataset.selected_place = args.selected_place

    
    # Get experiment name
    base_config["experiment_name"] = args.exp_name
    runner = Runner.from_cfg(base_config)

    # breakpoint()

    # 3. train wiht runner!
    runner.train()
    # runner.val()
    
if __name__=="__main__":
    
    # This area is for unit test
    
    # 1. load the model config file
    # tmp = Config.fromfile("configs/models/retinanet.py")
    # tmp = Config.fromfile("configs/models/faster_rcnn.py")
    # breakpoint()
    
    # 2. load the concatenated dataset config
    # dummy_dir_path_list = [
    #     "/mnt/home/jeongjun/layout_diffusion/datasets/250130_ms_coco_seed_3/w_acq/guided/cycle_1/3", 
    #     "/mnt/home/jeongjun/layout_diffusion/datasets/250130_ms_coco_seed_3/w_acq/guided/cycle_2/3", 
    # ]
    # dataloader_tmp = create_dataloader_dict_from_data_path_list(dummy_dir_path_list)
    # breakpoint()
    
    # check actual running
    main()