import sys

import torch
import torch.nn as nn


class Model_Wrapper():
    """
    This class is a Wrapper  used to facilitate the implementation the modularity of the training.
    For instance, as CHTR do not classsify the action, this wrapper returns None as the action classifies.
    Also allows to load the correct module
    """

    def __init__(self, cfg, visualize=True):
        self.cfg = cfg
        self.training=False
        self.modelname = self.load_model(cfg, visualize=True)

    def __call__(self, batch):
        """
        :param x: batch as dict of:
            'pred_labels': torch.Size([789])
            'frames': torch.Size([173, 3, 224, 224])
            'bboxes': torch.Size([789, 5])
            'ids': torch.Size([789])
            'pair_idxes': torch.Size([2830, 2])
            'pair_human_ids': torch.Size([2830])
            'pair_object_ids': torch.Size([2830])
            'im_idxes': torch.Size([2830])
            'binary_masks': torch.Size([789, 224, 224])
            'interactions_gt': torch.Size([217, 50])
            'spatial_gt': torch.Size([217, 8])
            'action_gt': torch.Size([217, 42])
            'exist_mask': torch.Size([217])
            'change_mask': torch.Size([217])
            'windows': torch.Size([217, 2830])
            'windows_out': torch.Size([217, 2830])
            'out_im_idxes': torch.Size([BatchSize])- int that indicates the indices of the batch sizes per frame
        :return: (action_classes, human_motion) of shape ([B], [B,T,P])
        """
        return self.model(batch)


    def load_model(self, cfg, visualize=True):
        modelname = cfg["MODEL"]["MODEL_NAME"]
        assert modelname in ["STTranGaze"], "Error when defining the model type"
        if modelname == "STTranGaze":
            from modules.sthoip_transformer.STTRAN_Wrapper import STTRAN_Wrapper
            self.model, message = self.STTRAN_Wrapper(cfg, cfg["MODEL"]["gaze_usage"])

        return modelname

    def to(self, device):
        self.model.to(device)

    def parameters(self):
        return self.model.parameters()

    def train(self):
        self.training=True
        self.model.train()

    def eval(self):
        self.training=False
        self.model.eval()

    def __str__(self):
        return self.model.__str__()

    def __repr__(self):
        return self.model.__repr__()

    def load_state_dict(self, state_dict, strict=True):
        return self.model.load_state_dict(state_dict=state_dict, strict=strict)

    def state_dict(self):
        return self.model.state_dict()


if __name__ == '__main__':
    from configs.paths import *
    from configs.cfg_to_info import cfg_to_info, default_opt
    future = 0
    opt = default_opt()
    opt.name = "exp_f{}".format(future)
    opt.cfg = project_path + "/hoi/configs/cfg/train_hyp_f0.yaml"
    opt.modelname = "STTRAN"
    cfg = cfg_to_info(opt)
    modelwrapper = Model_Wrapper(cfg=cfg)

