from inspect import signature

import math
import warnings
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F

from mmcv.cnn import xavier_init
from mmcv.runner import force_fp32, auto_fp16
from mmcv.utils import build_from_cfg
from mmcv.cnn.bricks.registry import PLUGIN_LAYERS
from mmdet.models import (
    DETECTORS,
    BaseDetector,
    build_backbone,
    build_head,
    build_neck,
    build_loss,
)
from einops.layers.torch import Rearrange
from einops import rearrange, repeat

from .grid_mask import GridMask
# from .attention import FlowAttention, FlowCrossAttention, trunc_normal_
from .world_model_hd import RSSM

try:
    from ..ops import feature_maps_format
    DAF_VALID = True
except:
    DAF_VALID = False


def _trunc_normal_(tensor, mean, std, a, b):
    # Cut & paste from PyTorch official master until it's in a few official releases - RW
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    def norm_cdf(x):
        # Computes standard normal cumulative distribution function
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    if (mean < a - 2 * std) or (mean > b + 2 * std):
        warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
                      "The distribution of values may be incorrect.",
                      stacklevel=2)

    # Values are generated by using a truncated uniform distribution and
    # then using the inverse CDF for the normal distribution.
    # Get upper and lower cdf values
    l = norm_cdf((a - mean) / std)
    u = norm_cdf((b - mean) / std)

    # Uniformly fill tensor with values from [l, u], then translate to
    # [2l-1, 2u-1].
    tensor.uniform_(2 * l - 1, 2 * u - 1)

    # Use inverse cdf transform for normal distribution to get truncated
    # standard normal
    tensor.erfinv_()

    # Transform to proper mean, std
    tensor.mul_(std * math.sqrt(2.))
    tensor.add_(mean)

    # Clamp to ensure it's in the proper range
    tensor.clamp_(min=a, max=b)
    return tensor


def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
    r"""Fills the input Tensor with values drawn from a truncated
    normal distribution. The values are effectively drawn from the
    normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
    with values outside :math:`[a, b]` redrawn until they are within
    the bounds. The method used for generating the random values works
    best when :math:`a \leq \text{mean} \leq b`.

    NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are
    applied while sampling the normal with mean/std applied, therefore a, b args
    should be adjusted to match the range of mean, std args.

    Args:
        tensor: an n-dimensional `torch.Tensor`
        mean: the mean of the normal distribution
        std: the standard deviation of the normal distribution
        a: the minimum cutoff value
        b: the maximum cutoff value
    Examples:
        >>> w = torch.empty(3, 5)
        >>> nn.init.trunc_normal_(w)
    """
    with torch.no_grad():
        return _trunc_normal_(tensor, mean, std, a, b)


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

class FlowAttention(nn.Module):
    def __init__(self, dim, heads = 4, dim_head = 64, dropout = 0., with_res=False, channel_fuse=False):
        super().__init__()
        inner_dim = dim_head *  heads if not channel_fuse else dim_head
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm_q = nn.LayerNorm(dim)
        self.norm_k = nn.LayerNorm(dim)
        self.norm_v = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_k = nn.Linear(dim, inner_dim, bias = False)
        self.to_v = nn.Linear(dim, inner_dim, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

        self.channel_fuse = channel_fuse
        self.with_res = with_res
        if self.with_res:
            self.res_mlp = FeedForward(dim, inner_dim, dropout)
            self.res_norm = nn.LayerNorm(dim)

    def forward(self, x, pos=None):
        if self.channel_fuse:
            # x = rearrange(x, 'b n (h d) -> b n h d', h = self.heads)

            if pos is not None:
                q = x + pos
                k = x + pos
                v = x
            else:
                q, k, v = x, x, x
            q = self.to_q(self.norm_q(q))
            k = self.to_k(self.norm_k(k))
            v = self.to_v(self.norm_v(v))
            
            # q, k, v = map(lambda t: rearrange(t, 'b n h d -> b n h d', h = self.heads), [q,k,v])

            dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

            attn = self.attend(dots)
            attn = self.dropout(attn)
            # print('\nself fuse', q.shape, k.shape, dots.shape)

            out = torch.matmul(attn, v)
            out = self.to_out(out)
            
            if not self.with_res:
                return out  # .flatten(2,3)
            else:
                out = x + self.res_mlp(out+x)
                return self.res_norm(out)  # .flatten(2,3)
        else:
            if pos is not None:
                q = x + pos
                k = x + pos
                v = x
            else:
                q, k, v = x, x, x
            q = self.to_q(self.norm_q(q))
            k = self.to_k(self.norm_k(k))
            v = self.to_v(self.norm_v(v))

            q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), [q,k,v])

            dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
            # print('\nself not fuse', q.shape, k.shape, dots.shape)

            attn = self.attend(dots)
            attn = self.dropout(attn)

            out = torch.matmul(attn, v)
            out = self.to_out(rearrange(out, 'b h n d -> b n (h d)'))
            
            if not self.with_res:
                return out
            else:
                out = x + self.res_mlp(out+x)
                return self.res_norm(out)

class FlowCrossAttention(nn.Module):
    def __init__(self, dim, heads = 4, dim_head = 64, dropout = 0., with_res=False, channel_fuse=False):
        super().__init__()
        inner_dim = dim_head *  heads if not channel_fuse else dim_head
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm_q = nn.LayerNorm(dim)
        self.norm_k = nn.LayerNorm(dim)
        self.norm_v = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(dropout)

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_k = nn.Linear(dim, inner_dim, bias = False)
        self.to_v = nn.Linear(dim, inner_dim, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

        self.channel_fuse = channel_fuse
        self.with_res = with_res
        if self.with_res:
            self.res_mlp = FeedForward(dim, inner_dim, dropout)
            self.res_norm = nn.LayerNorm(dim)

    def forward(self, q_ini, kv_ini, q_pos=None, k_pos=None, flatten=True):
        if self.channel_fuse:
            # q_init = rearrange(q_init, 'b n (h d) -> b n h d', h = self.heads)
            # kv_init = rearrange(kv_init, 'b n (h d) -> b n h d', h = self.heads)

            if q_pos is not None and k_pos is not None:
                # q_pos = rearrange(q_pos, 'b n (h d) -> b n h d', h = self.heads)
                # k_pos = rearrange(k_pos, 'b n (h d) -> b n h d', h = self.heads)
                q = q_ini + q_pos
                k = kv_ini + k_pos
                v = kv_ini
            else:
                q, k, v = q_ini, kv_ini, kv_ini
            q = self.to_q(self.norm_q(q))
            k = self.to_k(self.norm_k(k))
            v = self.to_v(self.norm_v(v))

            q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h = self.heads), [q,k,v])

            dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
            # print('\ncross fuse', q.shape, k.shape, dots.shape)

            attn = self.attend(dots)
            attn = self.dropout(attn)

            out = torch.matmul(attn, v)
            out = rearrange(out, 'b h n d -> b n h d')
            out = self.to_out(out)
            if self.with_res:
                out = q_ini + self.res_mlp(out+q_ini)
                out = self.res_norm(out)
            return out.flatten(2,3) if flatten else out
        else:
            if q_pos is not None and k_pos is not None:
                q = q_ini + q_pos
                k = kv_ini + k_pos
                v = kv_ini
            else:
                q, k, v = q_ini, kv_ini, kv_ini
            q = self.to_q(self.norm_q(q))
            k = self.to_k(self.norm_k(k))
            v = self.to_v(self.norm_v(v))

            q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), [q,k,v])

            dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

            attn = self.attend(dots)
            attn = self.dropout(attn)

            out = torch.matmul(attn, v)
            out = rearrange(out, 'b h n d -> b n (h d)')
            out = self.to_out(out)
            if self.with_res:
                out = q_ini + self.res_mlp(out+q_ini)
                return self.res_norm(out)
            return out


@DETECTORS.register_module()
class FlowAD(BaseDetector):
    def __init__(
        self,
        img_backbone,
        head,
        img_neck=None,
        init_cfg=None,
        train_cfg=None,
        test_cfg=None,
        pretrained=None,
        use_grid_mask=True,
        use_deformable_func=False,
        depth_branch=None,
        embed_dims=256,
        flow_patches=[8,4,2,1],
        flow_ids=[0,1,2,3],
        flow_attn_ids=[2,3],
        flow_pad_single=1,
        flow_num=22,
        spat_size=6,
        ego_dim=18,
        loss_wm=dict(
            type='WMLoss',
            loss_weight=0.25,
        ),
        loss_wm_spatial=dict(
            type='WMLoss',
            loss_weight=0.25,
        ),
        flow_wm_ids = [2,3],
    ):
        super(FlowAD, self).__init__(init_cfg=init_cfg)
        if pretrained is not None:
            backbone.pretrained = pretrained
        self.img_backbone = build_backbone(img_backbone)
        if img_neck is not None:
            self.img_neck = build_neck(img_neck)
        self.head = build_head(head)
        self.use_grid_mask = use_grid_mask
        if use_deformable_func:
            assert DAF_VALID, "deformable_aggregation needs to be set up."
        self.use_deformable_func = use_deformable_func
        if depth_branch is not None:
            self.depth_branch = build_from_cfg(depth_branch, PLUGIN_LAYERS)
        else:
            self.depth_branch = None
        if use_grid_mask:
            self.grid_mask = GridMask(
                True, True, rotate=1, offset=False, ratio=0.5, mode=1, prob=0.7
            ) 
        
        self.flow_patches = flow_patches
        self.flow_ids = flow_ids
        self.flow_attn_ids = flow_attn_ids
        self.flow_wm_ids = flow_wm_ids
        self.flow_pad_single = flow_pad_single
        self.flow_pad_sum = 1 + 2 * flow_pad_single
        self.ids_pad_left_right = torch.tensor([[5,1],[0,2],[1,3],[2,4],[3,5],[4,0]], dtype=torch.long, requires_grad=False)
        self.flow_num = flow_num
        self.spat_size = spat_size
        self.frame_num = 1  # 8
        self.embed_dims = embed_dims

        self.flow_mlps = nn.ModuleList([
            nn.Linear(self.flow_pad_sum*self.flow_patches[flow_id], self.flow_patches[flow_id]) for flow_id in self.flow_ids
        ])
        self.flow_attns = nn.ModuleList([
            FlowAttention(dim=embed_dims, heads=4, dim_head=embed_dims//4, dropout=0.) for flow_attn_id in self.flow_attn_ids
        ])

        self.world_models_temporal = nn.ModuleList([
            RSSM(embedding_dim=embed_dims) for flow_wm_id in self.flow_wm_ids
        ])
        self.world_models_spatial = nn.ModuleList([
            RSSM(embedding_dim=embed_dims) for flow_wm_id in self.flow_wm_ids
        ])
        self.patch_selfattn = FlowAttention(dim=embed_dims, heads=self.flow_num, dim_head=embed_dims, dropout=0., with_res=False, channel_fuse=True)
        self.patch_crossattn_local = FlowCrossAttention(dim=embed_dims, heads=self.flow_num, dim_head=embed_dims, dropout=0., with_res=False, channel_fuse=True)

        self.patch_crossattn_temporal = FlowCrossAttention(dim=embed_dims, heads=self.flow_num, dim_head=embed_dims, dropout=0., with_res=False, channel_fuse=True)
        self.patch_crossattn_spatial = FlowCrossAttention(dim=embed_dims, heads=self.frame_num, dim_head=embed_dims, dropout=0., with_res=False, channel_fuse=True)

        self.flow_wm_convs = nn.ModuleList([
            nn.Conv2d(embed_dims*2, embed_dims, kernel_size=3, padding=1) for flow_wm_id in self.flow_wm_ids
        ])
        self.flow_wm_convs_spatial = nn.ModuleList([
            nn.Conv2d(embed_dims*2, embed_dims, kernel_size=3, padding=1) for flow_wm_id in self.flow_wm_ids
        ])

        # self.can_bus_mlp = nn.Sequential(
        #     nn.Linear(ego_dim, self.embed_dims // 2),
        #     nn.ReLU(inplace=True),
        #     nn.Linear(self.embed_dims // 2, self.embed_dims),
        #     nn.ReLU(inplace=True),
        #     nn.LayerNorm(self.embed_dims)
        # )
        # xavier_init(self.can_bus_mlp, distribution='uniform', bias=0.)

        self.flow_query_xys = self.generate_flow_xys()
        self.flow_query_embed = nn.Sequential(
            nn.Linear(2, self.embed_dims // 2),
            nn.ReLU(inplace=True),
            nn.Linear(self.embed_dims // 2, self.embed_dims),
            nn.ReLU(inplace=True),
            nn.LayerNorm(self.embed_dims)
        )
        xavier_init(self.flow_query_embed, distribution='uniform', bias=0.)

        self.flow_spatial_pe_map = nn.Sequential(
            nn.Linear(2, self.embed_dims // 2),
            nn.ReLU(inplace=True),
            nn.Linear(self.embed_dims // 2, self.embed_dims),
            nn.ReLU(inplace=True),
            nn.LayerNorm(self.embed_dims)
        )
        xavier_init(self.flow_spatial_pe_map, distribution='uniform', bias=0.)

        self.flow_temporal_pe_map = nn.Sequential(
            nn.Linear(ego_dim, self.embed_dims // 2),
            nn.ReLU(inplace=True),
            nn.Linear(self.embed_dims // 2, self.embed_dims),
            nn.ReLU(inplace=True),
            nn.LayerNorm(self.embed_dims)
        )
        xavier_init(self.flow_temporal_pe_map, distribution='uniform', bias=0.)


        self.loss_wm_temporal = build_loss(loss_wm)
        self.loss_wm_spatial = build_loss(loss_wm_spatial)
    
    def generate_flow_xys(self,):
        pc_range = 51.2
        tan_30 = math.tan(math.radians(30))
        flow_items = self.flow_num
        p_x = pc_range * tan_30

        f_range_x = [-p_x, p_x]
        f_range_y = [pc_range, pc_range]

        fr_range_x = [p_x, pc_range]
        fr_range_y = [pc_range, 0]

        br_range_x = [pc_range, p_x]
        br_range_y = [0, -pc_range]

        b_range_x = [-p_x, p_x]
        b_range_y = [-pc_range, -pc_range]

        bl_range_x = [-p_x, -pc_range]
        bl_range_y = [-pc_range, 0]

        fl_range_x = [-pc_range, -p_x]
        fl_range_y = [0, pc_range]

        flow_xys_all = []
        for range_x, range_y in zip([f_range_x, fr_range_x, br_range_x, b_range_x, bl_range_x, fl_range_x],
                                    [f_range_y, fr_range_y, br_range_y, b_range_y, bl_range_y, fl_range_y]):
            step_x = (range_x[1]-range_x[0])/flow_items
            step_y = (range_y[1]-range_y[0])/flow_items
            xys = torch.stack([
                torch.arange(range_x[0], range_x[1], step_x)+step_x/2 if step_x != 0 else torch.ones(flow_items)*range_x[0],
                torch.arange(range_y[0], range_y[1], step_y)+step_y/2 if step_y != 0 else torch.ones(flow_items)*range_y[0],
                                ], dim=-1)
            flow_xys_all.append(xys)

        flow_xys_all = torch.stack(flow_xys_all, dim=0)
        return flow_xys_all

    @auto_fp16(apply_to=("img",), out_fp32=True)
    def extract_feat(self, img, return_depth=False, metas=None):
        bs = img.shape[0]
        if img.dim() == 5:  # multi-view
            num_cams = img.shape[1]
            img = img.flatten(end_dim=1)
        else:
            num_cams = 1
        if self.use_grid_mask:
            img = self.grid_mask(img)
        if "metas" in signature(self.img_backbone.forward).parameters:
            feature_maps = self.img_backbone(img, num_cams, metas=metas)
        else:
            feature_maps = self.img_backbone(img)
        if self.img_neck is not None:
            feature_maps = list(self.img_neck(feature_maps))
        for i, feat in enumerate(feature_maps):
            feature_maps[i] = torch.reshape(
                feat, (bs, num_cams) + feat.shape[1:]
            )

        # flow operation
        device = feature_maps[0].device
        B, TN, C, _, _ = feature_maps[0].shape
        N = 6
        T = TN // N
        flow_spatial_pe = self.flow_spatial_pe_map(self.flow_query_xys.to(device)).unsqueeze(0)  # [1, N, 22, C]
        flow_spatial_pe_local = rearrange(flow_spatial_pe.repeat(B*T, 1, 1, 1), 'b n y c -> (b n) y c')  # [BTN, 22, C]

        # pad spatial PE for 3*patch self-attn
        flow_spatial_pe_3p = flow_spatial_pe_local.transpose(1,2).reshape(B, T, N, C, 1, self.flow_num)
        pe_pad_left = flow_spatial_pe_3p[:,:,self.ids_pad_left_right[:,0],:,:,-self.flow_pad_single:]
        pe_pad_right = flow_spatial_pe_3p[:,:,self.ids_pad_left_right[:,1],:,:,:self.flow_pad_single]
        # [B, T, N, C, 1, 24]
        flow_spatial_pe_3p = torch.cat([pe_pad_left, flow_spatial_pe_3p, pe_pad_right], dim=-1)
        # [BTN, C*1*3, 1*22]
        flow_spatial_pe_3p_unfold = F.unfold(flow_spatial_pe_3p.flatten(0,2), kernel_size=(1, self.flow_pad_sum), stride=1)
        # [BTN, 22, 3, C]
        flow_spatial_pe_3p = rearrange(flow_spatial_pe_3p_unfold, 'b (c l) y -> b y l c', c=C, l=self.flow_pad_sum)

        ego_pose_temporal = metas['ego_pose'].to(flow_spatial_pe.dtype)
        flow_temporal_pe = self.flow_temporal_pe_map(ego_pose_temporal)  # [B, T, C]
        flow_temporal_pe_local = flow_temporal_pe.reshape(B, T, 1, 1, C).repeat(1, 1, N, self.flow_num, 1).flatten(0,2)  # [BTN, 22, C]

        # outs = {}
        outputs_wm_temporal = []
        outputs_wm_spatial = []

        mlvl_feats_flow = []
        mlvl_feats_flow_unfold = []
        mlvl_feats_flow_unfold_fuse = []

        H_W_Ps = []
        for idx, mlvl_feat in enumerate(feature_maps):
            if idx not in self.flow_ids:
                mlvl_feats_flow.append(mlvl_feat)
                continue
            flow_patch = self.flow_patches[idx]
            B, TN, C, H, W = mlvl_feat.shape
            H_W_Ps.append([H,W,flow_patch])
            mlvl_feat_reshaped = mlvl_feat.reshape(B, T, N, C, H, W)

            # pad two sides of each view
            feat_pad_left = mlvl_feat_reshaped[:,:,self.ids_pad_left_right[:,0],:,:,-self.flow_pad_single*flow_patch:]
            feat_pad_right = mlvl_feat_reshaped[:,:,self.ids_pad_left_right[:,1],:,:,:self.flow_pad_single*flow_patch]
            mlvl_feat_pad_reshaped = torch.cat([feat_pad_left, mlvl_feat_reshaped, feat_pad_right], dim=-1)

            # [BTN, C*H*3P, 1*22]
            mlvl_feat_pad_unfold = F.unfold(mlvl_feat_pad_reshaped.flatten(0,2), kernel_size=(H, self.flow_pad_sum*flow_patch), stride=flow_patch)
            # [BTN, C, H, 3P, 22]
            mlvl_feat_flow_unfold = mlvl_feat_pad_unfold.reshape(B*TN, C, H, self.flow_pad_sum*flow_patch, W//flow_patch)

            if idx in self.flow_attn_ids:
                # interaction inside 3 patches
                mlvl_feat_flow_unfold = rearrange(mlvl_feat_flow_unfold, 'b c h p y -> (b h y) p c')
                # [BTN, 22, 3, C] -> [BTN * H * 22, 3P, C]
                flow_spatial_pe_3p_selfattn = flow_spatial_pe_3p.unsqueeze(3).unsqueeze(1).repeat(1,H,1,1,flow_patch,1).flatten(0,2).flatten(1,2)
                # [BTN * H * 22, 3P, C]
                mlvl_feat_flow_unfold = self.flow_attns[self.flow_attn_ids.index(idx)](mlvl_feat_flow_unfold, pos=flow_spatial_pe_3p_selfattn)
                # [BTN, C, H, 3P, 22]
                mlvl_feat_flow_unfold = rearrange(mlvl_feat_flow_unfold, '(b h y) p c -> b c h p y', b=B*TN, h=H, y=W//flow_patch)

            # [BTN, C, H, 22, 3P] -> [BTN, C, H, 22, P]
            mlvl_feat_flow_unfold = self.flow_mlps[self.flow_ids.index(idx)](mlvl_feat_flow_unfold.transpose(3,4))
            mlvl_feats_flow_unfold.append(mlvl_feat_flow_unfold)
            # [BTN, C, 1, 22, patch]
            mlvl_feats_flow_unfold_fuse.append(mlvl_feat_flow_unfold.mean(dim=2))
        
        # [BTN, C, 22, 8+4+2+1]
        mlvl_feats_flow_unfold_fuse = torch.cat(mlvl_feats_flow_unfold_fuse, dim=-1)
        # [BTN, 22, C] -> [BTN, 1, 22, C]
        flow_spatial_pe_selfattn = flow_spatial_pe_local.unsqueeze(1)
        # [BTN, 8+4+2+1, 22, C]
        mlvl_feats_flow_unfold_fuse = self.patch_selfattn(rearrange(mlvl_feats_flow_unfold_fuse, 'b c y p -> b p y c'), pos=flow_spatial_pe_selfattn)

        # [N, 22, C] -> [BTN, 22, C]
        flow_query = self.flow_query_embed(self.flow_query_xys.to(mlvl_feat.device)).unsqueeze(0).repeat(B*T, 1, 1, 1).flatten(0,1)
        # get local messages from [8+4+2+1] of [BTN, 8+4+2+1, 22, C]
        flow_query = self.patch_crossattn_local(flow_query.unsqueeze(1), mlvl_feats_flow_unfold_fuse).reshape(B*TN, W//flow_patch, C)
        # print('output', flow_query.shape)

        # get temporal messages from [T] of [BTN, 8+4+2+1, 22, C]
        # prompt for temporal world model
        flow_temporal_pe_crossattn = rearrange(flow_temporal_pe_local, '(b t n) y c -> (b n) t y c', b=B, t=T, n=6)
        # [BN, T, 22, C]
        flow_query_temporal = self.patch_crossattn_temporal(
            rearrange(flow_query.squeeze(1), '(b t n) y c -> (b n) t y c', b=B, t=T, n=6),
            rearrange(mlvl_feats_flow_unfold_fuse.mean(dim=1), '(b t n) y c -> (b n) t y c', b=B, t=T, n=6, y=self.flow_num),
            q_pos=flow_temporal_pe_crossattn,
            k_pos=flow_temporal_pe_crossattn,
            flatten=False
        )
        # print('output', flow_query_temporal.shape)

        # get spatial messages from [22] of [BTN, 8+4+2+1, 22, C]
        # prompt for spatial world model
        flow_spatial_pe_crossattn = rearrange(flow_spatial_pe_local, '(b t n) y c -> (b n) y t c', b=B, t=T, n=6)
        # [BN, 22, T, C]
        flow_query_spatial = self.patch_crossattn_spatial(
            rearrange(flow_query.squeeze(1), '(b t n) y c -> (b n) y t c', b=B, t=T, n=6),
            rearrange(mlvl_feats_flow_unfold_fuse.mean(dim=1), '(b t n) y c -> (b n) y t c', b=B, t=T, n=6, y=self.flow_num),
            q_pos=flow_spatial_pe_crossattn,
            k_pos=flow_spatial_pe_crossattn,
            flatten=False
        )
        # print('output', flow_query_spatial.shape)

        # split left/right spatial flow from ego forward_point
        forward_point = 11
        flow_query_spatial = rearrange(flow_query_spatial, '(b n) y t c -> (b t n) y c', b=B, t=T, n=6)
        flow_query_spat = rearrange(flow_query_spatial, '(b t n) y c -> b (n y) t c', b=B, t=T, n=6)
        flow_query_spat_l = torch.cat([flow_query_spat[:, forward_point+3*self.flow_num:], flow_query_spat[:, :forward_point]], dim=1).flip(dims=[1])
        flow_query_spat_r = flow_query_spat[:, forward_point:forward_point+3*self.flow_num]
        # [B, 2, 66, T, C]
        flow_query_spat = torch.stack([flow_query_spat_l, flow_query_spat_r], dim=1)
        flow_query_spat = rearrange(flow_query_spat, 'b n (y p) t c -> (b n p) y t c', y=self.flow_num*3//self.spat_size, p=self.spat_size)

        for idx, (mlvl_feat_flow_unfold, (H, W, flow_patch)) in enumerate(zip(mlvl_feats_flow_unfold, H_W_Ps)):
            if idx in self.flow_wm_ids:
                layer_idx = self.flow_wm_ids.index(idx)

                # temporal world model
                # [BTN, C, H, W/patch, patch]
                output_wm, mlvl_feat_flow_unfold_fut = self.world_models_temporal[layer_idx](
                    rearrange(mlvl_feat_flow_unfold.mean(dim=2).mean(dim=-1), '(b t n) c y -> (b n) t y c', b=B, t=T, n=6).flip(dims=[1]),
                    flow_query_temporal.flip(dims=[1]))
                outputs_wm_temporal.append(output_wm)
                
                # enhance temporal messages with temporal prediction
                mlvl_feat_flow_unfold_fuse = torch.cat([rearrange(mlvl_feat_flow_unfold, 'b c h y p -> b h y p c'),
                                                        rearrange(mlvl_feat_flow_unfold_fut.flip(dims=[1]), '(b n) t y c -> (b t n) y c', b=B, n=6, y=W//flow_patch).unsqueeze(2).unsqueeze(1).repeat(1,H,1,flow_patch,1)], dim=-1)
                mlvl_feat_flow_unfold = rearrange(mlvl_feat_flow_unfold_fuse, 'b h y p c -> b c h y p', b=B*TN, h=H, y=W//flow_patch)  # [BTN, C, H, patch, W/patch]
                mlvl_feat_flow_fold = F.fold(mlvl_feat_flow_unfold.transpose(3,4).flatten(1,3), output_size=(H,W), kernel_size=(H, flow_patch), stride=flow_patch)
                mlvl_feat_flow_fold = self.flow_wm_convs[layer_idx](mlvl_feat_flow_fold)

                mlvl_feat_flow_unfold = F.unfold(mlvl_feat_flow_fold, kernel_size=(H, flow_patch), stride=flow_patch).reshape(B*TN, C, H, flow_patch, W//flow_patch).transpose(3,4)  # [btn, chp, y]
                
                # split left/right spatial flow from ego forward_point
                mlvl_feat_flow_unfold_spatial = rearrange(mlvl_feat_flow_unfold.mean(dim=2).mean(dim=-1), '(b t n) c y -> b (n y) t c', b=B, t=T, n=6)
                mlvl_feat_flow_unfold_spatial_l = torch.cat([mlvl_feat_flow_unfold_spatial[:, forward_point+3*self.flow_num:], mlvl_feat_flow_unfold_spatial[:, :forward_point]], dim=1).flip(dims=[1])
                mlvl_feat_flow_unfold_spatial_r = mlvl_feat_flow_unfold_spatial[:, forward_point:forward_point+3*self.flow_num]
                # [B, 2, 66, T, C]
                mlvl_feat_flow_unfold_spatial = torch.stack([mlvl_feat_flow_unfold_spatial_l, mlvl_feat_flow_unfold_spatial_r], dim=1)
                # split with spatial patch (len 6)
                mlvl_feat_flow_unfold_spatial = rearrange(mlvl_feat_flow_unfold_spatial, 'b n (y p) t c -> (b n p) y t c', y=self.flow_num*3//self.spat_size, p=self.spat_size)
                
                # spatial world model
                # [BTN, C, H, W/patch, patch]
                output_wm_spatial, mlvl_feat_flow_unfold_fut_spatial = self.world_models_spatial[layer_idx](
                    mlvl_feat_flow_unfold_spatial,
                    flow_query_spat)
                outputs_wm_spatial.append(output_wm_spatial)

                # enhance spatial messages with spatial prediction
                mlvl_feat_flow_unfold_fut_spatial = rearrange(mlvl_feat_flow_unfold_fut_spatial, '(b n p) y t c -> b n (y p) t c', b=B, n=2, p=self.spat_size)
                mlvl_feat_flow_unfold_fut_spatial = torch.cat([mlvl_feat_flow_unfold_fut_spatial[:,1], mlvl_feat_flow_unfold_fut_spatial[:,0].flip(dims=[1])], dim=1)
                mlvl_feat_flow_unfold_fut_spatial = torch.cat([mlvl_feat_flow_unfold_fut_spatial[:,-forward_point:], mlvl_feat_flow_unfold_fut_spatial[:,:-forward_point]], dim=1)
                mlvl_feat_flow_unfold_fut_spatial = rearrange(mlvl_feat_flow_unfold_fut_spatial, 'b (n y) t c -> (b t n) y c', n=6, y=self.flow_num)
                mlvl_feat_flow_unfold_fuse = torch.cat([rearrange(mlvl_feat_flow_unfold, 'b c h y p -> b h y p c'),
                                                        mlvl_feat_flow_unfold_fut_spatial.unsqueeze(2).unsqueeze(1).repeat(1,H,1,flow_patch,1)], dim=-1)
                mlvl_feat_flow_unfold = rearrange(mlvl_feat_flow_unfold_fuse, 'b h y p c -> b c h y p', b=B*TN, h=H, y=W//flow_patch)  # [BTN, C, H, patch*3, W/patch]
                mlvl_feat_flow_fold = F.fold(mlvl_feat_flow_unfold.transpose(3,4).flatten(1,3), output_size=(H,W), kernel_size=(H, flow_patch), stride=flow_patch)
                mlvl_feat_flow_fold = self.flow_wm_convs_spatial[layer_idx](mlvl_feat_flow_fold)
            else:
                mlvl_feat_flow_fold = F.fold(mlvl_feat_flow_unfold.transpose(3,4).flatten(1,3), output_size=(H,W), kernel_size=(H, flow_patch), stride=flow_patch)
            mlvl_feats_flow.append(mlvl_feat_flow_fold.reshape(B, TN, C, H, W))
        feature_maps = mlvl_feats_flow
        # raise EOFError
        
        if return_depth and self.depth_branch is not None:
            depths = self.depth_branch(feature_maps, metas.get("focal"))
        else:
            depths = None
        if self.use_deformable_func:
            feature_maps = feature_maps_format(feature_maps)
        if return_depth:
            return feature_maps, depths, outputs_wm_temporal, outputs_wm_spatial
        return feature_maps, outputs_wm_temporal, outputs_wm_spatial

    @force_fp32(apply_to=("img",))
    def forward(self, img, **data):
        if self.training:
            return self.forward_train(img, **data)
        else:
            return self.forward_test(img, **data)

    def forward_train(self, img, **data):
        feature_maps, depths, outputs_wm_temporal, outputs_wm_spatial = self.extract_feat(img, True, data)
        model_outs = self.head(feature_maps, data)
        output = self.head.loss(model_outs, data)
        if depths is not None and "gt_depth" in data:
            output["loss_dense_depth"] = self.depth_branch.loss(
                depths, data["gt_depth"]
            )
        
        losses_wm_temporal = []
        losses_wm_spatial = []
        for output_wm_temporal, output_wm_spatial in zip(outputs_wm_temporal, outputs_wm_spatial):
            value_t = torch.nan_to_num(self.loss_wm_temporal(output_wm_temporal['prior'], output_wm_temporal['posterior']))
            losses_wm_temporal.append(value_t)
            value_s = torch.nan_to_num(self.loss_wm_spatial(output_wm_spatial['prior'], output_wm_spatial['posterior']))
            losses_wm_spatial.append(value_s)
        # print(losses_wm_temporal, losses_wm_spatial)
        output['loss_wm_t'] = sum(losses_wm_temporal)
        output['loss_wm_s'] = sum(losses_wm_spatial)
        return output

    def forward_test(self, img, **data):
        if isinstance(img, list):
            return self.aug_test(img, **data)
        else:
            return self.simple_test(img, **data)

    def simple_test(self, img, **data):
        feature_maps, outputs_wm_temporal, outputs_wm_spatial = self.extract_feat(img, False, data)

        model_outs = self.head(feature_maps, data)
        results = self.head.post_process(model_outs, data)
        output = [{"img_bbox": result} for result in results]
        return output

    def aug_test(self, img, **data):
        # fake test time augmentation
        for key in data.keys():
            if isinstance(data[key], list):
                data[key] = data[key][0]
        return self.simple_test(img[0], **data)
