import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
from modules.hoi4abot.hoibot.modules.feature_extractor.dinov2 import Dinov2
from modules.hoi4abot.hoibot.modules.prompt_encoder.bbox_encoder import Bbox_Wrapper
from modules.hoi4abot.hoibot.modules.semantic_extractor.SemanticWrapper import SemanticWrapper
from modules.hoi4abot.hoibot.modules.transformer.getter import get_transformer
from modules.hoi4abot.hoibot.modules.patch_blender.blender import Blender
from modules.hoi4abot.hoibot.modules.head.InteractionHead import InteractionHead
from modules.hoi4abot.hoibot.modules.prompt_encoder.union_box import UnionWrapper
from modules.hoi4abot.hoibot.modules.transformer.transformer_head import TransformerHead


from einops import rearrange, reduce, repeat
from typing import Callable, Optional, Tuple, Union

class HOIBOT(nn.Module):
    def __init__(
        self,
        device,
        num_action_classes,
        num_spatial_classes,
        obj_class_names,
        embedding_dimension = 384,
        hoi_feature = "learnable",
        input_image_size: Tuple[int, int]=(224, 224),
        patch_size: Union[int, Tuple[int, int]] = 16,
        sliding_window=4,
        max_length = 16,
        depth = 4,
        dual_transformer_type="dual",
        num_heads = 8,
        mlp_ratio = 4.0,
        drop_rate=0.2,
        blender_type = "MOA",
        moa_eps = 0.0,
        box_encoder_type = "Box",
        semantic_type = "file",
        do_regression=False,
        use_feature_extractor = False,
        train_feature_extractor = False,
        annotation_dir="/annotations/",
        semantic_masking_prob=0.0,
        augmentation_semantic=False,
        pos_embed_type="sinusoidal",
        simple_semantics=True,
        image_cls_type = "mean",
        mainbranch ="objects",
        do_inference=False,
        union_box = "none",
        head_cls_type="single",
        bigextractor=False,
        extend_head_transformer = False,
        concat_extension=False,
    ):
        super().__init__()
        print("in model", patch_size)
        print("in model", input_image_size)
        self.device = device
        self.embedding_dimension = embedding_dimension
        self.sliding_window = sliding_window
        self.dual_transformer_type = dual_transformer_type
        self.image_cls_type = image_cls_type
        self.do_inference = do_inference
        self.bigextractor=bigextractor
        self.extend_head_transformer = extend_head_transformer

        dinoname = "dinov2_vits14" if not bigextractor else "dinov2_vitb14"
        self.feature_extractor = Dinov2(name=dinoname, device=device, train_feature_extractor=train_feature_extractor)
        if self.image_cls_type == "Linear":
            self.img_cls_projection = nn.Linear(self.sliding_window, 1)

        bbox_embedding_size = embedding_dimension if not bigextractor else 384
        self.bbox_embedder = Bbox_Wrapper(box_encoder_type, bbox_embedding_size, input_image_size, patch_size, device)

        self.feature_extractor = Dinov2(device=device, train_feature_extractor=train_feature_extractor)
        self.semantic_extractor = SemanticWrapper(semantic_type, device, annotation_dir, simple_semantics, semantic_masking_prob, augmentation_semantic)
        if self.bigextractor:
            self.fc_semantic = nn.Linear(self.semantic_extractor.embedding_size, self.feature_extractor.embedding_dim+bbox_embedding_size)

        self.blender = Blender(blender_type, embedding_dimension, input_image_size, patch_size, moa_eps, device)

        self.blender = Blender(blender_type, self.feature_extractor.embedding_dim, input_image_size, patch_size, moa_eps, device)
        self.unionwrapper = UnionWrapper(union_box, windows_size=sliding_window, patch_size=patch_size, embedding_dimension=bbox_embedding_size,
                                                   img_size=input_image_size, only_spatial=False)

        use_hoi_token = True if (hoi_feature in ["learnable"] and not self.extend_head_transformer) else False

        self.transformer = get_transformer(dual_transformer_type, (self.feature_extractor.embedding_dim+bbox_embedding_size)*2, sliding_window,
                    depth,semantic_type, mlp_ratio, drop_rate, num_heads,
                    pos_embed_type, image_cls_type, use_hoi_token, mainbranch, just_the_passed=True)

        interaction_embed = self.transformer.embedding_dim
        if self.extend_head_transformer:
            hoi_feature ="learnable_head" if hoi_feature=="learnable" else hoi_feature

            if concat_extension and (self.transformer.mainbranch_id==self.transformer.secondbranch_id):
                assert True, f"With transformer extension, the concat option is only possible when both main branch and secondbranch have the same IDs." \
                             f"Concat {concat_extension} | MainBranch {self.transformer.mainbranch_id} | SecondBranch {self.transformer.secondbranch_id} "

            interaction_embed = interaction_embed *2 if concat_extension else interaction_embed
            self.transformer_head = TransformerHead(
                hoi_token_type=hoi_feature,
                concat=concat_extension,
                embed_dim=interaction_embed,
                windows_size=sliding_window,
                depth=depth // 2,  # number of blocks
                mlp_ratio=mlp_ratio,
                drop_rate=drop_rate,
                add_id=self.transformer.mainbranch_id
            )

        idx_ = self.transformer.mainbranch_id
        self.interaction_head = InteractionHead(hoi_feature, mainbranch, interaction_embed, sliding_window, num_action_classes, num_spatial_classes, do_regression, idx_windows=idx_,
                                                as_mlp=head_cls_type=="mlp", just_the_passed=True, half_prepend=False)


    def info_model(self):
        print("Information regarding HOIBOT model: number of parameters")
        print("---"*10)

        self.feature_extractor.info_model()
        self.bbox_embedder.info_model()
        self.semantic_extractor.info_model()
        self.blender.info_model()
        self.transformer.info_model()
        self.interaction_head.info_model()
        if self.extend_head_transformer:
            self.transformer_head.info_model()

        print("---"*10)
        total_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        print("[{}] - {:.2f}M".format("HOIBOT trainable", total_params / 10 ** 6))

        total_params = sum(p.numel() for p in self.parameters())
        print("[{}] - {:.2f}M".format("HOIBOT all", total_params / 10 ** 6))

    def extract_features(self, batch):
        """
        Uses image feature extractor to extract the cls_tokens and patch_tokens from a given window frames
        :param batch: dict. It will include "image" if the features in inference, else it includes
        cls_tokens and patch_tokens already extracted
        :return:
        """
        if "cls_tokens" not in batch or "patch_tokens" not in batch :
            images = batch["frames"]
            not_flattened =  len(images.shape) == 5
            if not_flattened:
                b, t = images.shape[:2]
                images = rearrange(images, "b t c h w -> (b t) c h w ")

            # [(B T) C H W] -> ( cls_tokens: [(B T) 1 EmbedDim] , patch_tokens: [(B T) NumTokens EmbedDim] )
            cls_tokens, patch_tokens =self.feature_extractor(images)
            cls_tokens, patch_tokens = self.bbox_embedder.extend_projection(cls_tokens, patch_tokens)

            if not_flattened:
                cls_tokens = rearrange(cls_tokens, "(b t) d e -> b t d e ", b=b, t=t)
                patch_tokens = rearrange(patch_tokens, "(b t) d e -> b t d e ", b=b, t=t)

            batch.update({"cls_tokens": cls_tokens, "patch_tokens":patch_tokens})

        batchid_frame = self.get_batch_id_per_frame(batch["out_im_idxes"])
        B = torch.unique(batchid_frame)
        if self.image_cls_type == "Mean" or self.image_cls_type == "None":
            if self.do_inference:
                cls_tokens = [reduce(batch["cls_tokens"], "t e -> e", reduction="mean")]
            else:
                cls_tokens = [reduce(batch["cls_tokens"][batchid_frame==index], "t e -> e", reduction="mean") for index in B]

            cls_tokens = torch.stack(cls_tokens)
        elif self.image_cls_type == "Linear":
            tokens_batched = torch.zeros(len(B), self.sliding_window, self.feature_extractor.embedding_dim, device=self.device)
            for i in B:
                num_frames = batch["cls_tokens"][batchid_frame==i].shape[0]
                tokens_batched[i, :num_frames] = batch["cls_tokens"][batchid_frame==i]
            cls_tokens = self.img_cls_projection(rearrange(tokens_batched, "b w e -> b e w")).squeeze(-1)
        batch.update({"cls_tokens": cls_tokens})
        return batch


    def get_batch_id_per_frame(self, im_idx_out):
        batch_id_num = []
        num = 0
        for i, n in enumerate(im_idx_out):
            batch_id_num.append(torch.LongTensor([i]*(n.item()+1-num)))
            num =n.item()+1
        return torch.cat(batch_id_num)

    def get_features(self, batch):
        """
        Processed different input features:
            tokens - [visual features, box_embeddings] : [NumberPairs, 2E]
            semantics : [NumberPairs, 2E]
            prepend_human: [NumberBatchSize, 2E]
        :param batch:
        :return:
        """
        # Extract Patch Tokens and Cls Tokens
        batch = self.extract_features(batch)

        # Extract Semantic Embeddings per pair
        semantics = self.semantic_extractor(batch)

        # Obtain patches per entity, and blend it to obtain features per item
        frame_indices = batch["bboxes"][:,0].to(dtype=torch.long, device=self.device)
        patch_tokens = batch["patch_tokens"][frame_indices]
        visual_features = self.blender(patch_tokens, batch["binary_masks"])

        # Obtain embeddings for the locations of each item
        if not self.bbox_embedder.isused():
            tokens = visual_features
            prepend_human = batch["cls_tokens"]
        else:
            box_embeddings = self.bbox_embedder(batch)
            tokens = torch.cat([visual_features, box_embeddings], dim=-1)
            prepend_human = self.unionwrapper.prepare_human_pos(batch["cls_tokens"])

        return tokens, prepend_human, semantics



    def prepare_windows(self, batch, token_humans, token_objects, prepend_human, semantics, bin_masks_humans=None, bin_masks_objects=None):
        windows = batch["windows"]
        windows_out = batch["windows_out"]
        im_idxes = batch["im_idxes"]

        frames_to_batch = self.get_batch_id_per_frame(batch["out_im_idxes"])
        num_sliding_window = len(windows)
        max_len_window = torch.max(torch.sum(windows, dim=1))

        transformer_input_humans = torch.zeros(num_sliding_window, self.sliding_window, self.transformer.embedding_dim//2, device=self.device)
        transformer_input_objects = torch.zeros(num_sliding_window, self.sliding_window, self.transformer.embedding_dim//2, device=self.device)

        temporal_idx = -torch.ones((num_sliding_window, self.sliding_window), dtype=torch.long, device=self.device)
        temporal_padding_masks = torch.zeros((num_sliding_window, self.sliding_window), dtype=torch.bool, device=self.device)

        if self.unionwrapper.is_box:
            UI_input_humans = torch.zeros(num_sliding_window, self.unionwrapper.UI_box.windows_size, self.unionwrapper.UI_box.num_patches,
                                          device=self.device)
            UI_input_objects = torch.zeros(num_sliding_window, self.unionwrapper.UI_box.windows_size, self.unionwrapper.UI_box.num_patches,
                                           device=self.device)


        prepend_humans = []
        prepend_object_semantics = []
        # fill everything in each sliding window
        for idx_window, window in enumerate(windows):
            temporal_slice_len = torch.sum(window)

            transformer_input_humans[idx_window, :temporal_slice_len, :] = token_humans[window]
            transformer_input_objects[idx_window, :temporal_slice_len, :] = token_objects[window]

            temporal_idx[idx_window, :temporal_slice_len, ] = im_idxes[window] - im_idxes[window][0]
            temporal_padding_masks[idx_window, temporal_slice_len:] = 1

            batch_id = frames_to_batch[batch["im_idxes"][window]][-1]

            if self.unionwrapper.is_box:
                UI_input_humans[idx_window, -temporal_slice_len:, :] = bin_masks_humans[window]
                UI_input_humans[idx_window, temporal_slice_len:, :] = bin_masks_humans[window][[0]]

                UI_input_objects[idx_window, -temporal_slice_len:, :] = bin_masks_objects[window]
                UI_input_objects[idx_window, temporal_slice_len:, :] = bin_masks_objects[window][[0]]

            prepend_humans.append(prepend_human[batch_id][None, None])
            if self.semantic_extractor.isused():
                prepend_object_semantics.append(semantics[window][-1])

        prepend_humans = torch.cat(prepend_humans)

        if self.unionwrapper.is_box:
            ui_box_token = self.unionwrapper(UI_input_humans, UI_input_objects)
            prepend_humans = torch.cat([prepend_humans, ui_box_token[:, None]], axis=-1)

        if self.semantic_extractor.isused():
            prepend_object_semantics = torch.cat(prepend_object_semantics)
            if self.bigextractor:
                prepend_object_semantics = self.fc_semantic(prepend_object_semantics)

        if self.image_cls_type != "None":
            transformer_input_humans = torch.cat([prepend_humans, transformer_input_humans], axis=1)

        transformer_input_objects = self.semantic_extractor.prepend_semantics(prepend_object_semantics, transformer_input_objects)
        return transformer_input_humans, transformer_input_objects, temporal_idx, temporal_padding_masks


    def backbone(self, batch):
        tokens, prepend_human, semantics = self.get_features(batch)

        token_humans = tokens[batch["pair_idxes"][:, 0]]
        token_objects = tokens[batch["pair_idxes"][:, 1]]

        bin_masks_humans, bin_masks_objects = self.unionwrapper.blender(batch["binary_masks"], batch["pair_idxes"])

        window_humans, window_objects, temporal_idx, temporal_padding_masks = self.prepare_windows(batch, token_humans, token_objects, prepend_human, semantics, bin_masks_humans, bin_masks_objects)

        windows = torch.cat([window_humans, window_objects], axis=-1)

        _, windows = self.interaction_head.prepend_token(None, windows)


        mainbranch = self.transformer(windows, temporal_idx, temporal_padding_masks)

        return mainbranch, None, temporal_idx, temporal_padding_masks

    def head(self, batch,  mainbranch, secondbranch, temporal_idx, temporal_padding_masks ):
        if self.extend_head_transformer:
            mainbranch = self.transformer_head(mainbranch, secondbranch, temporal_idx, temporal_padding_masks)

        output = self.interaction_head(mainbranch, secondbranch, windows=batch["windows"], windows_out=batch["windows_out"], im_idxes=batch["im_idxes"])
        return output
    def forward(self, batch):
        mainbranch, secondbranch, temporal_idx, temporal_padding_masks = self.backbone(batch)
        output = self.head(batch, mainbranch, secondbranch, temporal_idx, temporal_padding_masks)
        batch.update(output)
        return batch

