from typing import *
from torch import Tensor, LongTensor
from torch.nn import Module

import numpy as np
import random
import time

import torch
import torch.nn as nn
import torch.nn.functional as F


def nonzero_last(t: Tensor) -> Tensor:
    assert 2 == len(t.shape)
    inds = t.new_zeros(t.nonzero()[:, 0].unique().shape[0], 2).long()
    idx = -1
    idx_v = -1
    for x in t.nonzero():
        if x[0] != idx_v:
            idx += 1
            idx_v = x[0]
        inds[idx, 0] = x[0]
        inds[idx, 1] = x[1]
    return inds


def nonzero_first(t: Tensor) -> Tensor:
    assert 2 == len(t.shape)
    inds = t.new_zeros(t.nonzero()[:, 0].unique().shape[0], 2).long()
    idx = -1
    idx_v = -1
    for x in t.nonzero():
        if x[0] != idx_v:
            idx += 1
            idx_v = x[0]
            inds[idx, 0] = x[0]
            inds[idx, 1] = x[1]
    return inds


# Batched index_select
def batched_index_select(t, dim, inds):
    dummy = inds.unsqueeze(2).expand(inds.size(0), inds.size(1), t.size(2))
    out = t.gather(dim, dummy)  # b x e x f
    return out
