
import sys

sys.path.append("xxxx/CIL-EDBL")
from lib.ExemplarManager import ExemplarManager
from lib.approach.MCFM import MCFM_handler
from lib.loss import *
from lib.dataset import *
from lib.config import MCFM_cfg, update_config
from lib.utils.utils import (
    create_logger,
)

import torch
import os
import argparse
import warnings


def parse_args():
    parser = argparse.ArgumentParser(description="codes for EDBL")

    parser.add_argument(
        "--cfg",
        help="decide which cfg to use",
        required=False,
        default="../configs/MCFM_cifar100.yaml",
        #default="../configs/MCFM_cifar10.yaml",
        #default="../configs/MCFM_tiny.yaml",
        type=str,
    )
    parser.add_argument(
        "opts",
        help="modify config options using the command-line",
        default=None,
        nargs=argparse.REMAINDER,
    )
    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = parse_args()
    update_config(MCFM_cfg, args)
    logger, log_file = create_logger(MCFM_cfg, "log")
    batch_train_logger, _ = create_logger(MCFM_cfg, "batch_train_log")
    warnings.filterwarnings("ignore")
    dataset_split_handler = eval(MCFM_cfg.DATASET.dataset)(MCFM_cfg, split_selected_data=None)
    if MCFM_cfg.availabel_cudas:
        os.environ['CUDA_VISIBLE_DEVICES'] = MCFM_cfg.availabel_cudas
        device_ids = [i for i in range(len(MCFM_cfg.availabel_cudas.strip().split(',')))]
        print(device_ids)
    device = torch.device("cpu" if MCFM_cfg.CPU_MODE else "cuda")
    exemplar_img_transform_for_val = dataset_split_handler.val_test_dataset_transform if \
        MCFM_cfg.exemplar_manager.store_original_imgs else None
    exemplar_img_transform_for_train = transforms.Compose([*AVAILABLE_TRANSFORMS[dataset_split_handler.dataset_name]
    ['train_transform']]) if MCFM_cfg.exemplar_manager.store_original_imgs else None
    exemplar_manager = ExemplarManager(MCFM_cfg.exemplar_manager.memory_budget, MCFM_cfg.exemplar_manager.mng_approach,
                                       MCFM_cfg.exemplar_manager.store_original_imgs,
                                       MCFM_cfg.exemplar_manager.norm_exemplars,
                                       MCFM_cfg.exemplar_manager.centroid_order,
                                       img_transform_for_val=exemplar_img_transform_for_val,
                                       img_transform_for_train=exemplar_img_transform_for_train,
                                       device=device)

    mcfm_handler = MCFM_handler(dataset_split_handler, exemplar_manager, MCFM_cfg, logger, batch_train_logger, device)
    mcfm_handler.cil_train_main()
