import torch.nn as nn
import torch

from modules.hoi4abot.hoibot.modules.head.InteractionHead import InteractionHead
from modules.hoi4abot.hoibot.hoitbot import HOIBOT


class HOIBOT_HYDRA(nn.Module):
    def __init__(self, model_detection, future_nums=[0, 1, 3, 5]):
        super().__init__()
        self.model_detection = model_detection
        self.feature_extractor = model_detection.feature_extractor
        self.blender = self.model_detection.blender

        self.future_nums = future_nums
        self.anticipation_heads = nn.ModuleDict({f"{future_num}": InteractionHead(**self.model_detection.interaction_head.return_params()) for future_num in future_nums if future_num not in [0]})
        # self.freeze_nohead()
        self.isloaded = False

    def info_model(self):
        print("Information regarding HOIBOT DA: number of parameters")
        print("---" * 50)
        self.model_detection.info_model()
        print("---" * 50)

        total_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        print("[{}] - {:.2f}M".format("HOIBOT trainable", total_params / 10 ** 6))

        total_params = sum(p.numel() for p in self.parameters())
        print("[{}] - {:.2f}M".format("HOIBOT all", total_params / 10 ** 6))

    def load_weights_backbone(self, state_dict, future_num=0):
        if future_num == 0:
            incompatibles = self.model_detection.load_state_dict(state_dict)
        else:
            head_dict = {k.replace(f"anticipation_heads.{future_num}.", ""): v for k, v in state_dict.items() if "anticipation_heads" in k}
            incompatibles = self.anticipation_heads[f"{future_num}"].load_state_dict(head_dict)
        print(f"HOIBOT loaded: {future_num}. Incompatible keys {incompatibles}")
        self.isloaded = True

    def forward(self, batch):
        mainbranch, secondbranch, temporal_idx, temporal_padding_masks = self.model_detection.backbone(batch)
        for future_num in self.future_nums:
            if future_num == 0:
                head_ = self.model_detection.interaction_head
            else:
                head_ = self.anticipation_heads[f"{future_num}"]
            output = head_(mainbranch, secondbranch, windows=batch["windows"], windows_out=batch["windows_out"], im_idxes=batch["im_idxes"])
            batch.update({f"future_num_{future_num}": output})
        return batch
