import torch
import torch.nn as nn
from modules.hoi4abot.hoibot.modules.semantic_extractor.clip_module import SemanticExtractor

class SemanticWrapper(nn.Module):
    def __init__(self, semantic_type, device, annotation_dir, simple_semantics, semantic_masking_prob,
                 augmentation_semantic):
        super().__init__()
        self.semantic_type = semantic_type

        if semantic_type == "CLIP":
            semantic_extractor = SemanticExtractor(device=device, annotation_dir=annotation_dir, use_clip=True,
                                                        clip_model="ViT-L/14", simple_semantics=simple_semantics)
            embedding_size = semantic_extractor.embedding_size
        elif semantic_type == "file":
            semantic_extractor = SemanticExtractor(device=device, annotation_dir=annotation_dir, use_clip=False,
                                                        clip_model="ViT-L/14", masking_prob=semantic_masking_prob,
                                                        augmentation_semantic=augmentation_semantic,
                                                        simple_semantics=simple_semantics)
            embedding_size = semantic_extractor.embedding_size
        elif semantic_type == "None":
            semantic_extractor = nn.Identity()
            embedding_size = 384

        self.embedding_size = embedding_size
        self.semantic_extractor = semantic_extractor

    def isused(self):
        return self.semantic_type != "None"
    def info_model(self):
        total_params = sum(p.numel() for p in self.parameters() if p.requires_grad) if self.semantic_type != "None" else 0
        print("[{}] - {:.2f}M".format(f"Semantic Extractor {self.semantic_type}", total_params / 10 ** 6))
    def forward(self, batch):
        """
        Get the semantic embedding for the object types encountered in the video
        :param batch: batch["object"] contains a list of Index of object types.
        :return:
        """
        if self.semantic_type == "None":
            semantics = None
        else:
            if "semantics" not in batch and self.semantic_type == "CLIP":
                semantics = self.semantic_extractor.get_feature_object_idx(batch["pred_labels"])
            elif "semantics" not in batch and self.semantic_type == "file":
                semantics = self.semantic_extractor.get_from_index(batch["pred_labels"].tolist())
            semantics = semantics[batch["pair_idxes"][:, 1]][:, None]
        return semantics

    def prepend_semantics(self,  prepend_object_semantics, transformer_input_objects):
        if self.semantic_type != "None":
            transformer_input_objects = torch.cat([prepend_object_semantics[:, None], transformer_input_objects], axis=1)
        return transformer_input_objects

if __name__ == '__main__':
    from configs.paths import annotations_dir
    semantic_type="file"
    device="cuda:0"
    simple_semantics =True
    semantic_masking_prob=0.0
    augmentation_semantic=False
    batch = {"pred_labels": torch.randint(low=0, high=20, size=(20,)).to(device)}
    semantic_wrapper = SemanticWrapper(semantic_type, device, annotations_dir, simple_semantics, semantic_masking_prob, augmentation_semantic)
    semantic_wrapper.info_model()