"""
This class is about swap fusion applications
"""
import torch
from einops import rearrange
from torch import nn, einsum
from einops.layers.torch import Rearrange, Reduce

from opencood.models.base_transformer import FeedForward, PreNormResidual

# swap attention -> max_vit
class Attention(nn.Module):
    """
    Unit Attention class. Todo: mask is not added yet.

    Parameters
    ----------
    dim: int
        Input feature dimension.
    dim_head: int
        The head dimension.
    dropout: float
        Dropout rate
    agent_size: int
        The agent can be different views, timestamps or vehicles.
    """

    def __init__(
            self,
            dim,
            dim_head=32,
            dropout=0.,
            agent_size=6,
            window_size=7
    ):
        super().__init__()
        assert (dim % dim_head) == 0, \
            'dimension should be divisible by dimension per head'

        self.heads = dim // dim_head
        self.scale = dim_head ** -0.5
        self.window_size = [agent_size, window_size, window_size]

        self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
        self.attend = nn.Sequential(
            nn.Softmax(dim=-1)
        )

        self.to_out = nn.Sequential(
            nn.Linear(dim, dim, bias=False),
            nn.Dropout(dropout)
        )

        self.relative_position_bias_table = nn.Embedding(
            (2 * self.window_size[0] - 1) *
            (2 * self.window_size[1] - 1) *
            (2 * self.window_size[2] - 1),
            self.heads)  # 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH

        # get pair-wise relative position index for
        # each token inside the window
        coords_d = torch.arange(self.window_size[0])
        coords_h = torch.arange(self.window_size[1])
        coords_w = torch.arange(self.window_size[2])
        # 3, Wd, Wh, Ww
        coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w))
        coords_flatten = torch.flatten(coords, 1)  # 3, Wd*Wh*Ww

        # 3, Wd*Wh*Ww, Wd*Wh*Ww
        relative_coords = \
            coords_flatten[:, :, None] - coords_flatten[:, None, :]
        # Wd*Wh*Ww, Wd*Wh*Ww, 3
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()
        # shift to start from 0
        relative_coords[:, :, 0] += self.window_size[0] - 1
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 2] += self.window_size[2] - 1

        relative_coords[:, :, 0] *= \
            (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1)
        relative_coords[:, :, 1] *= (2 * self.window_size[2] - 1)
        relative_position_index = relative_coords.sum(-1)  # Wd*Wh*Ww, Wd*Wh*Ww
        self.register_buffer("relative_position_index",
                             relative_position_index)

    def forward(self, x):
        # x shape: b, l, h, w, w_h, w_w, c
        batch, agent_size, height, width, window_height, window_width, _, device, h \
            = *x.shape, x.device, self.heads

        # flatten
        x = rearrange(x, 'b l x y w1 w2 d -> (b x y) (l w1 w2) d')
        # project for queries, keys, values
        q, k, v = self.to_qkv(x).chunk(3, dim=-1)
        # split heads
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h),
                      (q, k, v))
        # scale
        q = q * self.scale

        # sim
        sim = einsum('b h i d, b h j d -> b h i j', q, k)

        # add positional bias
        bias = self.relative_position_bias_table(self.relative_position_index)
        sim = sim + rearrange(bias, 'i j h -> h i j')

        # attention
        attn = self.attend(sim)
        # aggregate
        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        # merge heads
        out = rearrange(out, 'b h (l w1 w2) d -> b l w1 w2 (h d)',
                        l=agent_size, w1=window_height, w2=window_width)

        # combine heads out
        out = self.to_out(out)
        return rearrange(out, '(b x y) l w1 w2 d -> b l x y w1 w2 d',
                         b=batch, x=height, y=width)

class SwapMultiscaleFusionBlock(nn.Module):
    """
    Swap Fusion Block contains window attention and grid attention.
    """

    def __init__(self,
                 input_dim,
                 mlp_dim,
                 dim_head,
                 window_size,
                 agent_size,
                 drop_out):
        super(SwapMultiscaleFusionBlock, self).__init__()
        # b = batch * max_cav
        self.block = nn.Sequential(
            Rearrange('b m d (x w1) (y w2) -> b m x y w1 w2 d',
                      w1=window_size, w2=window_size),
            PreNormResidual(input_dim, Attention(input_dim, dim_head, drop_out,
                                                 agent_size, window_size)),
            PreNormResidual(input_dim,
                            FeedForward(input_dim, mlp_dim, drop_out)),
            Rearrange('b m x y w1 w2 d -> b m d (x w1) (y w2)'),

            Rearrange('b m d (w1 x) (w2 y) -> b m x y w1 w2 d',
                      w1=window_size, w2=window_size),
            PreNormResidual(input_dim, Attention(input_dim, dim_head, drop_out,
                                                 agent_size, window_size)),
            PreNormResidual(input_dim,
                            FeedForward(input_dim, mlp_dim, drop_out)),
            Rearrange('b m x y w1 w2 d -> b m d (w1 x) (w2 y)'),
        )

    def forward(self, x):
        # todo: add mask operation later for mulit-agents
        x = self.block(x)
        return x


class SwapMultiscaleFusionEncoder(nn.Module):
    """
    Data rearrange -> swap block -> mlp_head
    """

    def __init__(self, args):
        super(SwapMultiscaleFusionEncoder, self).__init__()

        self.layers = nn.ModuleList([])
        self.depth = args['depth']

        # block related
        input_dim = args['input_dim']
        mlp_dim = args['mlp_dim']
        agent_size = args['agent_size']
        window_size = args['window_size']
        drop_out = args['drop_out']
        dim_head = args['dim_head']
        rsu_num = args['rsu_num']

        try:
            if args['is_lidar2cam'] == True:
                conv_agent = rsu_num
            else:
                conv_agent = agent_size - rsu_num
        except:
            conv_agent = agent_size - rsu_num

        self.mask = False
        if 'mask' in args:
            self.mask = args['mask']

        for i in range(self.depth):
            block = SwapMultiscaleFusionBlock(input_dim,
                                    mlp_dim,
                                    dim_head,
                                    window_size,
                                    agent_size,
                                    drop_out)
            self.layers.append(block)

        upsample_layers = list()
        conv1x1_layers = list()
        for i in range(self.depth - 1):
            upsample_block = nn.Sequential(
                nn.Conv2d(input_dim, input_dim, 3, 1, 1),
                nn.BatchNorm2d(input_dim),
                nn.ReLU(True),
                nn.Upsample(scale_factor=2, mode='nearest'),
                nn.Conv2d(input_dim, input_dim, 3, 1, 1),
                nn.BatchNorm2d(input_dim),
                nn.ReLU(True)
            )

            conv1x1_block = nn.Conv3d(in_channels=agent_size,
                                     out_channels=conv_agent,
                                     kernel_size=(1, 1, 1),
                                     stride=(1, 1, 1),
                                     padding=(0, 0, 0))
            upsample_layers.append(upsample_block)
            conv1x1_layers.append(conv1x1_block)

        self.upsample_layers = nn.ModuleList(upsample_layers)
        self.conv1x1_layers = nn.ModuleList(conv1x1_layers)
        # mlp head
        self.mlp_head = nn.Sequential(
            Reduce('b m d h w -> b d h w', 'mean'),
            Rearrange('b d h w -> b h w d'),
            nn.LayerNorm(input_dim),
            nn.Linear(input_dim, input_dim),
            Rearrange('b h w d -> b d h w')
        )

    def forward(self, pts_features, img_feature):
        x = img_feature
        for i, stage in enumerate(self.layers):
            pts_feature = pts_features[(self.depth - 1) - i]

            x = torch.cat([x, pts_feature], dim=1)
            x = stage(x)

            # x = self.mlp_head(x)
            if i < self.depth - 1: # TODO: The last layer does not need upsample
                x = self.conv1x1_layers[i](x)
                b, n, c, h, w = x.shape
                x = rearrange(x, 'b n c h w -> (b n) c h w')
                x = self.upsample_layers[i](x)
                x = rearrange(x, '(b n) c h w -> b n c h w', b=b, n=n)

        return self.mlp_head(x)

if __name__ == "__main__":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    args = {'input_dim': 128,
            'mlp_dim': 256,
            'agent_size': 2,
            'window_size': 8,
            'dim_head': 32,
            'drop_out': 0.1,
            'depth': 3,
            'mask': False,
            'upsample_conv': {
                'in_channels': 128,
                'out_channels': 128,
                'kernel_size': 2,
                'stride': 2,
                'padding': 0
            }}

    fusion_net = SwapMultiscaleFusionEncoder(args)
    fusion_net.to(device)

    img_features = torch.rand(4, 1, 128, 32, 32)
    pts_features = [torch.rand(4, 1, 128, 128, 128),
                    torch.rand(4, 1, 128, 64, 64),
                    torch.rand(4, 1, 128, 32, 32)]
    img_features = img_features.to(device)
    pts_features = [pts_feature.to(device) for pts_feature in pts_features]

    output = fusion_net(pts_features, img_features)
    print(output)
