from modules.hoi4abot.hoibot.modules.patch_blender.patch_blender import PatchBlender
import torch.nn as nn
import torch
from einops.layers.torch import Rearrange, Reduce
from einops import rearrange, reduce, repeat


class UnionWrapper(nn.Module):
    def __init__(self, use_union_or_intersection, windows_size=6, patch_size=16, embedding_dimension=384,
                 device="cuda:0", img_size=(224, 224), drop_rate=0.2, use_mlp=False, only_spatial=False):
        super().__init__()
        self.use_union_or_intersection = use_union_or_intersection
        if self.use_union_or_intersection =="union":
            self.UI_box = UnionBoxEncoder(windows_size=windows_size, patch_size=patch_size, embedding_dimension=embedding_dimension,
                                                   img_size=img_size, only_spatial=only_spatial, drop_rate=drop_rate, use_mlp=use_mlp)
        elif self.use_union_or_intersection == "intersection":
            self.UI_box = IntersectionBoxEncoder(windows_size=windows_size, patch_size=patch_size, embedding_dimension=embedding_dimension,
                                                 img_size=img_size, only_spatial=only_spatial, drop_rate=drop_rate, use_mlp=use_mlp)
        else:
            self.human_pos = nn.Parameter(torch.zeros(1, embedding_dimension))

    def blender(self, binary_masks, pair_idxes):
        if self.is_box:
            patchified_bin_masks = self.UI_box.patch_blender(binary_masks)
            bin_masks_humans = patchified_bin_masks[pair_idxes[:, 0]]
            bin_masks_objects = patchified_bin_masks[pair_idxes[:, 1]]
        else:
            bin_masks_humans = bin_masks_objects = None
        return bin_masks_humans, bin_masks_objects
    def prepare_human_pos(self, cls_tokens):
        if not self.is_box:
            human_pos =  repeat(self.human_pos, "a b -> (repeat a) b", repeat=cls_tokens.shape[0])
            return torch.cat([cls_tokens,human_pos], axis=-1)
        else:
            return cls_tokens

    @property
    def is_box(self):
        return self.use_union_or_intersection in ["union", "intersection"]
    def info_model(self):
        total_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        print("[{} ({})] - {:.2f}M".format(f"Union", self.use_union_or_intersection, total_params / 10 ** 6))

    def forward(self, patchified_binary_mask_humans, patchified_binary_mask_objects):
        return self.UI_box(patchified_binary_mask_humans, patchified_binary_mask_objects)

class UnionBoxEncoder(nn.Module):
    def __init__(self, windows_size=6, patch_size=16, embedding_dimension= 384, device ="cuda:0", img_size=(224,224), drop_rate = 0.2, use_mlp=False, only_spatial=False):
        super().__init__()
        self.device = device
        self.original_size = img_size
        self.num_patches = patch_size**2
        self.windows_size = windows_size
        self.patch_blender = PatchBlender(img_size=img_size, patch_size=patch_size, flatten_embedding=True)
        self.mlp = TempMLP(self.num_patches, embedding_dimension, windows_size, drop_rate=drop_rate, use_mlp=use_mlp, only_spatial=only_spatial)

    def forward(self, patchified_binary_mask_humans, patchified_binary_mask_objects):
        """
        this info is already provided per pairs, therefore the shape of both binary_masks_humans and binary_masks_objects are the same.
        Args:
            binary_masks_humans:
            binary_masks_objects:

        Returns:

        """
        patchified_union_mask = patchified_binary_mask_humans + patchified_binary_mask_objects
        embeddings = self.mlp(patchified_union_mask)
        return embeddings


class IntersectionBoxEncoder(nn.Module):
    def __init__(self, windows_size=6, patch_size=16, embedding_dimension= 384, device ="cuda:0", img_size=(224,224), drop_rate = 0.2, use_mlp=False, only_spatial=False):
        super().__init__()
        self.device = device
        self.original_size = img_size
        self.num_patches = patch_size**2
        self.windows_size = windows_size
        self.patch_blender = PatchBlender(img_size=img_size, patch_size=patch_size, flatten_embedding=True)
        self.mlp = TempMLP(self.num_patches, embedding_dimension, windows_size, drop_rate=drop_rate, use_mlp=use_mlp, only_spatial=False)

    def forward(self, patchified_binary_mask_humans, patchified_binary_mask_objects):
        """
        this info is already provided per pairs, therefore the shape of both binary_masks_humans and binary_masks_objects are the same.
        Args:
            binary_masks_humans:
            binary_masks_objects:

        Returns:

        """
        patchified_union_mask = patchified_binary_mask_humans * patchified_binary_mask_objects
        embeddings = self.mlp(patchified_union_mask)
        return embeddings

class TempMLP(nn.Module):
    def __init__(self, num_patches, embeding_dim, windows_time, drop_rate=0.2, use_mlp=False, only_spatial=False):
        super().__init__()
        if use_mlp:
            self.spatial_layer= nn.Sequential(
                nn.Linear(num_patches, embeding_dim),
                nn.Dropout(drop_rate),
                nn.ReLU(),
                nn.Linear(embeding_dim, embeding_dim),
            )
        else:
            self.spatial_layer = nn.Linear(num_patches, embeding_dim)
        self.only_spatial = only_spatial
        if not only_spatial:
            self.space2time = Rearrange("b t p -> b p t")
            self.time2space = Rearrange("b p t -> b (t p)")
            self.layernorm = LN(embeding_dim)
            self.temp_layer = nn.Linear(windows_time, 1)

    def forward(self, x):
        x = self.spatial_layer(x)
        if not self.only_spatial:
            x = self.layernorm(x)
            x = self.space2time(x)
            x = self.temp_layer(x)
            x = self.time2space(x)
        return x

class LN(nn.Module):
    def __init__(self, dim, epsilon=1e-5):
        super().__init__()
        self.epsilon = epsilon

        self.alpha = nn.Parameter(torch.ones([1, 1, dim]), requires_grad=True)
        self.beta = nn.Parameter(torch.zeros([1, 1, dim]), requires_grad=True)

    def forward(self, x):
        mean = x.mean(axis=-1, keepdim=True)
        var = ((x - mean) ** 2).mean(dim=-1, keepdim=True)
        std = (var + self.epsilon).sqrt()
        y = (x - mean) / std
        y = y * self.alpha + self.beta
        return y


if __name__ == '__main__':
    B, T, NumPatches = 16, 6, 256
    patchified_binary_mask_humans_ = torch.randn(B, T, NumPatches)
    patchified_binary_mask_objects_ = torch.rand_like(patchified_binary_mask_humans_)
    print("patchified_binary_mask_humans_: ", patchified_binary_mask_humans_.shape)
    print("patchified_binary_mask_objects_: ", patchified_binary_mask_objects_.shape)
    UIbox_union = UnionBoxEncoder(windows_size=T, patch_size=16, embedding_dimension= 384, device ="cuda:0", img_size=(224,224), drop_rate = 0.2)
    output= UIbox_union(patchified_binary_mask_humans_, patchified_binary_mask_objects_)
    print("output: ", output.shape)
