import torch
import torch.nn as nn


def drop_path(x, drop_prob: float = 0., training: bool = False):
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    # work with diff dim tensors, not just 2D ConvNets
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)
    random_tensor = keep_prob + \
        torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize
    output = x.div(keep_prob) * random_tensor
    return output


class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    """

    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)


class CrossAttention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., compression=2, supervised=False):
        """
            Cross-attention module.
        """
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5
        self.compression = compression

        self.proj_kv = nn.Linear(
            dim, (dim // self.compression) * 2, bias=qkv_bias)
        self.proj_q = nn.Linear(dim, dim // self.compression, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim // self.compression, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.supervised = supervised

    def forward(self, x, context, train=True, support_group_ids=None, query_group_ids=None):
        """
            Perform image-to-image cross-attention between query (x) and support examples (context).

            Information about query and support domains (groups) is only used for supervised CXDA.

            x: Q, HWC, 1, 1
            context: S, HWC, 1, 1

            return: Q, HWC, 1, 1

        """
        B, H, W, C = x.shape
        S = context.shape[0]
        x = x.flatten(start_dim=1, end_dim=2)  # B, HW, C
        q = self.proj_q(x)  # B, HW, C
        q = q.reshape(B, H*W, self.num_heads, (C // self.num_heads) //
                      self.compression).transpose(1, 2)  # B, heads, HW, c

        context = context.flatten(start_dim=0, end_dim=2)  # SHW, C
        kv = self.proj_kv(context)  # SHW, 2C
        kv = kv.reshape(S*H*W, 2, self.num_heads,
                        (C // self.num_heads) // self.compression)
        kv = kv.permute(1, 2, 0, 3)  # 2, heads, SHW, c
        k, v = kv[0], kv[1]  # heads, SHW, c

        # For supervised CXDA manually calculate the attention matrix
        # using support_group_ids and query_group_ids
        if self.supervised:
            attn = torch.zeros(B, self.num_heads, H * W,
                               S * H * W).to(x.data.device)
            group_ids = support_group_ids.unique().tolist()
            for group_id in group_ids:
                attn_val = 1.0 / (S * H * W / len(group_ids))
                group_attn_mask = (
                    support_group_ids == group_id).unsqueeze(-1).repeat(1, H * W).view(-1)
                true_attn = torch.zeros(S * H * W)
                true_attn[group_attn_mask] = attn_val
                attn[query_group_ids == group_id] = true_attn.to(x.data.device)
        else:
            attn = (q @ k.transpose(-2, -1)) * self.scale  # B, heads, HW, SHW
            attn = attn.softmax(dim=-1)
            attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, H, W, C // self.compression)
        x = self.proj(x)
        x = self.proj_drop(x)

        return x, attn


class BlockCXDA(nn.Module):
    def __init__(self, dim, spatial_dim, num_heads=8, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, has_mlp=False, supervised=False, compression=2):
        """
            CXDA block that performs image-to-image cross-attention between query (x) and support examples (context).
        """
        super().__init__()
        self.norm1 = norm_layer(dim * spatial_dim ** 2)
        self.attn = CrossAttention(
            dim * spatial_dim ** 2, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, compression=compression, supervised=supervised)
        self.drop_path = DropPath(
            drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x, context, return_attention=False, train=True, support_group_ids=None, query_group_ids=None):
        """
            Perform image-to-image cross-attention between query (x) and support examples (context).

            Information about query and support domains (groups) is only used for supervised CXDA.
        """
        x = x.permute(0, 2, 3, 1)  # Q, H, W, C
        context = context.permute(0, 2, 3, 1)  # S, H, W, C
        B, H, W, C = x.shape
        # Flatten all channel and spatial dimensions into one vector - image-to-image attention
        # Add additional unit dimensions for height and width
        # New shapes: Q, HWC, 1, 1 and S, HWC, 1, 1
        x = x.flatten(start_dim=1, end_dim=3).unsqueeze(dim=1).unsqueeze(dim=1)
        context = context.flatten(start_dim=1, end_dim=3).unsqueeze(
            dim=1).unsqueeze(dim=1)
        
        y, attn = self.attn(self.norm1(x), self.norm1(context), train=train,
                            support_group_ids=support_group_ids, query_group_ids=query_group_ids)

        # Reshape it back to the previous format
        x = x.flatten(start_dim=1, end_dim=3).reshape(B, H, W, C)
        y = y.flatten(start_dim=1, end_dim=3).reshape(B, H, W, C)

        if return_attention:
            return attn

        x = x + self.drop_path(y)
        x = x.permute(0, 3, 1, 2)  # B, C, H, W
        return x


if __name__ == '__main__':
    x = torch.rand(20, 128, 4, 4)
    context = torch.rand(100, 128, 4, 4)
    tf_block = BlockCXDA(128, 4)
    x = tf_block(x, context)
