import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable


class FocalLoss(nn.Module):

    def __init__(self, gamma=2.0, eps=1e-7):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.eps = eps

    def one_hot(self, index, classes):
        size = index.size() + (classes,)
        view = index.size() + (1,)

        mask = torch.Tensor(*size).fill_(0).to(index.device)

        index = index.view(*view)
        ones = 1.

        if isinstance(index, Variable):
            ones = Variable(torch.Tensor(index.size()).fill_(1).to(index.device))
            mask = Variable(mask, volatile=index.volatile)

        return mask.scatter_(1, index, ones)

    def forward(self, input, target):
        y = self.one_hot(target, input.size(-1))
        logit = F.softmax(input, dim=-1)
        logit = logit.clamp(self.eps, 1. - self.eps)

        loss = -1 * y * torch.log(logit) # cross entropy
        loss = loss * (1 - logit) ** self.gamma # focal loss

        return loss.mean()

def sort_by_indices(features, indices, features_add=None):
    """
        To sort the sparse features with its indices in a convenient manner.
        Args:
            features: [N, C], sparse features
            indices: [N, 4], indices of sparse features
            features_add: [N, C], additional features to sort
    """
    idx = indices[:, 1:]
    idx_sum = idx.select(1, 0) * idx[:, 1].max() * idx[:, 2].max() + idx.select(1, 1) * idx[:, 2].max() + idx.select(1, 2)
    _, ind = idx_sum.sort()
    features = features[ind]
    indices = indices[ind]
    if not features_add is None:
        features_add = features_add[ind]
    return features, indices, features_add

def check_repeat(features, indices, features_add=None, sort_first=True, flip_first=True):
    """
        Check that whether there are replicate indices in the sparse features, 
        remove the replicate features if any.
    """
    if sort_first:
        features, indices, features_add = sort_by_indices(features, indices, features_add)

    if flip_first:
        features, indices = features.flip([0]), indices.flip([0])

    if not features_add is None:
        features_add=features_add.flip([0])

    idx = indices[:, 1:].int()
    idx_sum = torch.add(torch.add(idx.select(1, 0) * idx[:, 1].max() * idx[:, 2].max(), idx.select(1, 1) * idx[:, 2].max()), idx.select(1, 2))
    _unique, inverse, counts = torch.unique_consecutive(idx_sum, return_inverse=True, return_counts=True, dim=0)
    
    if _unique.shape[0] < indices.shape[0]:
        perm = torch.arange(inverse.size(0), dtype=inverse.dtype, device=inverse.device)
        features_new = torch.zeros((_unique.shape[0], features.shape[-1]), device=features.device)
        features_new.index_add_(0, inverse.long(), features)
        features = features_new
        perm_ = inverse.new_empty(_unique.size(0)).scatter_(0, inverse, perm)
        indices = indices[perm_].int()

        if not features_add is None:
            features_add_new = torch.zeros((_unique.shape[0],), device=features_add.device)
            features_add_new.index_add_(0, inverse.long(), features_add)
            features_add = features_add_new / counts
    return features, indices, features_add


def split_voxels(x, b, imps_3d, voxels_3d, kernel_offsets, mask_multi=True, topk=True, threshold=0.5):
    """
        Generate and split the voxels into foreground and background sparse features, based on the predicted importance values.
        Args:
            x: [N, C], input sparse features
            b: int, batch size id
            imps_3d: [N, kernelsize**3], the prediced importance values
            voxels_3d: [N, 3], the 3d positions of voxel centers 
            kernel_offsets: [kernelsize**3, 3], the offset coords in an kernel
            mask_multi: bool, whether to multiply the predicted mask to features
            topk: bool, whether to use topk or threshold for selection
            threshold: float, threshold value
    """
    index = x.indices[:, 0]
    batch_index = index==b
    indices_ori = x.indices[batch_index]
    features_ori = x.features[batch_index]
    mask_voxel = imps_3d[batch_index, -1].sigmoid()
    mask_kernel = imps_3d[batch_index, :-1].sigmoid()

    if mask_multi:
        features_ori *= mask_voxel.unsqueeze(-1)

    if topk:
        _, indices = mask_voxel.sort(descending=True)
        indices_fore = indices[:int(mask_voxel.shape[0]*threshold)]
        indices_back = indices[int(mask_voxel.shape[0]*threshold):]
    else:
        indices_fore = mask_voxel > threshold
        indices_back = mask_voxel <= threshold

    features_fore = features_ori[indices_fore]
    coords_fore = indices_ori[indices_fore]

    mask_kernel_fore = mask_kernel[indices_fore]
    mask_kernel_bool = mask_kernel_fore>=threshold
    voxel_kerels_imp = kernel_offsets.unsqueeze(0).repeat(mask_kernel_bool.shape[0],1, 1)
    mask_kernel_fore = mask_kernel[indices_fore][mask_kernel_bool]
    indices_fore_kernels = coords_fore[:, 1:].unsqueeze(1).repeat(1, kernel_offsets.shape[0], 1)
    indices_with_imp = indices_fore_kernels + voxel_kerels_imp
    selected_indices = indices_with_imp[mask_kernel_bool]
    spatial_indices = (selected_indices[:, 0] >0) * (selected_indices[:, 1] >0) * (selected_indices[:, 2] >0)  * \
                        (selected_indices[:, 0] < x.spatial_shape[0]) * (selected_indices[:, 1] < x.spatial_shape[1]) * (selected_indices[:, 2] < x.spatial_shape[2])
    selected_indices = selected_indices[spatial_indices]
    mask_kernel_fore = mask_kernel_fore[spatial_indices]
    selected_indices = torch.cat([torch.ones((selected_indices.shape[0], 1), device=features_fore.device)*b, selected_indices], dim=1)

    selected_features = torch.zeros((selected_indices.shape[0], features_ori.shape[1]), device=features_fore.device)

    features_fore_cat = torch.cat([features_fore, selected_features], dim=0)
    coords_fore = torch.cat([coords_fore, selected_indices], dim=0)
    mask_kernel_fore = torch.cat([torch.ones(features_fore.shape[0], device=features_fore.device), mask_kernel_fore], dim=0)

    features_fore, coords_fore, mask_kernel_fore = check_repeat(features_fore_cat, coords_fore, features_add=mask_kernel_fore)

    features_back = features_ori[indices_back]
    coords_back = indices_ori[indices_back]

    return features_fore, coords_fore, features_back, coords_back, mask_kernel_fore
