"""
This class is about swap fusion applications
"""
import torch
from einops import rearrange
from torch import nn, einsum
from einops.layers.torch import Rearrange, Reduce

from opencood.models.base_transformer import FeedForward, PreNormResidual

# swap attention -> max_vit
class Attention(nn.Module):
    """
    Unit Attention class. Todo: mask is not added yet.

    Parameters
    ----------
    dim: int
        Input feature dimension.
    dim_head: int
        The head dimension.
    dropout: float
        Dropout rate
    agent_size: int
        The agent can be different views, timestamps or vehicles.
    """

    def __init__(
            self,
            dim,
            dim_head=32,
            dropout=0.,
            agent_size=6,
            window_size=7
    ):
        super().__init__()
        assert (dim % dim_head) == 0, \
            'dimension should be divisible by dimension per head'

        self.heads = dim // dim_head
        self.scale = dim_head ** -0.5
        self.window_size = [agent_size, window_size, window_size]

        self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
        self.attend = nn.Sequential(
            nn.Softmax(dim=-1)
        )

        self.to_out = nn.Sequential(
            nn.Linear(dim, dim, bias=False),
            nn.Dropout(dropout)
        )

        self.relative_position_bias_table = nn.Embedding(
            (2 * self.window_size[0] - 1) *
            (2 * self.window_size[1] - 1) *
            (2 * self.window_size[2] - 1),
            self.heads)  # 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH

        # get pair-wise relative position index for
        # each token inside the window
        coords_d = torch.arange(self.window_size[0])
        coords_h = torch.arange(self.window_size[1])
        coords_w = torch.arange(self.window_size[2])
        # 3, Wd, Wh, Ww
        coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w))
        coords_flatten = torch.flatten(coords, 1)  # 3, Wd*Wh*Ww

        # 3, Wd*Wh*Ww, Wd*Wh*Ww
        relative_coords = \
            coords_flatten[:, :, None] - coords_flatten[:, None, :]
        # Wd*Wh*Ww, Wd*Wh*Ww, 3
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()
        # shift to start from 0
        relative_coords[:, :, 0] += self.window_size[0] - 1
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 2] += self.window_size[2] - 1

        relative_coords[:, :, 0] *= \
            (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1)
        relative_coords[:, :, 1] *= (2 * self.window_size[2] - 1)
        relative_position_index = relative_coords.sum(-1)  # Wd*Wh*Ww, Wd*Wh*Ww
        self.register_buffer("relative_position_index",
                             relative_position_index)

    def forward(self, x):
        # x shape: b, l, h, w, w_h, w_w, c
        batch, agent_size, height, width, window_height, window_width, _, device, h \
            = *x.shape, x.device, self.heads

        # flatten
        x = rearrange(x, 'b l x y w1 w2 d -> (b x y) (l w1 w2) d')
        # project for queries, keys, values
        q, k, v = self.to_qkv(x).chunk(3, dim=-1)
        # split heads
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h),
                      (q, k, v))
        # scale
        q = q * self.scale

        # sim
        sim = einsum('b h i d, b h j d -> b h i j', q, k)

        # add positional bias
        bias = self.relative_position_bias_table(self.relative_position_index)
        sim = sim + rearrange(bias, 'i j h -> h i j')

        # attention
        attn = self.attend(sim)
        # aggregate
        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        # merge heads
        out = rearrange(out, 'b h (l w1 w2) d -> b l w1 w2 (h d)',
                        l=agent_size, w1=window_height, w2=window_width)

        # combine heads out
        out = self.to_out(out)
        return rearrange(out, '(b x y) l w1 w2 d -> b l x y w1 w2 d',
                         b=batch, x=height, y=width)

class CrossAttention(nn.Module):
    """
    Unit Cross Attention class.

    Parameters
    ----------
    dim: int
        Input feature dimension.
    dim_head: int
        The head dimension.
    dropout: float
        Dropout rate
    agent_size: int
        The agent can be different views, timestamps or vehicles.
    """

    def __init__(
            self,
            dim,
            dim_head=32,
            dropout=0.,
            agent_size=6,
            window_size=7
    ):
        super().__init__()
        assert (dim % dim_head) == 0, \
            'dimension should be divisible by dimension per head'

        self.heads = dim // dim_head
        self.scale = dim_head ** -0.5
        self.window_size = [agent_size, window_size, window_size]

        self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
        self.attend = nn.Sequential(
            nn.Softmax(dim=-1)
        )

        self.to_out = nn.Sequential(
            nn.Linear(dim, dim, bias=False),
            nn.Dropout(dropout)
        )

        self.relative_position_bias_table = nn.Embedding(
            (2 * self.window_size[0] - 1) *
            (2 * self.window_size[1] - 1) *
            (2 * self.window_size[2] - 1),
            self.heads)  # 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH

        # get pair-wise relative position index for
        # each token inside the window
        coords_d = torch.arange(self.window_size[0])
        coords_h = torch.arange(self.window_size[1])
        coords_w = torch.arange(self.window_size[2])
        # 3, Wd, Wh, Ww
        coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w))
        coords_flatten = torch.flatten(coords, 1)  # 3, Wd*Wh*Ww

        # 3, Wd*Wh*Ww, Wd*Wh*Ww
        relative_coords = \
            coords_flatten[:, :, None] - coords_flatten[:, None, :]
        # Wd*Wh*Ww, Wd*Wh*Ww, 3
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()
        # shift to start from 0
        relative_coords[:, :, 0] += self.window_size[0] - 1
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 2] += self.window_size[2] - 1

        relative_coords[:, :, 0] *= \
            (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1)
        relative_coords[:, :, 1] *= (2 * self.window_size[2] - 1)
        relative_position_index = relative_coords.sum(-1)  # Wd*Wh*Ww, Wd*Wh*Ww
        self.register_buffer("relative_position_index",
                             relative_position_index)

    def forward(self, q, k, v):
        # x shape: b, l, h, w, w_h, w_w, c
        # q, k, v maintain same shape
        batch, agent_size, height, width, window_height, window_width, _, device, h \
            = *q.shape, q.device, self.heads

        # flatten queries, keys, values
        q = rearrange(q, 'b l x y w1 w2 d -> (b x y) (l w1 w2) d')
        k = rearrange(k, 'b l x y w1 w2 d -> (b x y) (l w1 w2) d')
        v = rearrange(v, 'b l x y w1 w2 d -> (b x y) (l w1 w2) d')

        # split heads
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h),
                      (q, k, v))
        # scale
        q = q * self.scale

        # sim
        sim = einsum('b h i d, b h j d -> b h i j', q, k)

        # add positional bias
        bias = self.relative_position_bias_table(self.relative_position_index)
        sim = sim + rearrange(bias, 'i j h -> h i j')

        # attention
        attn = self.attend(sim)
        # aggregate
        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        # merge heads
        out = rearrange(out, 'b h (l w1 w2) d -> b l w1 w2 (h d)',
                        l=agent_size, w1=window_height, w2=window_width)

        # combine heads out
        out = self.to_out(out)
        return rearrange(out, '(b x y) l w1 w2 d -> b l x y w1 w2 d',
                         b=batch, x=height, y=width)


class CrossAttentionBlock(nn.Module):
    def __init__(
            self,
            input_dim,
            mlp_dim,
            dim_head,
            window_size,
            agent_size,
            drop_out
    ):
        super().__init__()
        self.window_size = window_size
        self.norm = nn.LayerNorm(input_dim)
        self.cross_attn1 = CrossAttention(input_dim, dim_head, drop_out, agent_size, window_size)
        self.ff1 = FeedForward(input_dim, mlp_dim, drop_out)
        self.cross_attn2 = CrossAttention(input_dim, dim_head, drop_out, agent_size, window_size)
        self.ff2 = FeedForward(input_dim, mlp_dim, drop_out)

    def forward(self, q, k):
        # First layer
        v = k.clone()
        q = rearrange(q, 'b m d (x w1) (y w2) -> b m x y w1 w2 d', w1=self.window_size, w2=self.window_size)
        k = rearrange(k, 'b m d (x w1) (y w2) -> b m x y w1 w2 d', w1=self.window_size, w2=self.window_size)
        v = rearrange(v, 'b m d (x w1) (y w2) -> b m x y w1 w2 d', w1=self.window_size, w2=self.window_size)
        q = self.cross_attn1(self.norm(q), self.norm(k), self.norm(v)) + q
        q = self.ff1(self.norm(q)) + q
        q = rearrange(q, 'b m x y w1 w2 d -> b m d (x w1) (y w2)')
        k = rearrange(k, 'b m x y w1 w2 d -> b m d (x w1) (y w2)')
        v = rearrange(v, 'b m x y w1 w2 d -> b m d (x w1) (y w2)')

        # Second layer
        q = rearrange(q,'b m d (w1 x) (w2 y) -> b m x y w1 w2 d', w1=self.window_size, w2=self.window_size)
        k = rearrange(k, 'b m d (w1 x) (w2 y) -> b m x y w1 w2 d', w1=self.window_size, w2=self.window_size)
        v = rearrange(v, 'b m d (w1 x) (w2 y) -> b m x y w1 w2 d', w1=self.window_size, w2=self.window_size)
        x = self.cross_attn2(self.norm(q), self.norm(k), self.norm(v)) + q
        x = self.ff2(self.norm(x)) + x
        x = rearrange(x, 'b m x y w1 w2 d -> b m d (w1 x) (w2 y)')

        return x


class SwapFusionBlock(nn.Module):
    """
    Swap Fusion Block contains window attention and grid attention.
    """

    def __init__(self,
                 input_dim,
                 mlp_dim,
                 dim_head,
                 window_size,
                 agent_size,
                 drop_out):
        super(SwapFusionBlock, self).__init__()
        # b = batch * max_cav
        # self.cross_block = CrossAttentionBlock(input_dim, mlp_dim,
        #                                        dim_head, window_size,
        #                                        agent_size//2, drop_out)

        self.block = nn.Sequential(
            Rearrange('b m d (x w1) (y w2) -> b m x y w1 w2 d',
                      w1=window_size, w2=window_size),
            PreNormResidual(input_dim, Attention(input_dim, dim_head, drop_out,
                                                 agent_size, window_size)),
            PreNormResidual(input_dim,
                            FeedForward(input_dim, mlp_dim, drop_out)),
            Rearrange('b m x y w1 w2 d -> b m d (x w1) (y w2)'),

            Rearrange('b m d (w1 x) (w2 y) -> b m x y w1 w2 d',
                      w1=window_size, w2=window_size),
            PreNormResidual(input_dim, Attention(input_dim, dim_head, drop_out,
                                                 agent_size, window_size)),
            PreNormResidual(input_dim,
                            FeedForward(input_dim, mlp_dim, drop_out)),
            Rearrange('b m x y w1 w2 d -> b m d (w1 x) (w2 y)'),
        )

    def forward(self, x):
        # todo: add mask operation later for mulit-agents
        # v = k.clone()
        # x = self.cross_block(q, k, v)
        x = self.block(x)
        return x


class CrossAttnFusion(nn.Module):
    """
    Data rearrange -> swap block -> mlp_head
    """

    def __init__(self, args):
        super(CrossAttnFusion, self).__init__()

        self.layers = nn.ModuleList([])
        self.depth = args['depth']

        # block related
        input_dim = args['input_dim']
        mlp_dim = args['mlp_dim']
        agent_size = args['agent_size']
        window_size = args['window_size']
        drop_out = args['drop_out']
        dim_head = args['dim_head']

        self.mask = False
        if 'mask' in args:
            self.mask = args['mask']

        for i in range(self.depth):
            if i == 0: # We use cross attention for the first layer
                block = CrossAttentionBlock(input_dim, mlp_dim,
                                    dim_head, window_size,
                                    agent_size // 2, drop_out)
            else:
                block = SwapFusionBlock(input_dim,
                                        mlp_dim,
                                        dim_head,
                                        window_size,
                                        agent_size,
                                        drop_out)
            self.layers.append(block)

        # mlp head
        self.mlp_head = nn.Sequential(
            Reduce('b m d h w -> b d h w', 'mean'),
            Rearrange('b d h w -> b h w d'),
            nn.LayerNorm(input_dim),
            nn.Linear(input_dim, input_dim),
            Rearrange('b h w d -> b d h w')
        )
        self.conv1x1 = nn.Conv3d(in_channels=4,
                                 out_channels=2,
                                 kernel_size=(1, 1, 1),
                                 stride=(1, 1, 1),
                                 padding=(0, 0, 0))
        self.upsample_opt = nn.Sequential(
            nn.Conv2d(input_dim, input_dim, 3, 1, 1),
            nn.BatchNorm2d(input_dim),
            nn.ReLU(True),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(input_dim, input_dim, 3, 1, 1),
            nn.BatchNorm2d(input_dim),
            nn.ReLU(True)
        )

    def forward(self, pts_features, img_feature):
        x = img_feature
        for i, stage in enumerate(self.layers):
            pts_feature = pts_features[(self.depth - 1) - i]

            if i == 0:
                x = stage(pts_feature, img_feature)
            else:
                x = torch.cat((x, pts_feature), dim=1)
                x = stage(x)
            if i < self.depth - 1:  # TODO: The last layer does not need upsample
                if i > 0:
                    x = self.conv1x1(x)
                b, n, c, h, w = x.shape
                x = rearrange(x, 'b n c h w -> (b n) c h w')
                x = self.upsample_opt(x)
                x = rearrange(x, '(b n) c h w -> b n c h w', b=b, n=n)

        return self.mlp_head(x)


if __name__ == "__main__":
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    args = {'input_dim': 128,
            'mlp_dim': 256,
            'agent_size': 2,
            'window_size': 8,
            'dim_head': 32,
            'drop_out': 0.1,
            'depth': 3,
            'mask': False,
            'upsample_conv': {
                'in_channels': 128,
                'out_channels': 128,
                'kernel_size': 2,
                'stride': 2,
                'padding': 0
            }}

    fusion_net = V2icoopFuse(args)
    fusion_net.to(device)

    img_features = torch.rand(4, 1, 128, 32, 32)
    pts_features = [torch.rand(4, 1, 128, 128, 128),
                    torch.rand(4, 1, 128, 64, 64),
                    torch.rand(4, 1, 128, 32, 32)]
    img_features = img_features.to(device)
    pts_features = [pts_feature.to(device) for pts_feature in pts_features]

    output = fusion_net(pts_features, img_features)
    print(output)
