import torch.nn as nn

from modules.hoi4abot.hoibot.modules.patch_blender.MOA_blender import MOABlender
from modules.hoi4abot.hoibot.modules.patch_blender.patch_blender import PatchBlender
from modules.hoi4abot.hoibot.modules.transformer.modules_attn.CrossAttention import CrossAttention
from einops import rearrange, reduce, repeat


class Blender(nn.Module):
    def __init__(self, blender_type, embedding_dimension, input_image_size, patch_size, moa_eps, device):
        super().__init__()
        self.blender_type= blender_type
        if blender_type == "Linear":
            self.patch_blender = PatchBlender(img_size=input_image_size, patch_size=patch_size, flatten_embedding=True)
            self.blending_layer = nn.Sequential(
                nn.Linear(patch_size ** 2, patch_size ** 2),
                nn.Sigmoid()
            )
        elif blender_type == "Attention":
            self.patch_blender = PatchBlender(img_size=input_image_size, patch_size=patch_size, flatten_embedding=True)
            self.blending_layer = CrossAttention(patch_size ** 2)
            self.blending_ap = nn.AvgPool1d(patch_size ** 2)

        elif blender_type == "MOA":
            self.patch_blender = MOABlender(embedding_dimension, input_image_size, patch_size, device=device,
                                            eps=moa_eps)

        elif blender_type == "weighted":
            self.patch_blender = PatchBlender(img_size=input_image_size, patch_size=patch_size, flatten_embedding=True)
            pass
        else:
            raise RuntimeError("Unrecognized blender type")

    def info_model(self):
        total_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        print("[{}] - {:.2f}M".format("Blender", total_params / 10 ** 6))
    def forward(self, patch_tokens, bin_masks):
        """
        Blends the features based on the patch information given
        :param patch_tokens: [B, NumPatches, E]
        :param bin_masks: [B,T, N, H, W]
        :return:
        """

        if self.blender_type == "Linear":
            bin_masks = self.patch_blender(bin_masks)
            patch_tokens = rearrange(patch_tokens, "b n e -> e b n")
            bin_masks = self.blending_layer(bin_masks)
            features = bin_masks * patch_tokens
            features = rearrange(features, "e b n -> b n e ")
            features = reduce(features, "b n e -> b e", reduction="sum")

        elif self.blender_type == "Attention":
            bin_masks = self.patch_blender(bin_masks)
            patch_tokens = rearrange(patch_tokens, "b n e -> b e n")
            bin_masks = self.blending_layer(patch_tokens, bin_masks[:, None])
            features = self.blending_ap(bin_masks)
            features = reduce(features, "b e n -> b e", reduction="sum")

        elif self.blender_type == "weighted":
            bin_masks = self.patch_blender(bin_masks)
            patch_tokens = rearrange(patch_tokens, "b n e -> e b n")
            features = bin_masks * patch_tokens
            features = rearrange(features, "e b n -> b n e ")
            features = reduce(features, "b n e -> b e", reduction="sum")

        elif self.blender_type == "MOA":
            features = self.patch_blender(patch_tokens, bin_masks)

        return features
