import numpy as np
import math
import warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet.models import HEADS
from mmcv.runner import force_fp32, auto_fp16
from mmcv.cnn import xavier_init
from mmdet.models.builder import build_loss
from mmdet.models.utils import build_transformer
from .matcher import HungarianMatcher
from .loss_utils import CE_ssc_loss, lovasz_softmax, get_voxel_decoder_loss_input

from einops.layers.torch import Rearrange
from einops import rearrange, repeat

from .world_model_hd import RSSM
from .losses import WMLoss

NUSC_CLASS_FREQ = np.array([
    944004, 1897170, 152386, 2391677, 16957802, 724139, 189027, 2074468, 413451, 2384460,
    5916653, 175883646, 4275424, 51393615, 61411620, 105975596, 116424404, 1892500630
])


def _trunc_normal_(tensor, mean, std, a, b):
    # Cut & paste from PyTorch official master until it's in a few official releases - RW
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    def norm_cdf(x):
        # Computes standard normal cumulative distribution function
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    if (mean < a - 2 * std) or (mean > b + 2 * std):
        warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
                      "The distribution of values may be incorrect.",
                      stacklevel=2)

    # Values are generated by using a truncated uniform distribution and
    # then using the inverse CDF for the normal distribution.
    # Get upper and lower cdf values
    l = norm_cdf((a - mean) / std)
    u = norm_cdf((b - mean) / std)

    # Uniformly fill tensor with values from [l, u], then translate to
    # [2l-1, 2u-1].
    tensor.uniform_(2 * l - 1, 2 * u - 1)

    # Use inverse cdf transform for normal distribution to get truncated
    # standard normal
    tensor.erfinv_()

    # Transform to proper mean, std
    tensor.mul_(std * math.sqrt(2.))
    tensor.add_(mean)

    # Clamp to ensure it's in the proper range
    tensor.clamp_(min=a, max=b)
    return tensor


def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
    r"""Fills the input Tensor with values drawn from a truncated
    normal distribution. The values are effectively drawn from the
    normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
    with values outside :math:`[a, b]` redrawn until they are within
    the bounds. The method used for generating the random values works
    best when :math:`a \leq \text{mean} \leq b`.

    NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are
    applied while sampling the normal with mean/std applied, therefore a, b args
    should be adjusted to match the range of mean, std args.

    Args:
        tensor: an n-dimensional `torch.Tensor`
        mean: the mean of the normal distribution
        std: the standard deviation of the normal distribution
        a: the minimum cutoff value
        b: the maximum cutoff value
    Examples:
        >>> w = torch.empty(3, 5)
        >>> nn.init.trunc_normal_(w)
    """
    with torch.no_grad():
        return _trunc_normal_(tensor, mean, std, a, b)


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

class FlowAttention(nn.Module):
    def __init__(self, dim, heads = 4, dim_head = 64, dropout = 0., with_res=False, channel_fuse=False):
        super().__init__()
        inner_dim = dim_head *  heads if not channel_fuse else dim_head
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm_q = nn.LayerNorm(dim)
        self.norm_k = nn.LayerNorm(dim)
        self.norm_v = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_k = nn.Linear(dim, inner_dim, bias = False)
        self.to_v = nn.Linear(dim, inner_dim, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

        self.channel_fuse = channel_fuse
        self.with_res = with_res
        if self.with_res:
            self.res_mlp = FeedForward(dim, inner_dim, dropout)
            self.res_norm = nn.LayerNorm(dim)

    def forward(self, x, pos=None):
        if self.channel_fuse:
            # x = rearrange(x, 'b n (h d) -> b n h d', h = self.heads)

            if pos is not None:
                q = x + pos
                k = x + pos
                v = x
            else:
                q, k, v = x, x, x
            q = self.to_q(self.norm_q(q))
            k = self.to_k(self.norm_k(k))
            v = self.to_v(self.norm_v(v))
            
            # q, k, v = map(lambda t: rearrange(t, 'b n h d -> b n h d', h = self.heads), [q,k,v])

            dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

            attn = self.attend(dots)
            attn = self.dropout(attn)
            # print('\nself fuse', q.shape, k.shape, dots.shape)

            out = torch.matmul(attn, v)
            out = self.to_out(out)
            
            if not self.with_res:
                return out  # .flatten(2,3)
            else:
                out = x + self.res_mlp(out+x)
                return self.res_norm(out)  # .flatten(2,3)
        else:
            if pos is not None:
                q = x + pos
                k = x + pos
                v = x
            else:
                q, k, v = x, x, x
            q = self.to_q(self.norm_q(q))
            k = self.to_k(self.norm_k(k))
            v = self.to_v(self.norm_v(v))

            q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), [q,k,v])

            dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
            # print('\nself not fuse', q.shape, k.shape, dots.shape)

            attn = self.attend(dots)
            attn = self.dropout(attn)

            out = torch.matmul(attn, v)
            out = self.to_out(rearrange(out, 'b h n d -> b n (h d)'))
            
            if not self.with_res:
                return out
            else:
                out = x + self.res_mlp(out+x)
                return self.res_norm(out)

class FlowCrossAttention(nn.Module):
    def __init__(self, dim, heads = 4, dim_head = 64, dropout = 0., with_res=False, channel_fuse=False):
        super().__init__()
        inner_dim = dim_head *  heads if not channel_fuse else dim_head
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm_q = nn.LayerNorm(dim)
        self.norm_k = nn.LayerNorm(dim)
        self.norm_v = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(dropout)

        self.to_q = nn.Linear(dim, inner_dim, bias = False)
        self.to_k = nn.Linear(dim, inner_dim, bias = False)
        self.to_v = nn.Linear(dim, inner_dim, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

        self.channel_fuse = channel_fuse
        self.with_res = with_res
        if self.with_res:
            self.res_mlp = FeedForward(dim, inner_dim, dropout)
            self.res_norm = nn.LayerNorm(dim)

    def forward(self, q_ini, kv_ini, q_pos=None, k_pos=None, flatten=True):
        if self.channel_fuse:
            # q_init = rearrange(q_init, 'b n (h d) -> b n h d', h = self.heads)
            # kv_init = rearrange(kv_init, 'b n (h d) -> b n h d', h = self.heads)

            if q_pos is not None and k_pos is not None:
                # q_pos = rearrange(q_pos, 'b n (h d) -> b n h d', h = self.heads)
                # k_pos = rearrange(k_pos, 'b n (h d) -> b n h d', h = self.heads)
                q = q_ini + q_pos
                k = kv_ini + k_pos
                v = kv_ini
            else:
                q, k, v = q_ini, kv_ini, kv_ini
            q = self.to_q(self.norm_q(q))
            k = self.to_k(self.norm_k(k))
            v = self.to_v(self.norm_v(v))

            q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h = self.heads), [q,k,v])

            dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
            # print('\ncross fuse', q.shape, k.shape, dots.shape)

            attn = self.attend(dots)
            attn = self.dropout(attn)

            out = torch.matmul(attn, v)
            out = rearrange(out, 'b h n d -> b n h d')
            out = self.to_out(out)
            if self.with_res:
                out = q_ini + self.res_mlp(out+q_ini)
                out = self.res_norm(out)
            return out.flatten(2,3) if flatten else out
        else:
            if q_pos is not None and k_pos is not None:
                q = q_ini + q_pos
                k = kv_ini + k_pos
                v = kv_ini
            else:
                q, k, v = q_ini, kv_ini, kv_ini
            q = self.to_q(self.norm_q(q))
            k = self.to_k(self.norm_k(k))
            v = self.to_v(self.norm_v(v))

            q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), [q,k,v])

            dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

            attn = self.attend(dots)
            attn = self.dropout(attn)

            out = torch.matmul(attn, v)
            out = rearrange(out, 'b h n d -> b n (h d)')
            out = self.to_out(out)
            if self.with_res:
                out = q_ini + self.res_mlp(out+q_ini)
                return self.res_norm(out)
            return out


@HEADS.register_module()
class FlowADHeadFlow(nn.Module): 
    def __init__(self,
                 transformer=None,
                 class_names=None,
                 embed_dims=None,
                 occ_size=None,
                 pc_range=None,
                 loss_cfgs=None,
                 panoptic=False,
                 flow_patches=[8,4,2,1],
                 flow_ids=[0,1,2,3],
                 flow_attn_ids=[2,3],
                 flow_wm_ids=[2,3],
                 flow_pad_single=1,
                 flow_num=22,
                 spat_size=6,
                 loss_wm=dict(
                    type='WMLoss',
                    loss_weight=1.0,
                ),
                loss_wm_spatial=dict(
                    type='WMLoss',
                    loss_weight=0.25,
                ),
                 **kwargs):
        super(FlowADHeadFlow, self).__init__()
        self.num_classes = len(class_names)
        self.class_names = class_names
        self.pc_range = pc_range
        self.occ_size = occ_size
        self.embed_dims = embed_dims
        self.score_threshold = 0.3
        self.overlap_threshold = 0.8
        self.panoptic = panoptic

        self.transformer = build_transformer(transformer)
        self.criterions = {k: build_loss(loss_cfg) for k, loss_cfg in loss_cfgs.items()}
        self.matcher = HungarianMatcher(cost_class=2.0, cost_mask=5.0, cost_dice=5.0)

        self.class_weights = torch.from_numpy(1 / np.log(NUSC_CLASS_FREQ + 0.001))

        self.flow_num = flow_num
        self.spat_size = spat_size

        self.flow_patches = flow_patches
        self.flow_ids = flow_ids
        self.flow_attn_ids = flow_attn_ids
        self.flow_wm_ids = flow_wm_ids
        self.flow_pad_single = flow_pad_single
        self.flow_pad_sum = 1 + 2 * flow_pad_single
        self.ids_pad_left_right = torch.tensor([[5,1],[0,2],[1,3],[2,4],[3,5],[4,0]], dtype=torch.long, requires_grad=False)
        self.frame_num = 8

        in_channels = embed_dims

        # self.can_bus_mlp = nn.Sequential(
        #     nn.Linear(18, self.embed_dims // 2),
        #     nn.ReLU(inplace=True),
        #     nn.Linear(self.embed_dims // 2, self.embed_dims),
        #     nn.ReLU(inplace=True),
        #     nn.LayerNorm(self.embed_dims)
        # )
        # xavier_init(self.can_bus_mlp, distribution='uniform', bias=0.)

        self.flow_query_xys = self.generate_flow_xys()
        self.flow_query_embed = nn.Sequential(
            nn.Linear(2, self.embed_dims // 2),
            nn.ReLU(inplace=True),
            nn.Linear(self.embed_dims // 2, self.embed_dims),
            nn.ReLU(inplace=True),
            nn.LayerNorm(self.embed_dims)
        )
        xavier_init(self.flow_query_embed, distribution='uniform', bias=0.)

        self.flow_spatial_pe_map = nn.Sequential(
            nn.Linear(2, self.embed_dims // 2),
            nn.ReLU(inplace=True),
            nn.Linear(self.embed_dims // 2, self.embed_dims),
            nn.ReLU(inplace=True),
            nn.LayerNorm(self.embed_dims)
        )
        xavier_init(self.flow_spatial_pe_map, distribution='uniform', bias=0.)

        self.flow_temporal_pe_map = nn.Sequential(
            nn.Linear(18, self.embed_dims // 2),
            nn.ReLU(inplace=True),
            nn.Linear(self.embed_dims // 2, self.embed_dims),
            nn.ReLU(inplace=True),
            nn.LayerNorm(self.embed_dims)
        )
        xavier_init(self.flow_temporal_pe_map, distribution='uniform', bias=0.)

        self.flow_mlps = nn.ModuleList([
            nn.Linear(self.flow_pad_sum*self.flow_patches[flow_id], self.flow_patches[flow_id]) for flow_id in self.flow_ids
        ])
        self.flow_attns = nn.ModuleList([
            FlowAttention(dim=in_channels, heads=4, dim_head=in_channels//4, dropout=0.) for flow_attn_id in self.flow_attn_ids
        ])

        self.world_models_temporal = nn.ModuleList([
            RSSM(embedding_dim=in_channels) for flow_wm_id in self.flow_wm_ids
        ])
        self.world_models_spatial = nn.ModuleList([
            RSSM(embedding_dim=in_channels) for flow_wm_id in self.flow_wm_ids
        ])
        self.patch_selfattn = FlowAttention(dim=in_channels, heads=self.flow_num, dim_head=in_channels, dropout=0., with_res=False, channel_fuse=True)
        self.patch_crossattn_local = FlowCrossAttention(dim=in_channels, heads=self.flow_num, dim_head=in_channels, dropout=0., with_res=False, channel_fuse=True)

        self.patch_crossattn_temporal = FlowCrossAttention(dim=in_channels, heads=self.flow_num, dim_head=in_channels, dropout=0., with_res=False, channel_fuse=True)
        self.patch_crossattn_spatial = FlowCrossAttention(dim=in_channels, heads=self.frame_num, dim_head=in_channels, dropout=0., with_res=False, channel_fuse=True)

        self.flow_wm_convs = nn.ModuleList([
            nn.Conv2d(in_channels*2, in_channels, kernel_size=3, padding=1) for flow_wm_id in self.flow_wm_ids
        ])
        self.flow_wm_convs_spatial = nn.ModuleList([
            nn.Conv2d(in_channels*2, in_channels, kernel_size=3, padding=1) for flow_wm_id in self.flow_wm_ids
        ])

        self.loss_wm_temporal = build_loss(loss_wm)
        self.loss_wm_spatial = build_loss(loss_wm_spatial)
    
    def generate_flow_xys(self,):
        pc_range = 51.2
        tan_30 = math.tan(math.radians(30))
        flow_items = self.flow_num
        p_x = pc_range * tan_30

        f_range_x = [-p_x, p_x]
        f_range_y = [pc_range, pc_range]

        fr_range_x = [p_x, pc_range]
        fr_range_y = [pc_range, 0]

        br_range_x = [pc_range, p_x]
        br_range_y = [0, -pc_range]

        b_range_x = [-p_x, p_x]
        b_range_y = [-pc_range, -pc_range]

        bl_range_x = [-p_x, -pc_range]
        bl_range_y = [-pc_range, 0]

        fl_range_x = [-pc_range, -p_x]
        fl_range_y = [0, pc_range]

        flow_xys_all = []
        for range_x, range_y in zip([f_range_x, fr_range_x, br_range_x, b_range_x, bl_range_x, fl_range_x],
                                    [f_range_y, fr_range_y, br_range_y, b_range_y, bl_range_y, fl_range_y]):
            step_x = (range_x[1]-range_x[0])/flow_items
            step_y = (range_y[1]-range_y[0])/flow_items
            xys = torch.stack([
                torch.arange(range_x[0], range_x[1], step_x)+step_x/2 if step_x != 0 else torch.ones(flow_items)*range_x[0],
                torch.arange(range_y[0], range_y[1], step_y)+step_y/2 if step_y != 0 else torch.ones(flow_items)*range_y[0],
                                ], dim=-1)
            flow_xys_all.append(xys)

        flow_xys_all = torch.stack(flow_xys_all, dim=0)
        return flow_xys_all
    
    def get_forward_point(self, sdc_trajs):
        ego_xys = [one[1] for one in sdc_trajs]
        forward_points = []
        # tan_30 = math.tan(math.radians(30))
        tan_60 = math.tan(math.radians(60))
        half = 51.2 / tan_60
        step = half / self.flow_num * 2
        for ego_xy in ego_xys:
            if sum(ego_xy) == 0:
                forward_points.append(11)
                continue
            if ego_xy[0] >= 0:
                left_right = 1
            else:
                left_right = 0
            tan = abs(ego_xy[1]/ego_xy[0])
            if tan >= tan_60:
                forward_point = (51.2/tan + half)//step
            else:
                if left_right:
                    x = 51.2 * tan_60 / (tan + tan_60)
                    y = x * tan
                    delta = math.sqrt((x-51.2/tan_60)**2 + (y-51.2)**2)
                    forward_point = 22 + delta//step
                else:
                    x = 51.2 * tan_60 / (-tan - tan_60)
                    y = - x * tan
                    delta = math.sqrt((x+51.2/tan_60)**2 + (y-51.2)**2)
                    forward_point = - delta//step
            forward_points.append(int(forward_point))
        return forward_points[0]

    def init_weights(self):
        self.transformer.init_weights()

    @auto_fp16(apply_to=('mlvl_feats'))
    def forward(self, mlvl_feats, img_metas):
        # cam_types = [
        #         'CAM_FRONT', 'CAM_FRONT_RIGHT', 'CAM_BACK_RIGHT',
        #         'CAM_BACK', 'CAM_BACK_LEFT', 'CAM_FRONT_LEFT'
        #     ]
        # can_bus = mlvl_feats[0].new_tensor(
        #     [each['ego_pose'][0] for each in img_metas])  # [:, :]
        # can_bus = self.can_bus_mlp(can_bus)[:, None, :]

        # print([one.shape for one in mlvl_feats])
        # raise EOFError

        device = mlvl_feats[0].device
        B, TN, C, _, _ = mlvl_feats[0].shape
        N = 6
        T = TN // N
        flow_spatial_pe = self.flow_spatial_pe_map(self.flow_query_xys.to(device)).unsqueeze(0)  # [1, N, 22, C]
        flow_spatial_pe_local = rearrange(flow_spatial_pe.repeat(B*T, 1, 1, 1), 'b n y c -> (b n) y c')  # [BTN, 22, C]

        # pad spatial PE for 3*patch self-attn
        flow_spatial_pe_3p = flow_spatial_pe_local.transpose(1,2).reshape(B, T, N, C, 1, self.flow_num)
        pe_pad_left = flow_spatial_pe_3p[:,:,self.ids_pad_left_right[:,0],:,:,-self.flow_pad_single:]
        pe_pad_right = flow_spatial_pe_3p[:,:,self.ids_pad_left_right[:,1],:,:,:self.flow_pad_single]
        # [B, T, N, C, 1, 24]
        flow_spatial_pe_3p = torch.cat([pe_pad_left, flow_spatial_pe_3p, pe_pad_right], dim=-1)
        # [BTN, C*1*3, 1*22]
        flow_spatial_pe_3p_unfold = F.unfold(flow_spatial_pe_3p.flatten(0,2), kernel_size=(1, self.flow_pad_sum), stride=1)
        # [BTN, 22, 3, C]
        flow_spatial_pe_3p = rearrange(flow_spatial_pe_3p_unfold, 'b (c l) y -> b y l c', c=C, l=self.flow_pad_sum)

        ego_pose_temporal = torch.from_numpy(
            np.stack([each['ego_pose'] for each in img_metas], axis=0)).to(device).to(flow_spatial_pe.dtype)
        flow_temporal_pe = self.flow_temporal_pe_map(ego_pose_temporal)  # [B, T, C]
        flow_temporal_pe_local = flow_temporal_pe.reshape(B, T, 1, 1, C).repeat(1, 1, N, self.flow_num, 1).flatten(0,2)  # [BTN, 22, C]

        outs = {}
        outputs_wm_temporal = []
        outputs_wm_spatial = []

        mlvl_feats_flow = []
        mlvl_feats_flow_unfold = []
        mlvl_feats_flow_unfold_fuse = []

        H_W_Ps = []
        for idx, mlvl_feat in enumerate(mlvl_feats):
            if idx not in self.flow_ids:
                mlvl_feats_flow.append(mlvl_feat)
                continue
            flow_patch = self.flow_patches[idx]
            B, TN, C, H, W = mlvl_feat.shape
            H_W_Ps.append([H,W,flow_patch])
            mlvl_feat_reshaped = mlvl_feat.reshape(B, T, N, C, H, W)

            # pad two sides of each view
            feat_pad_left = mlvl_feat_reshaped[:,:,self.ids_pad_left_right[:,0],:,:,-self.flow_pad_single*flow_patch:]
            feat_pad_right = mlvl_feat_reshaped[:,:,self.ids_pad_left_right[:,1],:,:,:self.flow_pad_single*flow_patch]
            mlvl_feat_pad_reshaped = torch.cat([feat_pad_left, mlvl_feat_reshaped, feat_pad_right], dim=-1)

            # [BTN, C*H*3P, 1*22]
            mlvl_feat_pad_unfold = F.unfold(mlvl_feat_pad_reshaped.flatten(0,2), kernel_size=(H, self.flow_pad_sum*flow_patch), stride=flow_patch)
            # [BTN, C, H, 3P, 22]
            mlvl_feat_flow_unfold = mlvl_feat_pad_unfold.reshape(B*TN, C, H, self.flow_pad_sum*flow_patch, W//flow_patch)

            if idx in self.flow_attn_ids:
                # interaction inside 3 patches
                mlvl_feat_flow_unfold = rearrange(mlvl_feat_flow_unfold, 'b c h p y -> (b h y) p c')
                # [BTN, 22, 3, C] -> [BTN * H * 22, 3P, C]
                flow_spatial_pe_3p_selfattn = flow_spatial_pe_3p.unsqueeze(3).unsqueeze(1).repeat(1,H,1,1,flow_patch,1).flatten(0,2).flatten(1,2)
                # [BTN * H * 22, 3P, C]
                mlvl_feat_flow_unfold = self.flow_attns[self.flow_attn_ids.index(idx)](mlvl_feat_flow_unfold, pos=flow_spatial_pe_3p_selfattn)
                # [BTN, C, H, 3P, 22]
                mlvl_feat_flow_unfold = rearrange(mlvl_feat_flow_unfold, '(b h y) p c -> b c h p y', b=B*TN, h=H, y=W//flow_patch)

            # [BTN, C, H, 22, 3P] -> [BTN, C, H, 22, P]
            mlvl_feat_flow_unfold = self.flow_mlps[self.flow_ids.index(idx)](mlvl_feat_flow_unfold.transpose(3,4))
            mlvl_feats_flow_unfold.append(mlvl_feat_flow_unfold)
            # [BTN, C, 1, 22, patch]
            mlvl_feats_flow_unfold_fuse.append(mlvl_feat_flow_unfold.mean(dim=2))
        
        # [BTN, C, 22, 8+4+2+1]
        mlvl_feats_flow_unfold_fuse = torch.cat(mlvl_feats_flow_unfold_fuse, dim=-1)
        # [BTN, 22, C] -> [BTN, 1, 22, C]
        flow_spatial_pe_selfattn = flow_spatial_pe_local.unsqueeze(1)
        # [BTN, 8+4+2+1, 22, C]
        mlvl_feats_flow_unfold_fuse = self.patch_selfattn(rearrange(mlvl_feats_flow_unfold_fuse, 'b c y p -> b p y c'), pos=flow_spatial_pe_selfattn)

        # [N, 22, C] -> [BTN, 22, C]
        flow_query = self.flow_query_embed(self.flow_query_xys.to(mlvl_feat.device)).unsqueeze(0).repeat(B*T, 1, 1, 1).flatten(0,1)
        # get local messages from [8+4+2+1] of [BTN, 8+4+2+1, 22, C]
        flow_query = self.patch_crossattn_local(flow_query.unsqueeze(1), mlvl_feats_flow_unfold_fuse).reshape(B*TN, W//flow_patch, C)
        # print('output', flow_query.shape)

        # get temporal messages from [T] of [BTN, 8+4+2+1, 22, C]
        # prompt for temporal world model
        flow_temporal_pe_crossattn = rearrange(flow_temporal_pe_local, '(b t n) y c -> (b n) t y c', b=B, t=T, n=6)
        # [BN, T, 22, C]
        flow_query_temporal = self.patch_crossattn_temporal(
            rearrange(flow_query.squeeze(1), '(b t n) y c -> (b n) t y c', b=B, t=T, n=6),
            rearrange(mlvl_feats_flow_unfold_fuse.mean(dim=1), '(b t n) y c -> (b n) t y c', b=B, t=T, n=6, y=self.flow_num),
            q_pos=flow_temporal_pe_crossattn,
            k_pos=flow_temporal_pe_crossattn,
            flatten=False
        )
        # print('output', flow_query_temporal.shape)

        # get spatial messages from [22] of [BTN, 8+4+2+1, 22, C]
        # prompt for spatial world model
        flow_spatial_pe_crossattn = rearrange(flow_spatial_pe_local, '(b t n) y c -> (b n) y t c', b=B, t=T, n=6)
        # [BN, 22, T, C]
        flow_query_spatial = self.patch_crossattn_spatial(
            rearrange(flow_query.squeeze(1), '(b t n) y c -> (b n) y t c', b=B, t=T, n=6),
            rearrange(mlvl_feats_flow_unfold_fuse.mean(dim=1), '(b t n) y c -> (b n) y t c', b=B, t=T, n=6, y=self.flow_num),
            q_pos=flow_spatial_pe_crossattn,
            k_pos=flow_spatial_pe_crossattn,
            flatten=False
        )
        # print('output', flow_query_spatial.shape)

        # split left/right spatial flow from ego forward_point
        sdc_trajs = [each['sdc_traj'] for each in img_metas]
        forward_point = self.get_forward_point(sdc_trajs)
        flow_query_spatial = rearrange(flow_query_spatial, '(b n) y t c -> (b t n) y c', b=B, t=T, n=6)
        flow_query_spat = rearrange(flow_query_spatial, '(b t n) y c -> b (n y) t c', b=B, t=T, n=6)
        if forward_point >= 0:
            flow_query_spat_l = torch.cat([flow_query_spat[:, forward_point+3*self.flow_num:], flow_query_spat[:, :forward_point]], dim=1).flip(dims=[1])
            flow_query_spat_r = flow_query_spat[:, forward_point:forward_point+3*self.flow_num]
        else:
            flow_query_spat_r = torch.cat([flow_query_spat[:, forward_point:], flow_query_spat[:, :forward_point+3*self.flow_num:]], dim=1)
            flow_query_spat_l = flow_query_spat[:, forward_point+3*self.flow_num:forward_point].flip(dims=[1])
        # [B, 2, 66, T, C]
        flow_query_spat = torch.stack([flow_query_spat_l, flow_query_spat_r], dim=1)
        flow_query_spat = rearrange(flow_query_spat, 'b n (y p) t c -> (b n p) y t c', y=self.flow_num*3//self.spat_size, p=self.spat_size)

        for idx, (mlvl_feat_flow_unfold, (H, W, flow_patch)) in enumerate(zip(mlvl_feats_flow_unfold, H_W_Ps)):
            if idx in self.flow_wm_ids:
                layer_idx = self.flow_wm_ids.index(idx)

                # temporal world model
                # [BTN, C, H, W/patch, patch]
                output_wm, mlvl_feat_flow_unfold_fut = self.world_models_temporal[layer_idx](
                    rearrange(mlvl_feat_flow_unfold.mean(dim=2).mean(dim=-1), '(b t n) c y -> (b n) t y c', b=B, t=T, n=6).flip(dims=[1]),
                    flow_query_temporal.flip(dims=[1]))
                outputs_wm_temporal.append(output_wm)
                
                # enhance temporal messages with temporal prediction
                mlvl_feat_flow_unfold_fuse = torch.cat([rearrange(mlvl_feat_flow_unfold, 'b c h y p -> b h y p c'),
                                                        rearrange(mlvl_feat_flow_unfold_fut.flip(dims=[1]), '(b n) t y c -> (b t n) y c', b=B, n=6, y=W//flow_patch).unsqueeze(2).unsqueeze(1).repeat(1,H,1,flow_patch,1)], dim=-1)
                mlvl_feat_flow_unfold = rearrange(mlvl_feat_flow_unfold_fuse, 'b h y p c -> b c h y p', b=B*TN, h=H, y=W//flow_patch)  # [BTN, C, H, patch, W/patch]
                mlvl_feat_flow_fold = F.fold(mlvl_feat_flow_unfold.transpose(3,4).flatten(1,3), output_size=(H,W), kernel_size=(H, flow_patch), stride=flow_patch)
                mlvl_feat_flow_fold = self.flow_wm_convs[layer_idx](mlvl_feat_flow_fold)

                mlvl_feat_flow_unfold = F.unfold(mlvl_feat_flow_fold, kernel_size=(H, flow_patch), stride=flow_patch).reshape(B*TN, C, H, flow_patch, W//flow_patch).transpose(3,4)  # [btn, chp, y]
                
                # split left/right spatial flow from ego forward_point
                mlvl_feat_flow_unfold_spatial = rearrange(mlvl_feat_flow_unfold.mean(dim=2).mean(dim=-1), '(b t n) c y -> b (n y) t c', b=B, t=T, n=6)
                # mlvl_feat_flow_unfold_spatial_l = torch.cat([mlvl_feat_flow_unfold_spatial[:, forward_point+3*self.flow_num:], mlvl_feat_flow_unfold_spatial[:, :forward_point]], dim=1).flip(dims=[1])
                # mlvl_feat_flow_unfold_spatial_r = mlvl_feat_flow_unfold_spatial[:, forward_point:forward_point+3*self.flow_num]
                if forward_point >= 0:
                    mlvl_feat_flow_unfold_spatial_l = torch.cat([mlvl_feat_flow_unfold_spatial[:, forward_point+3*self.flow_num:], mlvl_feat_flow_unfold_spatial[:, :forward_point]], dim=1).flip(dims=[1])
                    mlvl_feat_flow_unfold_spatial_r = mlvl_feat_flow_unfold_spatial[:, forward_point:forward_point+3*self.flow_num]
                else:
                    mlvl_feat_flow_unfold_spatial_r = torch.cat([mlvl_feat_flow_unfold_spatial[:, forward_point:], mlvl_feat_flow_unfold_spatial[:, :forward_point+3*self.flow_num]], dim=1)
                    mlvl_feat_flow_unfold_spatial_l = mlvl_feat_flow_unfold_spatial[:, forward_point+3*self.flow_num:forward_point].flip(dims=[1])
                # [B, 2, 66, T, C]
                mlvl_feat_flow_unfold_spatial = torch.stack([mlvl_feat_flow_unfold_spatial_l, mlvl_feat_flow_unfold_spatial_r], dim=1)
                # split with spatial patch (len 6)
                mlvl_feat_flow_unfold_spatial = rearrange(mlvl_feat_flow_unfold_spatial, 'b n (y p) t c -> (b n p) y t c', y=self.flow_num*3//self.spat_size, p=self.spat_size)
                
                # spatial world model
                # [BTN, C, H, W/patch, patch]
                output_wm_spatial, mlvl_feat_flow_unfold_fut_spatial = self.world_models_spatial[layer_idx](
                    mlvl_feat_flow_unfold_spatial,
                    flow_query_spat)
                outputs_wm_spatial.append(output_wm_spatial)

                # enhance spatial messages with spatial prediction
                mlvl_feat_flow_unfold_fut_spatial = rearrange(mlvl_feat_flow_unfold_fut_spatial, '(b n p) y t c -> b n (y p) t c', b=B, n=2, p=self.spat_size)
                mlvl_feat_flow_unfold_fut_spatial = torch.cat([mlvl_feat_flow_unfold_fut_spatial[:,1], mlvl_feat_flow_unfold_fut_spatial[:,0].flip(dims=[1])], dim=1)
                mlvl_feat_flow_unfold_fut_spatial = torch.cat([mlvl_feat_flow_unfold_fut_spatial[:,-forward_point:], mlvl_feat_flow_unfold_fut_spatial[:,:-forward_point]], dim=1)
                mlvl_feat_flow_unfold_fut_spatial = rearrange(mlvl_feat_flow_unfold_fut_spatial, 'b (n y) t c -> (b t n) y c', n=6, y=self.flow_num)
                mlvl_feat_flow_unfold_fuse = torch.cat([rearrange(mlvl_feat_flow_unfold, 'b c h y p -> b h y p c'),
                                                        mlvl_feat_flow_unfold_fut_spatial.unsqueeze(2).unsqueeze(1).repeat(1,H,1,flow_patch,1)], dim=-1)
                mlvl_feat_flow_unfold = rearrange(mlvl_feat_flow_unfold_fuse, 'b h y p c -> b c h y p', b=B*TN, h=H, y=W//flow_patch)  # [BTN, C, H, patch*3, W/patch]
                mlvl_feat_flow_fold = F.fold(mlvl_feat_flow_unfold.transpose(3,4).flatten(1,3), output_size=(H,W), kernel_size=(H, flow_patch), stride=flow_patch)
                mlvl_feat_flow_fold = self.flow_wm_convs_spatial[layer_idx](mlvl_feat_flow_fold)
            else:
                mlvl_feat_flow_fold = F.fold(mlvl_feat_flow_unfold.transpose(3,4).flatten(1,3), output_size=(H,W), kernel_size=(H, flow_patch), stride=flow_patch)
            mlvl_feats_flow.append(mlvl_feat_flow_fold.reshape(B, TN, C, H, W))

        occ_preds, mask_preds, class_preds = self.transformer(mlvl_feats_flow, img_metas=img_metas)

        return {
            'occ_preds': occ_preds, 
            'mask_preds': mask_preds, 
            'class_preds': class_preds,
            'wm_temporal': outputs_wm_temporal,
            'wm_spatial': outputs_wm_spatial,
        }

    @force_fp32(apply_to=('preds_dicts'))
    def loss(self, voxel_semantics, voxel_instances, instance_class_ids, preds_dicts, mask_camera=None):
        return self.loss_single(voxel_semantics, voxel_instances, instance_class_ids, preds_dicts, mask_camera)

    def loss_single(self, voxel_semantics, voxel_instances, instance_class_ids, preds_dicts, mask_camera=None):
        loss_dict = {}
        B = voxel_instances.shape[0]

        if mask_camera is not None:
            assert mask_camera.shape == voxel_semantics.shape
            assert mask_camera.dtype == torch.bool
        
        for i, (occ_loc_i, _, seg_pred_i, _, scale) in enumerate(preds_dicts['occ_preds']):
            loss_dict_i = {}
            for b in range(B):
                loss_dict_i_b = {}
                seg_pred_i_sparse, voxel_semantics_sparse, sparse_mask = get_voxel_decoder_loss_input(
                    voxel_semantics[b:b + 1],
                    occ_loc_i[b:b + 1],
                    seg_pred_i[b:b + 1] if seg_pred_i is not None else None,
                    scale,
                    self.num_classes
                )

                loss_dict_i_b['loss_sem_lovasz'] = lovasz_softmax(torch.softmax(seg_pred_i_sparse, dim=1), voxel_semantics_sparse)

                valid_mask = (voxel_semantics_sparse < 255)
                seg_pred_i_sparse = seg_pred_i_sparse[valid_mask].transpose(0, 1).unsqueeze(0)  # [K, CLS] -> [B, CLS, K]
                voxel_semantics_sparse = voxel_semantics_sparse[valid_mask].unsqueeze(0)  # [K] -> [B, K]

                if 'loss_geo_scal' in self.criterions.keys():
                    loss_dict_i_b['loss_geo_scal'] = self.criterions['loss_geo_scal'](seg_pred_i_sparse, voxel_semantics_sparse)  
                if 'loss_sem_scal' in self.criterions.keys():
                    loss_dict_i_b['loss_sem_scal'] = self.criterions['loss_sem_scal'](seg_pred_i_sparse, voxel_semantics_sparse)

                loss_dict_i_b['loss_sem_ce'] = CE_ssc_loss(seg_pred_i_sparse, voxel_semantics_sparse, self.class_weights.type_as(seg_pred_i_sparse))

                for loss_key in loss_dict_i_b.keys():
                    loss_dict_i[loss_key] = loss_dict_i.get(loss_key, 0) + loss_dict_i_b[loss_key] / B

            for k, v in loss_dict_i.items():
                loss_dict['%s_%d' % (k, i)] = v

        occ_loc = preds_dicts['occ_preds'][-1][0]
        
        batch_idx = torch.arange(B)[:, None, None].expand(B, occ_loc.shape[1], 1).to(occ_loc.device)
        occ_loc = occ_loc.reshape(-1, 3)
        voxel_instances = voxel_instances[batch_idx.reshape(-1), occ_loc[..., 0], occ_loc[..., 1], occ_loc[..., 2]]
        voxel_instances = voxel_instances.reshape(B, -1)  # [B, N]

        if mask_camera is not None:
            mask_camera = mask_camera[batch_idx.reshape(-1), occ_loc[..., 0], occ_loc[..., 1], occ_loc[..., 2]]
            mask_camera = mask_camera.reshape(B, -1)  # [B, N]
        
        # drop instances if it has no positive voxels
        for b in range(B):
            instance_count = instance_class_ids[b].shape[0]
            instance_voxel_counts = torch.bincount(voxel_instances[b].long())  # [255]
            id_map = torch.cumsum(instance_voxel_counts > 0, dim=0) - 1
            id_map[255] = 255  # empty space still has an id of 255
            voxel_instances[b] = id_map[voxel_instances[b].long()]
            instance_class_ids[b] = instance_class_ids[b][instance_voxel_counts[:instance_count] > 0]

        for i, pred in enumerate(preds_dicts['mask_preds']):
            indices = self.matcher(pred, preds_dicts['class_preds'][i], voxel_instances, instance_class_ids, mask_camera)
            loss_mask, loss_dice, loss_class = self.criterions['loss_mask2former'](
                pred, preds_dicts['class_preds'][i], voxel_instances, instance_class_ids, indices, mask_camera)
            loss_dict['loss_mask_{:d}'.format(i)] = loss_mask
            loss_dict['loss_dice_mask_{:d}'.format(i)] = loss_dice
            loss_dict['loss_class_{:d}'.format(i)] = loss_class

        outputs_wm_temporal = preds_dicts['wm_temporal']
        losses_wm_temporal = []
        for idx, output_wm_temporal in enumerate(outputs_wm_temporal):
            value = torch.nan_to_num(self.loss_wm_temporal(output_wm_temporal['prior'], output_wm_temporal['posterior']))
            loss_dict[f'l{idx}.loss_wm_t'] = value
            losses_wm_temporal.append(value)
        loss_dict['loss_wm_t'] = sum(losses_wm_temporal)

        outputs_wm_spatial = preds_dicts['wm_spatial']
        losses_wm_spatial = []
        for idx, output_wm_spatial in enumerate(outputs_wm_spatial):
            value = torch.nan_to_num(self.loss_wm_spatial(output_wm_spatial['prior'], output_wm_spatial['posterior']))
            loss_dict[f'l{idx}.loss_wm_s'] = value
            losses_wm_spatial.append(value)
        loss_dict['loss_wm_s'] = sum(losses_wm_spatial)
        return loss_dict
    
    def merge_occ_pred(self, outs):
        mask_cls = outs['class_preds'][-1].sigmoid()
        mask_pred = outs['mask_preds'][-1].sigmoid()
        occ_indices = outs['occ_preds'][-1][0]
        
        sem_pred = self.merge_semseg(mask_cls, mask_pred)  # [B, C, N]
        outs['sem_pred'] = sem_pred
        outs['occ_loc'] = occ_indices

        if self.panoptic:
            pano_inst, pano_sem = self.merge_panoseg(mask_cls, mask_pred)  # [B, C, N]
            outs['pano_inst'] = pano_inst
            outs['pano_sem'] = pano_sem
        
        return outs
    
    # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/mask_former_model.py#L242
    def merge_semseg(self, mask_cls, mask_pred):
        valid_mask = mask_cls.max(dim=-1).values > self.score_threshold
        mask_cls[~valid_mask] = 0.0

        semseg = torch.einsum("bqc,bqn->bcn", mask_cls, mask_pred)
        if semseg.shape[1] == self.num_classes:
            semseg = semseg[:, :-1]
        
        cls_score, cls_id = torch.max(semseg, dim=1)
        cls_id[cls_score < 0.01] = self.num_classes - 1
        return cls_id  # [B, N]
    
    def merge_panoseg(self, mask_cls, mask_pred):
        pano_inst, pano_sem = [], []
        for b in range(mask_cls.shape[0]):
            pano_inst_b, pano_sem_b = self.merge_panoseg_single(
                mask_cls[b:b + 1],
                mask_pred[b:b + 1]
            )
            pano_inst.append(pano_inst_b)
            pano_sem.append(pano_sem_b)
        
        pano_inst = torch.cat(pano_inst, dim=0)
        pano_sem = torch.cat(pano_sem, dim=0)
        
        return pano_inst, pano_sem

    # https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/maskformer_model.py#L286
    def merge_panoseg_single(self, mask_cls, mask_pred):
        assert mask_cls.shape[0] == 1, "bs != 1"
        scores, labels = mask_cls.max(-1)
        
        # filter out low score and background instances
        keep = labels.ne(self.num_classes - 1) & (scores > self.score_threshold)
        cur_scores = scores[keep]
        cur_classes = labels[keep]
        cur_masks = mask_pred[keep]

        cur_prob_masks = cur_scores.view(-1, 1) * cur_masks

        N = cur_masks.shape[-1]
        instance_seg = torch.zeros((N), dtype=torch.int32, device=cur_masks.device)
        semantic_seg = torch.ones((N), dtype=torch.int32, device=cur_masks.device) * (self.num_classes - 1)
        
        current_segment_id = 0
        stuff_memory_list = {self.num_classes - 1: 0}

        # skip all process if no mask is detected
        if cur_masks.shape[0] != 0:
            # take argmax
            cur_mask_ids = cur_prob_masks.argmax(0)  # [N]
            for k in range(cur_classes.shape[0]):
                pred_class = cur_classes[k].item()

                # moving objects are treated as instances
                is_thing = self.class_names[pred_class] in [
                    'car', 'truck', 'construction_vehicle', 'bus',
                    'trailer', 'motorcycle', 'bicycle', 'pedestrian'
                ]

                mask_area = (cur_mask_ids == k).sum().item()
                original_area = (cur_masks[k] >= 0.5).sum().item()
                mask = (cur_mask_ids == k) & (cur_masks[k] >= 0.5)

                if mask_area > 0 and original_area > 0 and mask.sum().item() > 0:
                    if mask_area / original_area < self.overlap_threshold:
                        continue

                    # merge stuff regions
                    if not is_thing:
                        if int(pred_class) in stuff_memory_list.keys():
                            instance_seg[mask] = stuff_memory_list[int(pred_class)]
                            continue
                        else:
                            stuff_memory_list[int(pred_class)] = current_segment_id + 1

                    current_segment_id += 1
                    instance_seg[mask] = current_segment_id
                    semantic_seg[mask] = pred_class
        
        instance_seg = instance_seg.unsqueeze(0)
        semantic_seg = semantic_seg.unsqueeze(0)
        
        return instance_seg, semantic_seg  # [B, N]
