import math
import pdb
import time

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from perceptron.utils.ParameterizeLane import CustomParameterizeLane
from mmdet3d.models import ResNet, build_backbone, build_neck
from mmdet3d.models.builder import BACKBONES
from mmdet.models.backbones.resnet import BasicBlock
from mmdet.models.utils import ResLayer
from perceptron_ops.voxel_pooling import voxel_pooling
from scipy.optimize import linear_sum_assignment

from perceptron.utils import torch_dist as dist
from perceptron.models.lane3d.depthnet.aspp_depthnet import AsppConv2d, DCNAspp
from perceptron.models.lane3d.depthnet.bevdepth_depthnet import ASPP
from .backbone.convnext import convnext_base
from .backbone.efficientnet import EfficientNet
from .backbone.vovnet import VoVNet

#total_time = 0
#total_cnt = 0

class BasicLayers(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        
        if cfg.bev_aspp == 'aspp':
            self.bev_aspp = ASPP(cfg.in_channels, cfg.in_channels)

        self.res_layers = nn.ModuleList()
        for i in range(len(cfg.channels)):
            layer = ResLayer(
                block=BasicBlock,
                inplanes=cfg.channels[i - 1] if i else cfg.in_channels,
                planes=cfg.channels[i],
                num_blocks=cfg.num_blocks[i],
                stride=cfg.strides[i],
            )
            self.res_layers.append(layer)

    def forward(self, x):
        if self.cfg.bev_aspp != 'none':
            x = self.bev_aspp(x)

        out = []
        for layer in self.res_layers:
            x = layer(x)
            out.append(x)
        return out


class LSSViewTransformer(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.voxel_lb = nn.Parameter(torch.Tensor(cfg.voxel_range[::2]), requires_grad=False)
        self.voxel_size = nn.Parameter(torch.Tensor(cfg.voxel_size), requires_grad=False)
        self.num_voxel = nn.Parameter(
            torch.LongTensor(
                [(cfg.voxel_range[2 * i + 1] - cfg.voxel_range[2 * i]) // cfg.voxel_size[i] for i in range(3)]
            ),
            requires_grad=False,
        )
        self.depth_channels = math.ceil((cfg.depth_range[1] - cfg.depth_range[0]) / cfg.depth_step)
        self.out_channels = cfg.out_channels

        if self.cfg.head_type == 'aspp_conv':
            self.lift_conv = AsppConv2d(cfg.in_channels, self.out_channels, self.depth_channels)
        elif self.cfg.head_type == 'dcn_aspp':
            self.lift_conv = DCNAspp(cfg.in_channels, self.out_channels, self.depth_channels)
        else:
            self.lift_conv = nn.Conv2d(cfg.in_channels, self.depth_channels + self.out_channels, kernel_size=1, padding=0)

    def lift(self, x):
        B, _, H, W = x.shape
        x = self.lift_conv(x)
        depth, context = torch.split(x, [self.depth_channels, self.out_channels], dim=1)
        depth = depth.softmax(dim=1).type(depth.dtype)
        x = depth.unsqueeze(1) * context.unsqueeze(2)
        x = x.view(B, self.out_channels, self.depth_channels, H, W)
        x = x.permute(0, 2, 3, 4, 1)  # [B, D, H, W, C]
        return x

    def splat(self, x, frustum_coord):
        frustum_coord = frustum_coord.permute(0, 3, 1, 2, 4)  # [B, D, H, W, 3]
        voxel_index = ((frustum_coord - self.voxel_lb) / self.voxel_size).int().contiguous()
        x_bev = voxel_pooling(voxel_index, x.contiguous(), self.num_voxel).contiguous()
        return x_bev

    def forward(self, x, frustum_coord):
        x = self.lift(x)
        x = self.splat(x, frustum_coord)    # Left shape: (B, C, bev_y, bev_x)
        return x

class LaneShareConv2D(nn.Module):
    def __init__(self, total_in_ch, total_out_ch, kernel_size, groups, n_instance = 16):
        super().__init__()
        assert total_in_ch % n_instance == 0 and total_out_ch % n_instance == 0
        self.n_instance = n_instance

        self.conv_in_ch = total_in_ch // n_instance
        self.conv_out_ch = total_out_ch // n_instance
        self.conv2d = nn.Conv2d(self.conv_in_ch, self.conv_out_ch, kernel_size=kernel_size)
    
    def forward(self, x):
        b, c, h, w = x.shape

        x = x.view(b * self.n_instance, c // self.n_instance, h, w)
        y = self.conv2d(x)
        y = y.view(b, self.n_instance * self.conv_out_ch, y.shape[-2], y.shape[-1])

        return y

class SoftChannelReorder(nn.Module):
    def __init__(self, in_ch):
        super().__init__()
        self.in_ch = in_ch

        self.hw_reducer = nn.Sequential(
            ConvModule(in_ch, in_ch, kernel_size=3, norm_cfg=dict(type="BN2d")),
            nn.AdaptiveMaxPool2d((1, 1))
        )
        self.fc = nn.Linear(in_ch, in_ch * in_ch)

    def forward(self, x):
        b, c, h, w = x.shape

        a_x = self.hw_reducer(x).squeeze(-1).squeeze(-1)    # a_x shape: (B, C)
        f_x = self.fc(a_x).view(b, self.in_ch, self.in_ch).softmax(dim = 1) # Left shape: (B, C, C)
        
        y = x.permute(0, 2, 3, 1).unsqueeze(-2) @ f_x[:, None, None]   # Left shape: (B, H, W, 1, C)

class RowAnchorHead(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg

        self.ParameterizeLane = CustomParameterizeLane(method = 'bezier_Endpointfixed', method_para = dict(n_control=5))
        
        if 'group_conv' in cfg.keys() and cfg['group_conv']:
            conv_group = cfg['n_instances']
        else:
            conv_group = 1

        if 'share_final_conv' in cfg.keys() and cfg.share_final_conv:
            final_conv = LaneShareConv2D
        else:
            final_conv = nn.Conv2d

        if self.cfg.use_dual_dire_head:
            share_conv_out_ch = 2 * cfg.share_conv_channel
        else:
            share_conv_out_ch = cfg.share_conv_channel
        
        self.share_conv = ConvModule(
            cfg.in_channels,
            share_conv_out_ch,
            kernel_size=3,
            padding=1,
            groups = conv_group,
            norm_cfg=dict(type="BN2d"),
        )

        if self.cfg.use_dual_dire_head == False:
            self.branch_head = nn.ModuleDict()
            self.branch_head.add_module(
                "obj",
                nn.Sequential(
                    ConvModule(cfg.share_conv_channel, cfg.head_conv, kernel_size=1, groups = conv_group, norm_cfg=dict(type="BN2d")),
                    nn.AdaptiveMaxPool2d((1, 1)),
                    final_conv(cfg.head_conv, self.cfg.n_instances, kernel_size=1, groups = conv_group),
                ),
            )
            
            if 'scls_head' in self.cfg.keys() and self.cfg['scls_head']:
                self.branch_head.add_module(
                    "cls",
                    nn.Sequential(
                        ConvModule(cfg.share_conv_channel, cfg.head_conv, kernel_size=1, groups = conv_group, norm_cfg=dict(type="BN2d")),
                        final_conv(cfg.head_conv, self.cfg.n_instances * cfg.n_classes, kernel_size=1, groups = conv_group),
                    ),
                )
                self.cls_head_mix = nn.Linear(self.cfg['y_steps'].shape[0], 1)   # It can also be implemented with Conv2D. We use Linear for convenience.
            else:
                self.branch_head.add_module(
                    "cls",
                    nn.Sequential(
                        ConvModule(cfg.share_conv_channel, cfg.head_conv, kernel_size=1, groups = conv_group, norm_cfg=dict(type="BN2d")),
                        nn.AdaptiveMaxPool2d((1, 1)),
                        final_conv(cfg.head_conv, self.cfg.n_instances * cfg.n_classes, kernel_size=1, groups = conv_group),
                    ),
                )

            self.branch_head.add_module(
                "vis",
                nn.Sequential(
                    ConvModule(cfg.share_conv_channel, cfg.head_conv, kernel_size=1, groups = conv_group, norm_cfg=dict(type="BN2d")),
                    nn.AdaptiveMaxPool2d((None, 1)),
                    final_conv(cfg.head_conv, self.cfg.n_instances, kernel_size=1, groups = conv_group),
                ),
            )
            for k in ["row", "reg_x", "reg_z"]:
                self.branch_head.add_module(
                    k,
                    nn.Sequential(
                        ConvModule(cfg.share_conv_channel, cfg.head_conv, kernel_size=1, groups = conv_group, norm_cfg=dict(type="BN2d")),
                        final_conv(cfg.head_conv, self.cfg.n_instances, kernel_size=1, groups = conv_group),
                    ),
                )

        elif self.cfg.use_dual_dire_head == True:
            # hy
            self.hy_branch_head = nn.ModuleDict()
            self.hy_branch_head.add_module(
                "hy_obj",
                nn.Sequential(
                    ConvModule(cfg.share_conv_channel, cfg.head_conv, kernel_size=1, groups = conv_group, norm_cfg=dict(type="BN2d")),
                    nn.AdaptiveMaxPool2d((1, 1)),
                    final_conv(cfg.head_conv, self.cfg.n_instances, kernel_size=1, groups = conv_group),
                ),
            )
            self.hy_branch_head.add_module(
                "hy_cls",
                nn.Sequential(
                    ConvModule(cfg.share_conv_channel, cfg.head_conv, kernel_size=1, groups = conv_group, norm_cfg=dict(type="BN2d")),
                    nn.AdaptiveMaxPool2d((1, 1)),
                    final_conv(cfg.head_conv, self.cfg.n_instances * cfg.n_classes, kernel_size=1, groups = conv_group),
                ),
            )
            self.hy_branch_head.add_module(
                "hy_vis",
                nn.Sequential(
                    ConvModule(cfg.share_conv_channel, cfg.head_conv, kernel_size=1, groups = conv_group, norm_cfg=dict(type="BN2d")),
                    nn.AdaptiveMaxPool2d((None, 1)),
                    final_conv(cfg.head_conv, self.cfg.n_instances, kernel_size=1, groups = conv_group),
                ),
            )
            for k in ["hy_row", "hy_reg_x", "hy_reg_z"]:
                self.hy_branch_head.add_module(
                    k,
                    nn.Sequential(
                        ConvModule(cfg.share_conv_channel, cfg.head_conv, kernel_size=1, groups = conv_group, norm_cfg=dict(type="BN2d")),
                        final_conv(cfg.head_conv, self.cfg.n_instances, kernel_size=1, groups = conv_group),
                    ),
                )
            # hx
            self.hx_branch_head = nn.ModuleDict()
            self.hx_branch_head.add_module(
                "hx_obj",
                nn.Sequential(
                    ConvModule(cfg.share_conv_channel, cfg.head_conv, kernel_size=1, groups = conv_group, norm_cfg=dict(type="BN2d")),
                    nn.AdaptiveMaxPool2d((1, 1)),
                    final_conv(cfg.head_conv, self.cfg.n_instances, kernel_size=1, groups = conv_group),
                ),
            )
            self.hx_branch_head.add_module(
                "hx_cls",
                nn.Sequential(
                    ConvModule(cfg.share_conv_channel, cfg.head_conv, kernel_size=1, groups = conv_group, norm_cfg=dict(type="BN2d")),
                    nn.AdaptiveMaxPool2d((1, 1)),
                    final_conv(cfg.head_conv, self.cfg.n_instances * cfg.n_classes, kernel_size=1, groups = conv_group),
                ),
            )
            self.hx_branch_head.add_module(
                "hx_vis",
                nn.Sequential(
                    ConvModule(cfg.share_conv_channel, cfg.head_conv, kernel_size=1, groups = conv_group, norm_cfg=dict(type="BN2d")),
                    nn.AdaptiveMaxPool2d((1, None)),
                    final_conv(cfg.head_conv, self.cfg.n_instances, kernel_size=1, groups = conv_group),
                ),
            )
            for k in ["hx_row", "hx_reg_y", "hx_reg_z"]:
                self.hx_branch_head.add_module(
                    k,
                    nn.Sequential(
                        ConvModule(cfg.share_conv_channel, cfg.head_conv, kernel_size=1, groups = conv_group, norm_cfg=dict(type="BN2d")),
                        final_conv(cfg.head_conv, self.cfg.n_instances, kernel_size=1, groups = conv_group),
                    ),
                )
        
        if 'pos_emb' in self.cfg.keys() and self.cfg['pos_emb']:
            self.pos_emb = nn.Embedding(self.cfg['in_channels'], len(self.cfg['y_steps']) * len(self.cfg['x_steps']))

    @torch.no_grad()
    def match(self, output, targets):
        indices = []
        for b, tgt in enumerate(targets):
            n_tgt = len(tgt["cls"])
            if not self.cfg.use_dual_dire_head:
                n_pred = output['obj'].shape[1]
            else:
                n_pred = output['hy_obj'].shape[1]

            if n_tgt == 0:
                if not self.cfg.use_dual_dire_head:
                    indices.append([[], []])
                    continue
                else:
                    indices.append([[], [], [], []])
                    continue
            
            if not self.cfg.use_dual_dire_head:
                cost = 0
                obj = output["obj"][b].flatten(1)   # obj shape: (n_instance, 1)
                tgt_obj = torch.ones_like(obj)
                cost = cost + F.binary_cross_entropy_with_logits(obj, tgt_obj, reduction="none")

                if 'scls_head' in self.cfg.keys() and self.cfg['scls_head']:
                    out_cls = output['cls'][b][:, None, ...].expand([-1, n_tgt, -1, -1]).reshape(self.cfg.n_instances, \
                        self.cfg.n_classes, n_tgt, output['cls'].shape[-2], output['cls'].shape[-1])    # Left shape: (n_instance, n_cls, n_tgt, H, W)
                    cls_tgt_row = tgt["row"][None, None].expand(self.cfg.n_instances, self.cfg.n_classes, -1, -1).unsqueeze(-1)
                    out_cls = torch.gather(out_cls, -1, cls_tgt_row).squeeze(-1)    # Left shape: (n_instance, n_cls, n_tgt, H)
                    out_cls = self.cls_head_mix(out_cls).squeeze(-1)    # Left shape: (n_instance, n_cls, n_tgt)
                    tgt_cls = targets[b]['cls'][None, ...].expand([self.cfg.n_instances, -1])   # tgt_cls shape: (n_instance, n_tgt)
                    cost = cost + F.cross_entropy(out_cls, tgt_cls, reduction="none")
                else:
                    cls = output["cls"][b].reshape(-1, self.cfg.n_classes, 1).expand([-1, -1, n_tgt])   # Left shape: (n_instance, n_cls, n_tgt)   
                    tgt_cls = tgt["cls"][None, ...].expand([self.cfg.n_instances, -1])  # tgt_cls shape: (n_instance, n_tgt)
                    cost = cost + F.cross_entropy(cls, tgt_cls, reduction="none")   # Right F.cross_entropy shape: (n_instance, n_tgt)

                vis = output["vis"][b][:, None, :, 0].expand([-1, n_tgt, -1])   # Left shape: (n_instance, n_tgt, H)
                tgt_vis = tgt["vis"][None, ...].expand([self.cfg.n_instances, -1, -1]).float()  # tgt_vis shape: (n_instance, n_tgt, H)
                cost = cost + F.binary_cross_entropy_with_logits(vis, tgt_vis, reduction="none").mean(-1)

                row = output["row"][b][:, None, ...].expand([-1, n_tgt, -1, -1])    # Left shape: (n_instance, n_tgt, H, W)
                tgt_row = tgt["row"][None, ...].expand([self.cfg.n_instances, -1, -1])  # Left shape: (n_instance, n_tgt, H)
                cost_row_all = F.cross_entropy(row.flatten(0, 2), tgt_row.flatten(), reduction="none").reshape(
                    self.cfg.n_instances, n_tgt, -1
                )   # row.flatten(0, 2) shape: (n_instance * n_tgt * H, W), tgt_row shape: (n_instance * n_tgt * H), cost_row_all shape: (n_instance, n_tgt, H)
                cost = cost + (cost_row_all * tgt_vis).sum(-1) / tgt_vis.sum(-1)

                for reg_key in ["reg_x", "reg_z"]:
                    reg = (
                        output[reg_key][b][:, None, ...].expand([-1, n_tgt, -1, -1]).gather(-1, tgt_row[..., None])[..., 0]
                    )
                    tgt_reg = targets[b][reg_key][None, ...].expand([self.cfg.n_instances, -1, -1])
                    cost = cost + (F.l1_loss(reg, tgt_reg) * tgt_vis).sum(-1) / tgt_vis.sum(-1)

                indices.append(linear_sum_assignment(cost.detach().cpu().numpy()))

            else:
                # Obj
                hx_cost = 0
                hx_obj = output["hx_obj"][b].flatten(1)   # obj shape: (n_instance, 1)
                hx_tgt_obj = torch.ones_like(hx_obj)
                hx_cost = hx_cost + F.binary_cross_entropy_with_logits(hx_obj, hx_tgt_obj, reduction="none")

                hy_cost = 0
                hy_obj = output["hy_obj"][b].flatten(1)   # obj shape: (n_instance, 1)
                hy_tgt_obj = torch.ones_like(hy_obj)
                hy_cost = hy_cost + F.binary_cross_entropy_with_logits(hy_obj, hy_tgt_obj, reduction="none")

                # Classification
                hx_cls = output["hx_cls"][b].reshape(-1, self.cfg.n_classes, 1).expand([-1, -1, n_tgt])   # Left shape: (n_instance, n_cls, n_tgt)   
                tgt_cls = tgt["cls"][None, ...].expand([self.cfg.n_instances, -1])  # tgt_cls shape: (n_instance, n_tgt)
                hx_cost = hx_cost + F.cross_entropy(hx_cls, tgt_cls, reduction="none")

                hy_cls = output["hy_cls"][b].reshape(-1, self.cfg.n_classes, 1).expand([-1, -1, n_tgt])   # Left shape: (n_instance, n_cls, n_tgt)   
                hy_cost = hy_cost + F.cross_entropy(hx_cls, tgt_cls, reduction="none")

                # Visibility
                hx_vis = output["hx_vis"][b][:, None, 0, :].expand([-1, n_tgt, -1])   # Left shape: (n_instance, n_tgt, bev_x)
                hx_tgt_vis = tgt["hx_vis"][None, ...].expand([n_pred, -1, -1]).float()  # tgt_vis shape: (n_instance, n_tgt, bev_x)
                hx_cost = hx_cost + F.binary_cross_entropy_with_logits(hx_vis, hx_tgt_vis, reduction="none").mean(-1)

                hy_vis = output["hy_vis"][b][:, None, :, 0].expand([-1, n_tgt, -1])   # Left shape: (n_instance, n_tgt, bev_y)
                hy_tgt_vis = tgt["hy_vis"][None, ...].expand([n_pred, -1, -1]).float()  # tgt_vis shape: (n_instance, n_tgt, bev_y)
                hy_cost = hy_cost + F.binary_cross_entropy_with_logits(hy_vis, hy_tgt_vis, reduction="none").mean(-1)

                # Row
                hx_row = output["hx_row"][b][:, None, ...].expand([-1, n_tgt, -1, -1]).permute(0, 1, 3, 2)    # Left shape: (n_instance, n_tgt, bev_x, bev_y)
                hx_tgt_row = tgt["hx_row"][None, ...].expand([n_pred, -1, -1])  # Left shape: (n_instance, n_tgt, bev_x)
                hx_cost_row_all = F.cross_entropy(hx_row.flatten(0, 2), hx_tgt_row.flatten(), reduction="none")\
                    .reshape(n_pred, n_tgt, -1)   # hx_cost_row_all shape: (n_instance, n_tgt, bev_x)
                hx_cost = hx_cost + (hx_cost_row_all * hx_tgt_vis).sum(-1) / hx_tgt_vis.sum(-1)
                
                hy_row = output["hy_row"][b][:, None, ...].expand([-1, n_tgt, -1, -1])    # Left shape: (n_instance, n_tgt, bev_y, bev_x)
                hy_tgt_row = tgt["hy_row"][None, ...].expand([n_pred, -1, -1])  # Left shape: (n_instance, n_tgt, bev_y)
                hy_cost_row_all = F.cross_entropy(hy_row.flatten(0, 2), hy_tgt_row.flatten(), reduction="none")\
                    .reshape(n_pred, n_tgt, -1)   # hy_cost_row_all shape: (n_instance, n_tgt, bev_y)
                hy_cost = hy_cost + (hy_cost_row_all * hy_tgt_vis).sum(-1) / hy_tgt_vis.sum(-1)

                # reg
                for reg_key in ["hx_reg_y", "hx_reg_z"]:
                    reg = (output[reg_key][b][:, None, ...].expand([-1, n_tgt, -1, -1]).permute(0, 1, 3, 2)\
                        .gather(-1, hx_tgt_row[..., None])[..., 0])   # Left shape: (n_instance, n_tgt, bev_x)
                    tgt_reg = tgt[reg_key][None, ...].expand([n_pred, -1, -1])   # Left shape: (n_instance, n_tgt, reg_x)
                    hx_cost = hx_cost + (F.l1_loss(reg, tgt_reg) * hx_tgt_vis).sum(-1) / hx_tgt_vis.sum(-1)

                for reg_key in ["hy_reg_x", "hy_reg_z"]:
                    reg = (output[reg_key][b][:, None, ...].expand([-1, n_tgt, -1, -1])\
                        .gather(-1, hy_tgt_row[..., None])[..., 0])   # Left shape: (n_instance, n_tgt, reg_x)
                    tgt_reg = tgt[reg_key][None, ...].expand([n_pred, -1, -1])   # Left shape: (n_instance, n_tgt, reg_y)
                    hy_cost = hy_cost + (F.l1_loss(reg, tgt_reg) * hy_tgt_vis).sum(-1) / hy_tgt_vis.sum(-1)

                hx_cost = torch.nan_to_num(hx_cost.detach().cpu())
                hx_assign = linear_sum_assignment(hx_cost)
                hx_sort_idx = np.argsort(hx_assign[1])
                hx_assign = (hx_assign[0][hx_sort_idx], hx_assign[1][hx_sort_idx])
                hx_valid_point_num = tgt['hx_vis'].float().sum(dim = 1)

                hy_cost = torch.nan_to_num(hy_cost.detach().cpu())
                hy_assign = linear_sum_assignment(hy_cost)
                hy_sort_idx = np.argsort(hy_assign[1])
                hy_assign = (hy_assign[0][hy_sort_idx], hy_assign[1][hy_sort_idx])
                hy_valid_point_num = tgt['hy_vis'].float().sum(dim = 1)

                use_hx_flag = (hx_valid_point_num >= hy_valid_point_num).detach().cpu().numpy()
                hx_assign = (hx_assign[0][use_hx_flag], hx_assign[1][use_hx_flag])
                use_hy_flag = (hx_valid_point_num < hy_valid_point_num).detach().cpu().numpy()
                hy_assign = (hy_assign[0][use_hy_flag], hy_assign[1][use_hy_flag])
                indices.append(hx_assign + hy_assign)

        if not self.cfg.use_dual_dire_head:
            return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
        else:
            return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64), torch.as_tensor(k, dtype=torch.int64),\
                torch.as_tensor(h, dtype=torch.int64)) for i, j, k, h in indices]

    def loss(self, output, targets):
        n_lane = 0
        
        if 'hungarian_match' in self.cfg.keys() and self.cfg['hungarian_match'] == False:
            indices = []
            for b_idx in range(len(targets)):
                if targets[b_idx]['cls'].shape[0] < 15:
                    valid_tgt_len = targets[b_idx]['cls'].shape[0]
                else:
                    valid_tgt_len = 15
                indice = torch.as_tensor(range(valid_tgt_len), dtype=torch.int64) 
                indices.append((indice, indice))   
        else:
            indices = self.match(output, targets)

        if not self.cfg.use_dual_dire_head:
            loss_dict = {k: 0.0 for k in ["obj", "cls", "vis", "row", "reg_x", "reg_z"]}

            for b, (i, j) in enumerate(indices):
                obj = output["obj"][b].flatten()
                tgt_obj = torch.zeros_like(obj)
                tgt_obj[i] = 1.0
                loss_dict["obj"] += F.binary_cross_entropy_with_logits(obj, tgt_obj, reduction="sum")

                if 'scls_head' in self.cfg.keys() and self.cfg['scls_head']:
                    cls_tgt_row = targets[b]["row"][j][:, None, :, None].expand(-1, self.cfg.n_classes, -1, -1)  # Left shape: (n_tgt, n_cls, H, 1)
                    out_cls = output["cls"][b].reshape(self.cfg.n_instances, self.cfg.n_classes, self.cfg.y_steps.shape[0], \
                        self.cfg.x_steps.shape[0])[i]   # Left shape: (n_tgt, n_cls, H, W)
                    out_cls = torch.gather(out_cls, -1, cls_tgt_row).squeeze(-1)    # Left shape: (n_tgt, n_cls, H)
                    out_cls = self.cls_head_mix(out_cls).squeeze(-1)    # Left shape: (n_tgt, n_cls)
                    tgt_cls = targets[b]["cls"][j]
                    loss_dict["cls"] += F.cross_entropy(out_cls, tgt_cls, reduction="sum")
                else:
                    cls = output["cls"][b].reshape(-1, self.cfg.n_classes)[i]   # output["cls"][b] shape: (n_instance, n_cls)
                    tgt_cls = targets[b]["cls"][j]
                    loss_dict["cls"] += F.cross_entropy(cls, tgt_cls, reduction="sum")

                vis = output["vis"][b, i].flatten(1)
                tgt_vis = targets[b]["vis"][j]
                loss_vis = F.binary_cross_entropy_with_logits(vis, tgt_vis.float(), reduction="none")
                loss_dict["vis"] += loss_vis.mean(-1).sum()

                row = output["row"][b, i]   # Left shape: (n_tgt, H, W)
                tgt_row = targets[b]["row"][j]
                N, H, W = row.shape
                loss_row = F.cross_entropy(row.reshape(-1, W), tgt_row.flatten(), reduction="none").reshape([N, H])
                loss_dict["row"] += ((loss_row * tgt_vis).sum(-1) / tgt_vis.sum(-1)).sum()

                for reg_key in ["reg_x", "reg_z"]:
                    reg = output[reg_key][b, i].gather(-1, tgt_row[..., None])[..., 0]
                    tgt_reg = targets[b][reg_key][j]
                    loss_reg = F.l1_loss(reg, tgt_reg, reduction="none")
                    loss_dict[reg_key] += ((loss_reg * tgt_vis).sum(-1) / tgt_vis.sum(-1)).sum()

                n_lane += len(j)
        else:
            loss_dict = {k: 0.0 for k in ["hx_obj", "hx_cls", "hx_vis", "hx_row", "hx_reg_y", "hx_reg_z", "hy_obj", "hy_cls", "hy_vis",\
                "hy_row", "hy_reg_x", "hy_reg_z", ]}

            for b, (hx_i, hx_j, hy_i, hy_j) in enumerate(indices):
                hx_obj = output["hx_obj"][b].flatten()    # Left shape: (n_instance,)
                hx_tgt_obj = torch.zeros_like(hx_obj) # Left shape: (n_instance,)
                hx_tgt_obj[hx_i] = 1.0
                loss_dict["hx_obj"] += F.binary_cross_entropy_with_logits(hx_obj, hx_tgt_obj, reduction="sum")

                hy_obj = output["hy_obj"][b].flatten()    # Left shape: (n_instance,)
                hy_tgt_obj = torch.zeros_like(hy_obj) # Left shape: (n_instance,)
                hy_tgt_obj[hy_i] = 1.0
                loss_dict["hy_obj"] += F.binary_cross_entropy_with_logits(hy_obj, hy_tgt_obj, reduction="sum")

                hx_cls = output["hx_cls"][b].reshape(-1, self.cfg.n_classes)[hx_i]   # output["cls"][b] shape: (n_instance, n_cls)
                hx_tgt_cls = targets[b]["cls"][hx_j]
                loss_dict["hx_cls"] += F.cross_entropy(hx_cls, hx_tgt_cls, reduction="sum")

                hy_cls = output["hy_cls"][b].reshape(-1, self.cfg.n_classes)[hy_i]   # output["cls"][b] shape: (n_instance, n_cls)
                hy_tgt_cls = targets[b]["cls"][hy_j]
                loss_dict["hy_cls"] += F.cross_entropy(hy_cls, hy_tgt_cls, reduction="sum")
                
                hx_vis = output["hx_vis"][b, hx_i].flatten(1)
                hx_tgt_vis = targets[b]["hx_vis"][hx_j]
                hx_loss_vis = F.binary_cross_entropy_with_logits(hx_vis, hx_tgt_vis.float(), reduction="none")
                loss_dict["hx_vis"] += hx_loss_vis.mean(-1).sum()

                hy_vis = output["hy_vis"][b, hy_i].flatten(1)
                hy_tgt_vis = targets[b]["hy_vis"][hy_j]
                hy_loss_vis = F.binary_cross_entropy_with_logits(hy_vis, hy_tgt_vis.float(), reduction="none")
                loss_dict["hy_vis"] += hy_loss_vis.mean(-1).sum()
                
                hx_row = output["hx_row"][b, hx_i]   # Left shape: (n_tgt, bev_y, bev_x)
                hx_tgt_row = targets[b]["hx_row"][hx_j] # Left shape: (n_tgt, bev_x)
                n_tgt, bev_y, bev_x = hx_row.shape
                hx_loss_row = F.cross_entropy(hx_row.permute(0, 2, 1).reshape(-1, bev_y), hx_tgt_row.flatten(), reduction="none").reshape([n_tgt, bev_x])
                hx_row_loss = (hx_loss_row * hx_tgt_vis).sum(-1) / hx_tgt_vis.sum(-1)
                loss_dict["hx_row"] += hx_row_loss.sum()

                hy_row = output["hy_row"][b, hy_i]   # Left shape: (n_tgt, bev_y, bev_x)
                hy_tgt_row = targets[b]["hy_row"][hy_j] # Left shape: (n_tgt, bev_y)
                n_tgt, bev_y, bev_x = hy_row.shape
                hy_loss_row = F.cross_entropy(hy_row.view(-1, bev_x), hy_tgt_row.flatten(), reduction="none").reshape([n_tgt, bev_y])
                hy_row_loss = (hy_loss_row * hy_tgt_vis).sum(-1) / hy_tgt_vis.sum(-1)
                loss_dict["hy_row"] += hy_row_loss.sum()

                for reg_key in ["hx_reg_y", "hx_reg_z"]:
                    hx_reg = output[reg_key][b, hx_i].permute(0, 2, 1).gather(-1, hx_tgt_row[..., None])[..., 0]
                    hx_tgt_reg = targets[b][reg_key][hx_j]
                    hx_loss_reg = F.l1_loss(hx_reg, hx_tgt_reg, reduction="none")
                    loss_dict[reg_key] += ((hx_loss_reg * hx_tgt_vis).sum(-1) / hx_tgt_vis.sum(-1)).sum()

                for reg_key in ["hy_reg_x", "hy_reg_z"]:
                    hy_reg = output[reg_key][b, hy_i].gather(-1, hy_tgt_row[..., None])[..., 0]
                    hy_tgt_reg = targets[b][reg_key][hy_j]
                    hy_loss_reg = F.l1_loss(hy_reg, hy_tgt_reg, reduction="none")
                    loss_dict[reg_key] += ((hy_loss_reg * hy_tgt_vis).sum(-1) / hy_tgt_vis.sum(-1)).sum()
                
                n_lane = n_lane + len(hx_j) + len(hy_j)

        n_lane = dist.reduce_mean(torch.Tensor([n_lane]).cuda()).item()
        
        for k in loss_dict:
            loss_dict[k] /= max(n_lane, 1)

        return loss_dict

    def post_process(self, output):
        if not self.cfg.use_dual_dire_head:
            score = output["obj"].reshape(-1, self.cfg.n_instances).sigmoid()
            if 'scls_head' in self.cfg.keys() and self.cfg['scls_head']:
                cls_row = output['row'][:, :, None].expand(-1, -1, self.cfg.n_classes, -1, -1).argmax(-1).unsqueeze(-1) # Left shape: (bs, n_instance, n_cls, H, 1)
                out_cls = output['cls'].view(output['cls'].shape[0], self.cfg.n_instances, self.cfg.n_classes, output['cls'].shape[-2], output['cls'].shape[-1])
                out_cls = torch.gather(out_cls, -1, cls_row).squeeze(-1)    # Left shape: (bs, n_instance, n_cls, H)
                out_cls = self.cls_head_mix(out_cls).squeeze(-1) # Left shape: (bs, n_instance, n_cls)
                category = out_cls.argmax(-1)    # Left shape: (bs, n_instance)
            else:
                category = output["cls"].reshape(-1, self.cfg.n_instances, self.cfg.n_classes).argmax(-1)   # Left shape: (bs, n_instance)
            vis = output["vis"].flatten(-2) > 0
            batch_pred = []
            for i in range(len(score)): # bs index
                pred = dict(
                    lanes=[],
                    scores=[],
                    categories=[],
                )
                for j in range(self.cfg.n_instances):   # instance index
                    if hasattr(self.cfg.post_process, "score_thresh") and score[i, j] < self.cfg.post_process.score_thresh:
                        continue
                    v = vis[i, j].cpu().detach().numpy()
                    if getattr(self.cfg.post_process, "pad_end", False):
                        v_idx = np.where(v)[0]
                        if len(v_idx) > 0:
                            v[v_idx[0] :] = True
                    rid = output["row"][i, j][v].argmax(-1).detach().cpu().numpy()
                    cid = np.arange(len(rid))
                    x = self.cfg.x_steps[rid]
                    x = x + output["reg_x"][i, j][v].detach().cpu().numpy()[cid, rid] * self.cfg.grid_size[0]
                    y = self.cfg.y_steps[v]
                    z = output["reg_z"][i, j][v].detach().cpu().numpy()[cid, rid]
                    z = z * (self.cfg.z_range[1] - self.cfg.z_range[0]) + self.cfg.z_range[0]
                    pred_lane = np.stack([x, y, z], axis = -1)  # Left shape: (num_points, 3)

                    pred["lanes"].append(pred_lane)
                    pred["scores"].append(score[i, j].item())
                    pred["categories"].append(category[i, j].item())
                batch_pred.append(pred)
        else:
            B, n_pred, bev_y, bev_x = output['hx_row'].shape
            
            # hx confidence
            hx_score = output["hx_obj"].reshape(B, n_pred).sigmoid()   # Left shape: (B, n_instance)
            hx_vis = output["hx_vis"].flatten(-2) > 0 # (B, n_instance, bev_x)
            hx_category = output["hx_cls"].reshape(-1, self.cfg.n_instances, self.cfg.n_classes).argmax(-1)

            # hy confidence
            hy_score = output["hy_obj"].reshape(B, n_pred).sigmoid()   # Left shape: (B, n_instance)
            hy_vis = output["hy_vis"].flatten(-2) > 0 # (B, n_instance, bev_y)
            hy_category = output["hy_cls"].reshape(-1, self.cfg.n_instances, self.cfg.n_classes).argmax(-1)

            batch_pred = []
            for i in range(len(hy_score)): # bs index
                pred = dict(
                    lanes=[],
                    scores=[],
                    categories=[],
                )
                # hy predictions
                for j in range(self.cfg.n_instances):   # instance index
                    if hasattr(self.cfg.post_process, "score_thresh") and hy_score[i, j] < self.cfg.post_process.score_thresh:
                        continue
                    hy_v = hy_vis[i, j].cpu().detach().numpy()
                    if hy_v.sum() < 2: continue
                    hy_rid = output["hy_row"][i, j][hy_v].argmax(-1).detach().cpu().numpy()  # Row ID (select along the x-axis)
                    hy_cid = np.arange(len(hy_rid))
                    hy_x_ref = self.cfg.x_steps[hy_rid]
                    hy_x = hy_x_ref + output["hy_reg_x"][i, j][hy_v].detach().cpu().numpy()[hy_cid, hy_rid] * self.cfg.grid_size[0]
                    hy_y = self.cfg.y_steps[hy_v]
                    hy_z = output["hy_reg_z"][i, j][hy_v].detach().cpu().numpy()[hy_cid, hy_rid]
                    hy_z = hy_z * (self.cfg.z_range[1] - self.cfg.z_range[0]) + self.cfg.z_range[0]
                    pred_lane = np.stack([hy_x, hy_y, hy_z]).T
                    pred["lanes"].append(pred_lane)
                    pred["scores"].append(hy_score[i, j].item())
                    pred["categories"].append(hy_category[i, j].item())
                # hx predictions
                for j in range(self.cfg.n_instances):   # instance index
                    if hasattr(self.cfg.post_process, "score_thresh") and hx_score[i, j] < self.cfg.post_process.score_thresh:
                        continue
                    hx_v = hx_vis[i, j].cpu().detach().numpy()
                    if hx_v.sum() < 2: continue
                    hx_rid = output["hx_row"][i, j].transpose(1, 0)[hx_v].argmax(-1).detach().cpu().numpy()  # Row ID (select along the y-axis)        
                    hx_cid = np.arange(len(hx_rid))
                    hx_y_ref = self.cfg.y_steps[hx_rid]
                    hx_y = hx_y_ref + output["hx_reg_y"][i, j].transpose(1, 0)[hx_v].detach().cpu().numpy()[hx_cid, hx_rid] * self.cfg.grid_size[1]
                    hx_x = self.cfg.x_steps[hx_v]
                    hx_z = output["hx_reg_z"][i, j].transpose(1, 0)[hx_v].detach().cpu().numpy()[hx_cid, hx_rid]
                    hx_z = hx_z * (self.cfg.z_range[1] - self.cfg.z_range[0]) + self.cfg.z_range[0]  
                    pred_lane = np.stack([hx_x, hx_y, hx_z]).T
                    pred["lanes"].append(pred_lane)
                    pred["scores"].append(hx_score[i, j].item())
                    pred["categories"].append(hx_category[i, j].item())
         
                batch_pred.append(pred)
        
        return batch_pred

    def forward(self, x):
        bs = x.shape[0]
        if 'pos_emb' in self.cfg.keys() and self.cfg['pos_emb']:
            x = x + self.pos_emb.weight[None].view(1, self.cfg['in_channels'], len(self.cfg['y_steps']), len(self.cfg['x_steps']))

        feat = self.share_conv(x)

        if self.cfg.use_dual_dire_head == False:
            output = {k: m(feat) for k, m in self.branch_head.items()}
        else:
            hy_feat, hx_feat = torch.split(feat, self.cfg.share_conv_channel, dim = 1)
            hy_output = {k: m(hy_feat) for k, m in self.hy_branch_head.items()}
            hx_output = {k: m(hx_feat) for k, m in self.hx_branch_head.items()}
            output = dict(hy_output, **hx_output)
        return output


class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, drop=0.1):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.drop1 = nn.Dropout(drop)
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop2 = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop1(x)
        x = self.fc2(x)
        x = self.drop2(x)
        return x


class BEVLaneDet(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg

        self.img_backbone = self._configure_img_backbone()
        self.img_neck = self._configure_img_neck()
        self.view_transformer = self._configure_view_transformer()
        if hasattr(cfg, "bev_pe"):
            self.bev_pe = self._configure_bev_pe()
        self.bev_backbone = self._configure_bev_backbone()
        self.bev_neck = self._configure_bev_neck()
        self.lanedet_head = self._configure_lanedet_head()

    def _configure_img_backbone(self):
        if self.cfg.img_backbone['type'] == 'VoVNet':
            m = VoVNet(
                spec_name = self.cfg.img_backbone.spec_name, 
                norm_eval=self.cfg.img_backbone.norm_eval, 
                frozen_stages=self.cfg.img_backbone.frozen_stages, 
                input_ch=self.cfg.img_backbone.input_ch, 
                out_features=self.cfg.img_backbone.out_features,
                pretrained = self.cfg.img_backbone.pretrained,
            )
        elif self.cfg.img_backbone['type'] == 'ConvNext-Base':
            m = convnext_base(pretrained = True, in_22k = True, out_indices = self.cfg.img_backbone['out_indices'])
        elif self.cfg.img_backbone['type'] == 'EfficientNet-B7':
            m = EfficientNet(architecture = 'EfficientNet-B7', lv6 = True, lv5=True, lv4=True, lv3=True)
        else:
            m = build_backbone(self.cfg.img_backbone)
            m.init_weights()
            
        return m

    def _configure_img_neck(self):
        m = build_neck(self.cfg.img_neck)
        m.init_weights()
        return m

    def _configure_view_transformer(self):
        if self.cfg.view_transformer.type == "LSS":
            return LSSViewTransformer(self.cfg.view_transformer)
        else:
            raise NotImplementedError

    def _configure_bev_pe(self):
        if self.cfg.bev_pe.type == "learnable":
            return Mlp(2, out_features=self.cfg.lanedet_head.in_channels)
        else:
            raise NotImplementedError

    def _configure_bev_backbone(self):
        if self.cfg.bev_backbone.type == "BasicLayers":
            return BasicLayers(self.cfg.bev_backbone)
        m = build_backbone(self.cfg.bev_backbone)
        m.init_weights()
        return m

    def _configure_bev_neck(self):
        m = build_neck(self.cfg.bev_neck)
        m.init_weights()
        return m

    def _configure_lanedet_head(self):
        if self.cfg.lanedet_head.type == "RowAnchorHead":
            return RowAnchorHead(self.cfg.lanedet_head)
        else:
            raise NotImplementedError

    def forward(self, batch):
        x = batch["img"]    # x shape: ()

        '''global total_time
        global total_cnt
        start_time = time.time()'''
        x = self.img_backbone(x)
        
        if self.cfg.img_backbone['type'] == 'EfficientNet-B7':
            x = x[1:]
        
        x = self.img_neck(x)[0]
        x = self.view_transformer(x, batch["frustum_coord"])
        x = [x, *self.bev_backbone(x)]
        x = self.bev_neck(x)[0]

        '''total_time = total_time + (time.time() - start_time)
        total_cnt += 1
        print("cnt:{}, mean time: {}".format(total_cnt, total_time / total_cnt))'''

        if hasattr(self, "bev_pe"):
            x = x + self.bev_pe(batch["bev_coord"]).permute(0, 3, 1, 2)
        output = self.lanedet_head(x)
        if self.training:
            return self.lanedet_head.loss(output, batch["targets"])
        else:
            return self.lanedet_head.post_process(output)
