import torch
import torch.nn as nn
from einops import rearrange, reduce, repeat


class InteractionHead(nn.Module):
    def __init__(self, hoi_feature, cls_token_to, embedding_dimension, sliding_window, num_action_classes, num_spatial_classes, do_regression=False, no_reduction=False, just_the_passed=False,
                 idx_windows=1, as_mlp=False, half_prepend=False, multiple_heads=None):
        super().__init__()
        self.hoi_feature = hoi_feature
        self.cls_token_to = cls_token_to
        self.index_window=idx_windows
        self.sliding_window=sliding_window
        self.num_action_classes=num_action_classes
        self.num_spatial_classes=num_spatial_classes
        self.multiple_heads =multiple_heads

        if not just_the_passed:
            embed_dim = 3 * embedding_dimension if no_reduction else 2*embedding_dimension
        else:
            embed_dim = embedding_dimension


        if self.hoi_feature in ["learnable"]:
            hoi_tok_emb = embed_dim //2 if half_prepend else embed_dim
            self.interaction_token = nn.Parameter(torch.zeros(1, 1, hoi_tok_emb))
        elif self.hoi_feature in ["learnable_head"]:
            self.interaction_token = None
        elif self.hoi_feature in ["linear"]:
            self.fuse_interaction = nn.Linear(in_features=sliding_window, out_features=1)
        elif self.hoi_feature in ["pooling"]:
            self.fuse_interaction = nn.Identity()
        elif self.hoi_feature in ["last"]:
            self.fuse_interaction = nn.Identity()
        elif self.hoi_feature in ["last_both"]:
            embed_dim = embed_dim *2

        self.embedding_dim = embed_dim

        if as_mlp:
            self.action_head = MLP(embed_dim=embed_dim, out_features=num_action_classes)
            self.spatial_head = MLP(embed_dim=embed_dim, out_features=num_spatial_classes)
            self.regression_layer = MLP(embed_dim=embed_dim, out_features=2) if do_regression else nn.Identity()
        else:
            if multiple_heads is None:
                self.action_head = nn.Linear(in_features=embed_dim, out_features=num_action_classes)
                self.spatial_head = nn.Linear(in_features=embed_dim, out_features=num_spatial_classes)
                self.regression_layer = nn.Linear(in_features=embed_dim, out_features=2) if do_regression else nn.Identity()
            else:
                self.action_head = nn.ModuleDict({f"{future_num}": nn.Linear(in_features=embed_dim, out_features=num_action_classes) for future_num in multiple_heads})
                self.spatial_head =  nn.ModuleDict({f"{future_num}": nn.Linear(in_features=embed_dim, out_features=num_spatial_classes) for future_num in multiple_heads})
                self.regression_layer = nn.Linear(in_features=embed_dim, out_features=2) if do_regression else nn.Identity()

        self.do_regression = do_regression

    def return_params(self):
        return {"hoi_feature": self.hoi_feature,
                "cls_token_to":self.cls_token_to,
                "embedding_dimension": self.embedding_dim,
                "sliding_window": self.sliding_window,
                "num_action_classes": self.num_action_classes,
                "num_spatial_classes": self.num_spatial_classes,
                "just_the_passed": True,
                }

    def info_model(self):
        total_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        print("[{} ({})] - {:.2f}M".format(f"Interaction", "Head", total_params / 10 ** 6))

    def prepend_token(self, window_objects, window_humans):
        if self.cls_token_to == "objects":
            window_objects = self.learnable_hoi(window_objects)
        else:
            window_humans = self.learnable_hoi(window_humans)
        return window_objects, window_humans
    def learnable_hoi(self, transformer_input):
        if self.hoi_feature in ["learnable"]:
            B = transformer_input.shape[0]
            prepend_token = self.interaction_token
            prepend_token = repeat(prepend_token, "b t e -> (repeat b) t e", repeat=B)
            transformer_input = torch.cat([prepend_token, transformer_input], axis=1)
        return transformer_input

    def forward(self, mainbranch, secondbranch, windows=None, windows_out=None, im_idxes=None):
        if self.hoi_feature in ["learnable", "learnable_head"]:
            hoi_token = torch.zeros((len(im_idxes), mainbranch.shape[-1])).to(mainbranch.device)
            output_mask = torch.zeros_like(im_idxes).bool().to(mainbranch.device)
            for idx_window, (window, window_out) in enumerate(zip(windows, windows_out)):
                hoi_token[window_out, :] = mainbranch[idx_window, 0]
                output_mask = output_mask | window_out
            hoi_token = hoi_token[output_mask]
        elif self.hoi_feature in ["linear"]:
            hoi_token = torch.zeros((len(im_idxes), mainbranch.shape[-1])).to(mainbranch.device)
            output_mask = torch.zeros_like(im_idxes).bool().to(mainbranch.device)
            for idx_window, (window, window_out) in enumerate(zip(windows, windows_out)):
                hoi_token[window_out, :] = mainbranch[idx_window,self.index_window:]
                output_mask = output_mask | window_out
            hoi_token = hoi_token[output_mask]
            hoi_token = rearrange(hoi_token, "b t e -> b e t")
            # We do not want to use the semantic feature
            hoi_token= self.fuse_interaction(hoi_token)
            hoi_token = hoi_token[...,0]
        elif self.hoi_feature in ["pooling"]:
            # We do not want to use the semantic feature
            hoi_token = torch.zeros((len(im_idxes), mainbranch.shape[-1])).to(mainbranch.device)
            output_mask = torch.zeros_like(im_idxes).bool().to(mainbranch.device)
            for idx_window, (window, window_out) in enumerate(zip(windows, windows_out)):
                hoi_token[window_out, :] = mainbranch[idx_window,self.index_window:]
                output_mask = output_mask | window_out
            hoi_token = hoi_token[output_mask]
            hoi_token= reduce(hoi_token, "b t e -> b e", reduction="mean")
        elif self.hoi_feature in ["last"]:
            hoi_token = torch.zeros((len(im_idxes), mainbranch.shape[-1])).to(mainbranch.device)
            output_mask = torch.zeros_like(im_idxes).bool().to(mainbranch.device)
            for idx_window, (window, window_out) in enumerate(zip(windows, windows_out)):
                slice_len = torch.sum(window)+ self.index_window
                out_len = torch.sum(window_out)
                hoi_token[window_out, :] = mainbranch[idx_window, slice_len - out_len : slice_len]
                output_mask = output_mask | window_out
            hoi_token = hoi_token[output_mask]
        elif self.hoi_feature in ["last_both"]:
            hoi_token = torch.zeros((len(im_idxes), mainbranch.shape[-1]*2)).to(mainbranch.device)
            output_mask = torch.zeros_like(im_idxes).bool().to(mainbranch.device)
            for idx_window, (window, window_out) in enumerate(zip(windows, windows_out)):
                slice_len = torch.sum(window)+ self.index_window
                out_len = torch.sum(window_out)
                last__ = torch.cat([mainbranch[idx_window, slice_len - out_len : slice_len], secondbranch[idx_window, slice_len - out_len : slice_len]], axis=-1)
                hoi_token[window_out, :] = last__
                output_mask = output_mask | window_out
            hoi_token = hoi_token[output_mask]

        action_hoi, spatial_hoi = self.do_heads(hoi_token)

        if self.do_regression:
            return {"spatial_head": spatial_hoi, "action_head": action_hoi}
        return {"spatial_head": spatial_hoi, "action_head": action_hoi, "object_pos": None}

    def do_heads(self, hoi_token):
        if self.multiple_heads is None:
            action_hoi = self.action_head(hoi_token)
            spatial_hoi = self.spatial_head(hoi_token)
        else:
            action_hoi = {}
            spatial_hoi = {}
            for future_num in self.multiple_heads:
                action_hoi[future_num] = self.action_head[f"{future_num}"](hoi_token)
                spatial_hoi[future_num] = self.action_head[f"{future_num}"](hoi_token)
        return action_hoi, spatial_hoi

class MLP(nn.Module):
    def __init__(self, embed_dim, out_features, embed_dim_intermediate = 512, drop_rate=0.2):
        super().__init__()
        self.mlp =nn.Sequential(nn.Linear(in_features=embed_dim, out_features=embed_dim_intermediate),
                      nn.ReLU(),
                      nn.Dropout(drop_rate),
                      nn.Linear(in_features=embed_dim_intermediate, out_features=out_features))
    def forward(self, x):
        return self.mlp(x)

if __name__ == '__main__':
    embedding_dimension = 768
    sliding_window = 6
    num_action_classes = 52
    num_spatial_classes = 8
    hoi_feature="last"
    cls_token_to="humans"
    P= 138
    S =902
    t = sliding_window
    E = embedding_dimension

    WO = torch.randn(P, S)
    W = torch.randn(P, S)
    T = torch.randn(P,t+1, E)
    O = torch.randn(S, E)

    i_head = InteractionHead(hoi_feature, cls_token_to, embedding_dimension, sliding_window, num_action_classes, num_spatial_classes)
