from selectors import EpollSelector
import torch
import copy
import torch.nn as nn
import torch.nn.functional as F
from .ops.modules import MSDeformAttnNew


def inverse_sigmoid(x, eps=1e-5):
    x = x.clamp(min=0, max=1)
    x1 = x.clamp(min=eps)
    x2 = (1 - x).clamp(min=eps)
    return torch.log(x1/x2)


def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])


def _get_activation_fn(activation):
    """Return an activation function given a string"""
    if activation == "relu":
        return F.relu
    if activation == "gelu":
        return F.gelu
    if activation == "glu":
        return F.glu
    raise RuntimeError(F"activation should be relu/gelu, not {activation}.")


class BEVDecoderLayer(nn.Module):
    def __init__(self, img_size, d_model, d_bound, d_ffn, dropout, activation, n_levels, n_heads, n_points, n_cameras):
        super().__init__()
        self.n_cameras = n_cameras
        
        self.proj_atten = MSDeformAttnNew(d_model, n_levels, n_heads, n_points, d_bound)
        self.dropout1 = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(d_model)
        
        # self attention
        self.self_atten = nn.MultiheadAttention(d_model, n_heads, dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.norm2 = nn.LayerNorm(d_model)
        
        # ffn
        self.linear1 = nn.Linear(d_model, d_ffn)
        self.activation = _get_activation_fn(activation)
        self.dropout3  =nn.Dropout(dropout)
        self.linear2  =nn.Linear(d_ffn, d_model)
        self.dropout4 = nn.Dropout(dropout)
        self.norm3 = nn.LayerNorm(d_model)
        
        self.img_size = img_size
        self.fuse_view_feats = nn.Linear(d_model * n_cameras, d_model)
        
    @staticmethod
    def with_pos_embed(tensor, pos):
        return tensor if pos is None else tensor + pos
    
    def forward_ffn(self, tgt):
        tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
        tgt = tgt + self.dropout4(tgt2)
        tgt = self.norm3(tgt)
        return tgt
    
    def forward(self, tgt, query_pos, xx, yy, zz, pc_range, img_size, src_views, src_views_with_rayembed, src_spatial_shapes, level_start_index, lidar2imgs, src_padding_mask=None):
        bs, nbins = query_pos.shape[:2]
        n_levels = len(src_views)
        
        # cross attention
        tgt_expand = tgt.unsqueeze(1).expand(-1, self.n_cameras, -1, -1).flatten(0, 1)
        # query_pos_expand = query_pos.unsqueeze(1).expand(-1, self.n_cameras, -1, -1).flatten(0, 1)
        reference_points, mask = coordinate_translation(xx, yy, zz, pc_range, lidar2imgs, img_size)
        reference_points = reference_points.unsqueeze(-2).expand(-1, -1, n_levels, -1)
        src_padding_mask = torch.cat(src_padding_mask, dim=1)
        tgt2 = self.proj_atten(tgt_expand, reference_points, src_views, src_views_with_rayembed, src_spatial_shapes, level_start_index, src_padding_mask)
        tgt2 = tgt2.view(bs * self.n_cameras, nbins, -1)
        tgt2 = tgt2 * mask
        # reduce n_cameras
        tgt2 = tgt2.view(bs, self.n_cameras, nbins, -1).permute(0, 2, 1, 3).contiguous().view(bs, nbins, -1)
        tgt2 = self.fuse_view_feats(tgt2)
        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)
        
        # # self attention
        # q = k = self.with_pos_embed(tgt, query_pos)
        # tgt2 = self.self_atten(q.transpose(0, 1), k.transpose(0, 1), tgt.transpose(0, 1))[0].transpose(0, 1)
        # tgt = tgt + self.dropout2(tgt2)
        # tgt = self.norm2(tgt)
        
        # # ffn
        # tgt = self.forward_ffn(tgt)
        return tgt
        
        
        
def coordinate_translation(xx_cord, yy_cord, zz_cord, pc_range, lidar2imgs, img_size):
    bs, _, _ = zz_cord.shape
    xx_cord = xx_cord.unsqueeze(0).repeat(bs, 1, 1).view(bs, -1, 1)
    yy_cord = yy_cord.unsqueeze(0).repeat(bs, 1, 1).view(bs, -1, 1)
    xx_cord = xx_cord * (pc_range[1] - pc_range[0]) + pc_range[0]
    yy_cord = yy_cord * (pc_range[3] - pc_range[2]) + pc_range[2]
    zz_cord = zz_cord * (pc_range[5] - pc_range[4]) + pc_range[4]
    
    xyz_cord = torch.cat([xx_cord, yy_cord, zz_cord], dim=-1)
    xyz_cord = torch.cat((xyz_cord, torch.ones_like(xyz_cord[..., :1])), -1)
    bs, num_query = xyz_cord.size()[:2]
    n_cameras = lidar2imgs.size(1)
    
    xyz_cord = xyz_cord.view(bs, 1, num_query, 4).repeat(1, n_cameras, 1, 1).unsqueeze(-1)
    lidar2imgs = lidar2imgs.view(bs, n_cameras, 1, 4, 4).repeat(1, 1, num_query, 1, 1)
    # xyz_cord_cam: [bs, n_cameras, n_q, 4]
    xyz_cord_cam = torch.matmul(lidar2imgs, xyz_cord).squeeze(-1)
    xyz_cord_cam = xyz_cord_cam.view(bs * n_cameras, num_query, 4)
    eps = 1e-5
    mask = (xyz_cord_cam[..., 2:3] > eps)
    xyz_cord_cam = xyz_cord_cam[..., 0:2] / torch.maximum(xyz_cord_cam[..., 2:3], torch.ones_like(xyz_cord_cam[..., 2:3]) * eps)
    xyz_cord_cam[..., 0] /= img_size[0]
    xyz_cord_cam[..., 1] /= img_size[1]
    xyz_cord_cam = (xyz_cord_cam - 0.5) * 2
    # 考虑1.1好还是1.0好
    mask = (mask & (xyz_cord_cam[..., 0:1] > -1.)
                 & (xyz_cord_cam[..., 0:1] < 1.)
                 & (xyz_cord_cam[..., 1:2] > -1.)
                 & (xyz_cord_cam[..., 1:2] < 1.))
    # xyz_cord_cam: [bs_n_cameras, n_q, 2]; mask: [bs_n_cameras, n_q, 1]
    return xyz_cord_cam, mask


class BEVDecoder(nn.Module):
    def __init__(self, decoder_layer, num_layers):
        super().__init__()
        self.layers = _get_clones(decoder_layer, num_layers)
        self.num_layers = num_layers
        self.height_embed = None
    
    def forward(self, tgt, xx, yy, zz, pc_range, img_size, src_views, src_views_with_rayembed, lidar2imgs, src_spatial_shapes, src_level_start_index, src_valid_ratios, query_pos=None, src_padding_mask=None):
        output = tgt
        intermediate = []
        intermediate_reference_points = []
        for lid, layer in enumerate(self.layers):
            reference_zz = zz
            # import pdb; pdb.set_trace()
            output = layer(output, query_pos, xx, yy, reference_zz, pc_range, img_size, src_views, src_views_with_rayembed, src_spatial_shapes, src_level_start_index, lidar2imgs, src_padding_mask)
            
            assert self.height_embed != None
            if self.height_embed is not None:
                tmp = self.height_embed[lid](output)
                new_reference_zz = tmp + inverse_sigmoid(reference_zz)
                new_reference_zz = new_reference_zz.sigmoid()
                
                zz = new_reference_zz
                
            intermediate.append(output)
            intermediate_reference_points.append(zz)
        
        return torch.stack(intermediate), torch.stack(intermediate_reference_points)