import sys

import torch
import torch.nn as nn


class Model_Wrapper():
    """
    This class is a Wrapper  used to facilitate the implementation the modularity of the training.
    For instance, as CHTR do not classsify the action, this wrapper returns None as the action classifies.
    Also allows to load the correct module
    """

    def __init__(self, cfg, visualize=True):
        self.cfg = cfg
        self.training=False
        self.modelname = self.load_model(cfg, visualize=True)

    def __call__(self, batch):
        """
        :param x: batch as dict of:
            'pred_labels': torch.Size([789])
            'frames': torch.Size([173, 3, 224, 224])
            'bboxes': torch.Size([789, 5])
            'ids': torch.Size([789])
            'pair_idxes': torch.Size([2830, 2])
            'pair_human_ids': torch.Size([2830])
            'pair_object_ids': torch.Size([2830])
            'im_idxes': torch.Size([2830])
            'binary_masks': torch.Size([789, 224, 224])
            'interactions_gt': torch.Size([217, 50])
            'spatial_gt': torch.Size([217, 8])
            'action_gt': torch.Size([217, 42])
            'exist_mask': torch.Size([217])
            'change_mask': torch.Size([217])
            'windows': torch.Size([217, 2830])
            'windows_out': torch.Size([217, 2830])
            'out_im_idxes': torch.Size([BatchSize])- int that indicates the indices of the batch sizes per frame
        :return: (action_classes, human_motion) of shape ([B], [B,T,P])
        """
        return self.model(batch)


    def load_model(self, cfg, visualize=True):
        modelname = cfg["MODEL"]["MODEL_NAME"]
        hoibot_models =["HOIBOT", "HOIBOT_b", "HOIBOT_DA", "HOIBOT_stacked", "HOIBOT_all", "HYDRA", "BOTH"]
        assert modelname in ["STTranGaze"] + hoibot_models, "Error when defining the model type"
        if modelname == "STTranGaze":
            from hoi.model.sttran.load_sttran import load_sttran
            self.model, message = load_sttran(cfg)
        elif modelname in hoibot_models:
            self.model = self.load_hoibot(cfg, modelname)
            if visualize:
                self.model.info_model()
        else:
            print(f"ERROR! {modelname} not implemented in ModelWrapper")
        return modelname

    def load_hoibot(self, cfg, modelname):
        if modelname in ["HOIBOT", "HOIBOT_b", "HOIBOT_stacked"]:
            if modelname in ["HOIBOT", "HOIBOT_b"]:
                from modules.hoi4abot.hoibot.hoitbot import HOIBOT
            elif modelname in ["HOIBOT_stacked"]:
                from modules.hoi4abot.hoibot.hoitbot_stacked import HOIBOT
            return HOIBOT(
                device=cfg["MODEL"]["device"],
                num_spatial_classes=cfg["TRAINER"]["DATASET"]["num_spatial_classes"],
                num_action_classes=cfg["TRAINER"]["DATASET"]["num_action_classes"],
                obj_class_names=cfg["TRAINER"]["DATASET"]["object_classes"],
                embedding_dimension=cfg["MODEL"]["embedding_dim"],
                hoi_feature=cfg["MODEL"]["hoi_feature"],
                input_image_size=(cfg["TRAINER"]["DATASET"]["img_size"], cfg["TRAINER"]["DATASET"]["img_size"]),
                patch_size=cfg["TRAINER"]["DATASET"]["img_size"]//14,
                sliding_window=cfg["MODEL"]["sttran_sliding_window"],
                blender_type=cfg["MODEL"]["blender_type"],
                moa_eps=cfg["MODEL"]["moa_eps"],
                depth=cfg["MODEL"]["depth"],
                dual_transformer_type=cfg["MODEL"]["dual_transformer_type"],
                num_heads=cfg["MODEL"]["num_heads"],
                box_encoder_type = cfg["MODEL"]["box_encoder_type"],
                semantic_type=cfg["MODEL"]["semantic_type"],
                do_regression=cfg["MODEL"]["do_regression"],
                use_feature_extractor=False,
                train_feature_extractor=cfg["MODEL"]["train_feature_extractor"],
                semantic_masking_prob=cfg["MODEL"]["semantic_masking_prob"],
                augmentation_semantic=cfg["MODEL"]["augmentation_semantic"],
                # use_semantic_extractor=cfg["MODEL"]["use_semantic_extractor"],
                # use_sam_prompt_encoder=cfg["MODEL"]["use_sam_prompt_encoder"],
                annotation_dir=cfg["PATHS"]["annotations"],
                max_length = cfg["TRAINER"]["DATASET"]["max_clip_length"],
                mlp_ratio=cfg["MODEL"]["mlp_ratio"],
                drop_rate=cfg["MODEL"]["dropout"],
                pos_embed_type = cfg["MODEL"]["pos_embed_type"],
                simple_semantics = cfg["MODEL"]["simple_semantics"],
                image_cls_type = cfg["MODEL"]["image_cls_type"],
                mainbranch=cfg["MODEL"]["mainbranch"],
                do_inference=cfg["MODEL"]["do_inference"] if "do_inference" in cfg["MODEL"] else False,
                union_box=cfg["MODEL"]["union_box"],
                head_cls_type=cfg["MODEL"]["head_cls_type"],
                bigextractor=modelname=="HOIBOT_b",
                extend_head_transformer=cfg["MODEL"]["add_extension"] if "add_extension" in cfg["MODEL"] else False,
                concat_extension=cfg["MODEL"]["concat_extension"] if "concat_extension" in cfg["MODEL"] else False,
            )

        elif modelname in ["HOIBOT_DA"]:
            from modules.hoi4abot.hoibot.hoibot_DA import HOIBOT_DA
            state_dict = torch.load(cfg["FINETUNE"]["weights_path"])
            info = state_dict["info"]
            if "object_classes" not in info["TRAINER"]["DATASET"]:
                self.update_if_needed(info, cfg)
            hoibot_detection = self.load_hoibot(info, cfg["FINETUNE"]["backbone"])
            hoibot_da = HOIBOT_DA(hoibot_detection, future_num= cfg["FINETUNE"]["future_num"])
            hoibot_da.load_weights_backbone(state_dict["state_dict"])
            return hoibot_da
        elif modelname in ["HOIBOT_all"]:
            from modules.hoi4abot.hoibot.hoibot_DA import HOIBOT_DA
            hoibot_detection = self.load_hoibot(cfg, cfg["FINETUNE"]["backbone"])
            hoibot_da = HOIBOT_DA(hoibot_detection, future_num= cfg["FINETUNE"]["future_num"])
            return hoibot_da
        elif modelname in ["HYDRA"]:
            from modules.hoi4abot.hoibot.hoibot_hydra import HOIBOT_HYDRA
            anticipation_heads = cfg["FINETUNE"]["anticipation_heads"]
            hoibot_detection = self.load_hoibot(cfg, cfg["FINETUNE"]["backbone"])
            hoibot_hydra = HOIBOT_HYDRA(hoibot_detection, future_nums= anticipation_heads)
            for head_fut_num  in anticipation_heads:
                state_dict = torch.load(str(cfg["FINETUNE"]["all_heads_path"]).format(head_fut_num, "pt"))
                hoibot_hydra.load_weights_backbone(state_dict["state_dict"] if "state_dict" in state_dict else state_dict, head_fut_num)
            return hoibot_hydra
        elif modelname in ["BOTH"]:
            from modules.hoi4abot.hoibot.hoibot_both import HOIBOT_BOTH
            anticipation_heads = cfg["FINETUNE"]["anticipation_heads"]

            model_detection = self.load_hoibot(cfg, cfg["FINETUNE"]["backbone"])
            model_anticipation = self.load_hoibot(cfg, cfg["FINETUNE"]["backbone"])
            hoibot_hydra = HOIBOT_BOTH(model_detection, model_anticipation, device=cfg["MODEL"]["device"])


            for head_fut_num  in anticipation_heads:
                if head_fut_num == 0:
                    state_dict_detection = torch.load(str(cfg["FINETUNE"]["all_heads_path"]).format(head_fut_num, "pt"), map_location="cpu")
                    state_dict_detection = state_dict_detection["state_dict"] if "state_dict" in state_dict_detection else state_dict_detection
                else:
                    state_dict_anticipation = torch.load(str(cfg["FINETUNE"]["all_heads_path"]).format(head_fut_num, "pt"),  map_location="cpu")
                    state_dict_anticipation = state_dict_anticipation["state_dict"] if "state_dict" in state_dict_anticipation else state_dict_anticipation

            hoibot_hydra.load_weights_backbone(state_dict_detection, state_dict_anticipation )
            return hoibot_hydra
        else:
            print(f"ERROR! You should never get here: {modelname}")
            sys.exit(0)

    def update_if_needed(self, info, cfg):
        info["TRAINER"]["DATASET"].update({
            "object_classes": cfg["TRAINER"]["DATASET"]["object_classes"],
            "interaction_classes": cfg["TRAINER"]["DATASET"]["interaction_classes"],
            "spatial_class_idxes": cfg["TRAINER"]["DATASET"]["spatial_class_idxes"],
            "action_class_idxes": cfg["TRAINER"]["DATASET"]["action_class_idxes"],
            "num_object_classes": cfg["TRAINER"]["DATASET"]["num_object_classes"],
            "num_interaction_classes": cfg["TRAINER"]["DATASET"]["num_interaction_classes"],
            "num_action_classes": cfg["TRAINER"]["DATASET"]["num_action_classes"],
            "num_spatial_classes": cfg["TRAINER"]["DATASET"]["num_spatial_classes"],
        })

        info["PATHS"] = cfg["PATHS"]
        info["MODEL"]["device"] = cfg["MODEL"]["device"]

    def to(self, device):
        self.model.to(device)

    def parameters(self):
        return self.model.parameters()

    def train(self):
        self.training=True
        self.model.train()

    def eval(self):
        self.training=False
        self.model.eval()

    def __str__(self):
        return self.model.__str__()

    def __repr__(self):
        return self.model.__repr__()

    def load_state_dict(self, state_dict, strict=True):
        if self.modelname =="HYDRA":
            if self.model.isloaded:
                return 0
        if "state_dict" in state_dict:
            state_dict = state_dict["state_dict"]
        return self.model.load_state_dict(state_dict=state_dict, strict=strict)

    def state_dict(self):
        return self.model.state_dict()

if __name__ == '__main__':
    from hoi.configs.paths import *
    from hoi.configs.cfg_to_info import cfg_to_info, default_opt
    future = 0
    opt = default_opt()
    opt.name = "exp_f{}".format(future)
    opt.cfg = project_path + "/hoi/configs/cfg/train_hyp_f0.yaml"
    opt.modelname = "HOIBOT_extended"
    cfg = cfg_to_info(opt)
    modelwrapper = Model_Wrapper(cfg=cfg)

