import torch.nn as nn
import torch

from modules.hoi4abot.hoibot.modules.head.InteractionHead import InteractionHead
from modules.hoi4abot.hoibot.hoitbot import HOIBOT

class HOIBOT_DA(nn.Module):
    def __init__(self, model_detection: HOIBOT, future_num=1):
        super().__init__()
        self.model_detection = model_detection
        self.feature_extractor = model_detection.feature_extractor

        self.future_num = future_num
        self.anticipation_heads = nn.ModuleDict({
            f"{future_num}": InteractionHead(**self.model_detection.interaction_head.return_params())
        })
        self.freeze_nohead()


    def info_model(self):
        print("Information regarding HOIBOT DA: number of parameters")
        print("---"*50)
        self.model_detection.info_model()
        print("---"*50)

        total_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        print("[{}] - {:.2f}M".format("HOIBOT trainable", total_params / 10 ** 6))

        total_params = sum(p.numel() for p in self.parameters())
        print("[{}] - {:.2f}M".format("HOIBOT all", total_params / 10 ** 6))


    def freeze_nohead(self):
        self.model_detection.eval()
        for param in self.model_detection.parameters():
            param.requires_grad = False

    def load_weights_backbone(self, state_dict):
        incompatibles = self.model_detection.load_state_dict(state_dict)
        print(f"HOIBOT loaded. Incompatible keys {incompatibles}")
        self.freeze_nohead()

    def forward(self, batch):
        mainbranch, secondbranch, temporal_idx, temporal_padding_masks = self.model_detection.backbone(batch)
        head_  = self.anticipation_heads[f"{self.future_num}"]
        output = head_(mainbranch, secondbranch, windows = batch["windows"], windows_out = batch["windows_out"], im_idxes = batch[
            "im_idxes"])
        batch.update(output)
        return batch