# import torch
# import copy
# from torch import nn, Tensor
# import os
#
# import math
# import torch.nn.functional as F
# from torch import nn
#
#
# class MLP(nn.Module):
#     """ Very simple multi-layer perceptron (also called FFN)"""
#
#     def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
#         super().__init__()
#         self.num_layers = num_layers
#         h = [hidden_dim] * (num_layers - 1)
#         self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
#
#     def forward(self, x):
#         for i, layer in enumerate(self.layers):
#             x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
#         return x
#
#
# def inverse_sigmoid(x, eps=1e-5):
#     x = x.clamp(min=0, max=1)
#     x1 = x.clamp(min=eps)
#     x2 = (1 - x).clamp(min=eps)
#     return torch.log(x1/x2)
#
#
# def gen_encoder_output_proposals(memory:Tensor, memory_padding_mask:Tensor, spatial_shapes:Tensor):
#     """
#     Input:
#         - memory: bs, \sum{hw}, d_model
#         - memory_padding_mask: bs, \sum{hw}
#         - spatial_shapes: nlevel, 2
#     Output:
#         - output_memory: bs, \sum{hw}, d_model
#         - output_proposals: bs, \sum{hw}, 4
#     """
#     N_, S_, C_ = memory.shape
#     base_scale = 4.0
#     proposals = []
#     _cur = 0
#     for lvl, (H_, W_) in enumerate(spatial_shapes):
#         mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H_ * W_)].view(N_, H_, W_, 1)
#         valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
#         valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
#
#         grid_y, grid_x = torch.meshgrid(torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device),
#                                         torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device))
#         grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
#
#         scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2)
#         grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
#         wh = torch.ones_like(grid) * 0.05 * (2.0 ** lvl)
#         proposal = torch.cat((grid, wh), -1).view(N_, -1, 4)
#         proposals.append(proposal)
#         _cur += (H_ * W_)
#     output_proposals = torch.cat(proposals, 1)
#     output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True)
#     output_proposals = torch.log(output_proposals / (1 - output_proposals))
#     output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf'))
#     output_proposals = output_proposals.masked_fill(~output_proposals_valid, float('inf'))
#
#     output_memory = memory
#     output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0))
#     output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))
#     return output_memory, output_proposals
#
#
# def gen_sineembed_for_position(pos_tensor):
#     # n_query, bs, _ = pos_tensor.size()
#     # sineembed_tensor = torch.zeros(n_query, bs, 256)
#     scale = 2 * math.pi
#     dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device)
#     dim_t = 10000 ** (2 * (dim_t // 2) / 128)
#     x_embed = pos_tensor[:, :, 0] * scale
#     y_embed = pos_tensor[:, :, 1] * scale
#     pos_x = x_embed[:, :, None] / dim_t
#     pos_y = y_embed[:, :, None] / dim_t
#     pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
#     pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
#     if pos_tensor.size(-1) == 2:
#         pos = torch.cat((pos_y, pos_x), dim=2)
#     elif pos_tensor.size(-1) == 4:
#         w_embed = pos_tensor[:, :, 2] * scale
#         pos_w = w_embed[:, :, None] / dim_t
#         pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)
#
#         h_embed = pos_tensor[:, :, 3] * scale
#         pos_h = h_embed[:, :, None] / dim_t
#         pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2)
#
#         pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
#     else:
#         raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1)))
#     return pos
#
#
# def _get_activation_fn(activation):
#     """Return an activation function given a string"""
#     if activation == "relu":
#         return F.relu
#     if activation == "gelu":
#         return F.gelu
#     if activation == "glu":
#         return F.glu
#     if activation == "prelu":
#         return nn.PReLU()
#     if activation == "selu":
#         return F.selu
#     raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
#
#
# def _get_clones(module, N, layer_share=False):
#     # import ipdb; ipdb.set_trace()
#     if layer_share:
#         return nn.ModuleList([module for i in range(N)])
#     else:
#         return nn.ModuleList([copy.deepcopy(module) for i in range(N)])