import torch.nn as nn
import torch

from modules.hoi4abot.hoibot.modules.head.InteractionHead import InteractionHead
from modules.hoi4abot.hoibot.hoitbot import HOIBOT


class HOIBOT_BOTH(nn.Module):
    def __init__(self, model_detection, model_anticipation, device):
        super().__init__()
        self.model_detection = model_detection
        self.model_anticipation = model_anticipation

        self.feature_extractor = model_detection.feature_extractor
        self.blender = self.model_detection.blender
        self.device = device
        self.isloaded = False
        self.future_nums = ["detection", "anticipation"]

    def info_model(self):
        print("Information regarding HOIBOT DA: number of parameters")
        print("---" * 50)
        self.model_detection.info_model()
        print("---" * 50)

        total_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        print("[{}] - {:.2f}M".format("HOIBOT trainable", total_params / 10 ** 6))

        total_params = sum(p.numel() for p in self.parameters())
        print("[{}] - {:.2f}M".format("HOIBOT all", total_params / 10 ** 6))

    def load_weights_backbone(self, state_dict_detection, state_dict_anticipation):
        incompatibles = self.model_detection.load_state_dict(state_dict_detection)
        print(f"HOIBOT loaded: model_detection. Incompatible keys {incompatibles}")
        incompatibles = self.model_anticipation.load_state_dict(state_dict_anticipation)
        print(f"HOIBOT loaded: model_anticipation. Incompatible keys {incompatibles}")
        self.isloaded = True

    def shared_steps(self, batch):
        # Extract Patch Tokens and Cls Tokens
        batch = self.model_detection.extract_features(batch)

        # Extract Semantic Embeddings per pair
        semantics = self.model_detection.semantic_extractor(batch)
        return batch, semantics

    def finish_feature_extraction_individual(self, model, batch):
        # Obtain patches per entity, and blend it to obtain features per item
        frame_indices = batch["bboxes"][:,0].to(dtype=torch.long, device=self.device)
        patch_tokens = batch["patch_tokens"][frame_indices]
        visual_features = model.blender(patch_tokens, batch["binary_masks"])

        # Obtain embeddings for the locations of each item
        if not model.bbox_embedder.isused():
            tokens = visual_features
            prepend_human = batch["cls_tokens"]
        else:
            box_embeddings = model.bbox_embedder(batch)
            tokens = torch.cat([visual_features, box_embeddings], dim=-1)
            prepend_human = model.unionwrapper.prepare_human_pos(batch["cls_tokens"])

        return tokens, prepend_human

    def obtain_interaction(self, model ,batch, semantics):
        tokens, prepend_human = self.finish_feature_extraction_individual(model, batch)
        token_humans = tokens[batch["pair_idxes"][:, 0]]
        token_objects = tokens[batch["pair_idxes"][:, 1]]

        bin_masks_humans, bin_masks_objects = model.unionwrapper.blender(batch["binary_masks"], batch["pair_idxes"])

        window_humans, window_objects, temporal_idx, temporal_padding_masks = model.prepare_windows(batch, token_humans,
                                                                                                   token_objects,
                                                                                                   prepend_human,
                                                                                                   semantics,
                                                                                                   bin_masks_humans,
                                                                                                   bin_masks_objects)

        mainbranch, secondbranch = model.transformer(window_humans, window_objects, temporal_idx, temporal_padding_masks)

        if model.extend_head_transformer:
            mainbranch = model.transformer_head(mainbranch, secondbranch, temporal_idx, temporal_padding_masks)

        output = model.interaction_head(mainbranch, secondbranch, windows=batch["windows"], windows_out=batch["windows_out"], im_idxes=batch["im_idxes"])

        return output

    def forward(self, batch):
        batch, semantics = self.shared_steps(batch)
        detection_output = self.obtain_interaction(self.model_detection, batch, semantics)
        anticipation_output = self.obtain_interaction(self.model_anticipation, batch, semantics)
        batch.update({f"detection": detection_output})
        batch.update({f"anticipation": anticipation_output})
        return batch
