import detectron2.utils.comm as comm
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg
from detectron2.engine import default_argument_parser, default_setup, launch

from pt import add_config
from pt.engine.trainer import PTrainer

# to register
from pt.modeling.meta_arch.rcnn import GuassianGeneralizedRCNN
from pt.modeling.proposal_generator.rpn import GuassianRPN
from pt.modeling.roi_heads.roi_heads import GuassianROIHead
import pt.data.datasets.builtin
from pt.modeling.backbone.vgg import build_vgg_backbone
from pt.modeling.backbone.resnet_ibna import build_resnet_ibna_backbone
from pt.modeling.backbone.resnet_ibnb import build_resnet_ibnb_backbone
from pt.modeling.backbone.resnet_ibnb_fullIN import build_resnet_ibnb_fullIN_backbone
from pt.modeling.backbone.clip_backbone import build_clip_resnet_backbone
from pt.modeling.backbone.resnet_clip import build_resnet_clip_backbone
from pt.modeling.anchor_generator import DifferentiableAnchorGenerator

from pt.modeling.meta_arch.ts_ensemble import EnsembleTSModel
from shutil import copyfile
import os


def setup(args):
    """
    Create configs and perform basic setups.
    """
    cfg = get_cfg()
    add_config(cfg)
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    cfg.freeze()
    default_setup(cfg, args)
    return cfg


def main(args):
    cfg = setup(args)

    copyfile(args.config_file, os.path.join(cfg.OUTPUT_DIR, 'cfg.yaml'))
    copyfile('pt/modeling/roi_heads/fast_rcnn.py', os.path.join(cfg.OUTPUT_DIR, 'fast_rcnn.py'))

    if cfg.UNSUPNET.Trainer == "pt":
        Trainer = PTrainer
    else:
        raise ValueError("Trainer Name is not found.")

    if args.eval_only:
        if cfg.UNSUPNET.Trainer in ["pt"]:
            model = Trainer.build_model(cfg)
            model_teacher = Trainer.build_model(cfg)
            ensem_ts_model = EnsembleTSModel(model_teacher, model)

            DetectionCheckpointer(
                ensem_ts_model, save_dir=cfg.OUTPUT_DIR
            ).resume_or_load(cfg.MODEL.WEIGHTS, resume=args.resume)
            #res = Trainer.test(cfg, ensem_ts_model.modelStudent)
            res = Trainer.test(cfg, ensem_ts_model.modelTeacher)
        else:
            model = Trainer.build_model(cfg)
            DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
                cfg.MODEL.WEIGHTS, resume=args.resume
            )
            res = Trainer.test(cfg, model)
        return res

    trainer = Trainer(cfg)
    trainer.resume_or_load(resume=args.resume)

    return trainer.train()


if __name__ == "__main__":
    args = default_argument_parser().parse_args()

    print("Command Line Args:", args)
    launch(
        main,
        args.num_gpus,
        num_machines=args.num_machines,
        machine_rank=args.machine_rank,
        dist_url=args.dist_url,
        args=(args,),
    )
