import math
import torch
import torch.nn.functional as F
import torch_scatter

__all__ = ['logcosh', 'max_finite', 'sum_finite', 'log_signed',
           'logsumexp_signed', 'logsumexp_signed_signed', 'sumexp_signed', 'logdiffexp',
           'loginvexp', 'trace_batch', 'sinkhorn_normalization', 'logsinkhorn_normalization',
           'scatter', 'segment_coo', 'segment_csr', 'EPSILON', 'repeat_blocks']


EPSILON = 1e-40


def logcosh(x):
    return x + F.softplus(-2. * x) - math.log(2.)


def loginvexp(mat, use_double=False):
    # Offsets seem to only make things worse.
    if use_double:
        mat = mat.double()
    exp_inv = torch.exp(mat).inverse()

    inv, inv_sign = log_signed(exp_inv)
    if use_double:
        inv = inv.float()
        inv_sign = inv_sign.float()

    # Negative entries below machine precision are likely errors.
    # log(2**(-53)) = -36.7, but let's be more lenient than that.
    inv_sign[inv < -100] = 1.

    return inv, inv_sign


def sinkhorn_normalization(mat: torch.FloatTensor, niter: int, mean_one: bool = False, mask: torch.BoolTensor = None):
    batch_size, nrows, ncols = mat.shape
    assert nrows == ncols

    if mean_one:
        target_sum = nrows
    else:
        target_sum = 1

    # Mask for padded tensors
    if mask is not None:
        mat_u = mat.clone().transpose(1, 2)
        mat_u[:, :, 0] += mask

        mat_v = mat.clone()
        mat_v[:, :, 0] += mask

    u = mat.new_ones([batch_size, nrows])
    for _ in range(niter):
        v = torch.clamp(target_sum / torch.einsum("bij, bj -> bi", mat_u, u), max=1e10)
        u = torch.clamp(target_sum / torch.einsum("bij, bj -> bi", mat_v, v), max=1e10)

    return u, v


def logsinkhorn_normalization(mat: torch.FloatTensor, niter: int, mean_one: bool):
    batch_size, nrows, ncols = mat.shape
    assert nrows == ncols

    if mean_one:
        target_sum = math.log(nrows)
    else:
        target_sum = 0

    def M(u, v):
        "Modified cost for logarithmic updates"
        "$M_{ij} = -c_{ij} + u_i + v_j$"
        # clamp to prevent NaN for inf - inf
        if u is None:
            return v[:, None, :] + mat
        elif v is None:
            return u[:, :, None] + mat
        else:
            return u[:, :, None] + v[:, None, :] + mat

    u = mat.new_zeros(batch_size, nrows)
    v = mat.new_zeros(batch_size, nrows)
    for _ in range(niter):
        u = target_sum - torch.logsumexp(M(None, v), dim=-1)
        v = target_sum - torch.logsumexp(M(u, None), dim=-2)

    return u, v


def max_finite(tensor, dim=None):
    mask = torch.isinf(tensor)
    tensor_neginf = torch.masked_fill(tensor, mask, -math.inf)
    if dim is None:
        return tensor_neginf.max()
    else:
        return tensor_neginf.max(dim=dim)


@torch.jit.script
def sum_finite(tensor: torch.Tensor, dim: int):
    mask = ~torch.isfinite(tensor)
    tensor_finite = torch.masked_fill(tensor, mask, 0)
    return torch.sum(tensor_finite, dim)


@torch.jit.script
def log_signed(mat: torch.Tensor):
    mat_neg = mat < 0
    sign = torch.ones_like(mat, device=mat.device, dtype=mat.dtype, layout=mat.layout) - 2 * mat_neg
    mat_log = mat.abs().log()
    return mat_log, sign


@torch.jit.script
def sumexp_signed(mat: torch.Tensor, sign: torch.Tensor, dim: int):
    offset = mat.max(dim).values
    # offset = torch.clamp(offset, min=-1e10, max=1e10)
    mat_e = (mat - offset.unsqueeze(dim)).exp()
    mat_sum = (mat_e * sign).sum(dim)
    # mat_sum += 1e-40
    return mat_sum * offset.exp()


@torch.jit.script
def logsumexp_signed_signed(mat: torch.Tensor, sign: torch.Tensor, dim: int):
    offset = mat.max(dim).values
    offset = torch.clamp(offset, min=-1e10, max=1e10)
    mat_e = (mat - offset.unsqueeze(dim)).exp()
    mat_sum = (mat_e * sign).sum(dim)
    mat_sum += 1e-40
    mat_sum_abs, sign_res = log_signed(mat_sum)
    return mat_sum_abs + offset, sign_res


@torch.jit.script
def logsumexp_signed(mat: torch.Tensor, sign: torch.Tensor, dim: int,
                     check_sign: bool = True):
    offset = mat.max(dim).values
    offset = torch.clamp(offset, min=-1e10, max=1e10)
    mat_e = (mat - offset.unsqueeze(dim)).exp()
    mat_sum = (mat_e * sign).sum(dim)
    mat_sum += 1e-40
    res = mat_sum.abs().log() + offset
    if check_sign:
        assert torch.all(res.exp()[mat_sum < 0] < 1e-7), "Negative values in result. Use `logsumexp_signed_signed`."
    return res


@torch.jit.script
def logdiffexp(tensor1: torch.Tensor, tensor2: torch.Tensor, sign2: torch.Tensor):
    lse_offset = torch.max(tensor1, tensor2)
    diff = torch.exp(tensor1 - lse_offset) - sign2 * torch.exp(tensor2 - lse_offset)
    res_without_offset, res_sign = log_signed(diff + 1e-40)
    return res_without_offset + lse_offset, res_sign


def trace_batch(mat):
    return torch.diagonal(mat, dim1=-2, dim2=-1).sum(-1)


@torch.jit.script
def scatter(src: torch.Tensor, index: torch.Tensor, dim_size: int, dim: int = -1,
            fill_value: float = math.nan, reduce: str = 'sum'):
    if torch.isnan(torch.tensor(fill_value)):
        if reduce == 'max':
            fill_value = -1e38  # torch.finfo(src.dtype).min
        elif reduce == 'min':
            fill_value = 1e38  # torch.finfo(src.dtype).max

    if torch.isnan(torch.tensor(fill_value)):
        return torch_scatter.scatter(src, index, dim=dim, dim_size=dim_size, reduce=reduce)
    else:
        shape = list(src.shape)
        shape[dim] = dim_size
        out = torch.full(shape, fill_value, dtype=src.dtype, device=src.device)
        torch_scatter.scatter(src, index, dim=dim, out=out, reduce=reduce)
        return out


@torch.jit.script
def segment_coo(src: torch.Tensor, index: torch.Tensor, dim_size: int,
                fill_value: float = math.nan, reduce: str = 'sum'):
    if torch.isnan(torch.tensor(fill_value)):
        if reduce == 'max':
            fill_value = -1e38  # torch.finfo(src.dtype).min
        elif reduce == 'min':
            fill_value = 1e38  # torch.finfo(src.dtype).max

    if torch.isnan(torch.tensor(fill_value)):
        return torch_scatter.segment_coo(src, index, dim_size=dim_size, reduce=reduce)
    else:
        shape = src.shape[:-1] + (dim_size,)
        out = torch.full(shape, fill_value, dtype=src.dtype, device=src.device)
        torch_scatter.segment_coo(src, index, out=out, reduce=reduce)
        return out


@torch.jit.script
def segment_csr(src: torch.Tensor, indptr: torch.Tensor,
                fill_value: float = math.nan, reduce: str = 'sum'):
    if torch.isnan(torch.tensor(fill_value)):
        if reduce == 'max':
            fill_value = -1e38  # torch.finfo(src.dtype).min
        elif reduce == 'min':
            fill_value = 1e38  # torch.finfo(src.dtype).max

    if torch.isnan(torch.tensor(fill_value)):
        return torch_scatter.segment_csr(src, indptr, reduce=reduce)
    else:
        out = torch_scatter.segment_csr(src, indptr, reduce=reduce)
        mask = indptr[1:] == indptr[:-1]
        out.masked_fill_(mask, fill_value)
        return out


def repeat_blocks(sizes, repeats, continuous_indexing=True):
    """ Repeat blocks of indices.
    Adapted from https://stackoverflow.com/questions/51154989/numpy-vectorized-function-to-repeat-blocks-of-consecutive-elements
    """
    assert sizes.dim() == 1

    # Remove 0 sizes
    sizes_nonzero = (sizes > 0)
    if not torch.all(sizes_nonzero):
        sizes = torch.masked_select(sizes, sizes_nonzero)
        if isinstance(repeats, torch.Tensor):
            repeats = torch.masked_select(repeats, sizes_nonzero)

    if isinstance(repeats, torch.Tensor):
        insert_dummy = (repeats[0] == 0)
        if insert_dummy:
            one = sizes.new_ones(1)
            sizes = torch.cat((one, sizes))
            repeats = torch.cat((one, repeats))
    else:
        insert_dummy = False

    # Get repeats for each group using group lengths/sizes
    r1 = torch.repeat_interleave(torch.arange(len(sizes), device=sizes.device), repeats)

    # Get total size of output array, as needed to initialize output indexing array
    N = (sizes * repeats).sum()

    # Initialize indexing array with ones as we need to setup incremental indexing
    # within each group when cumulatively summed at the final stage.
    # Two steps here:
    # 1. Within each group, we have multiple sequences, so setup the offsetting
    # at each sequence lengths by the seq. lengths preceding those.
    id_ar = torch.ones(N, dtype=torch.long, device=sizes.device)
    id_ar[0] = 0
    insert_index = sizes[r1[:-1]].cumsum(0)
    insert_val = (1-sizes)[r1[:-1]]

    if continuous_indexing:
        if isinstance(repeats, torch.Tensor) and torch.any(repeats == 0):
            # If a group was skipped (repeats=0) we need to add its size
            diffs = r1[1:] - r1[:-1]
            indptr = torch.cat((sizes.new_zeros(1), diffs.cumsum(0)))
            insert_val += torch_scatter.segment_csr(sizes[:r1[-1]], indptr, reduce='sum')
        else:
            # 2. For each group, make sure the indexing starts from the next group's
            # first element. So, simply assign 1s there.
            insert_val[r1[1:] != r1[:-1]] = 1

    # Assign index-offseting values
    id_ar[insert_index] = insert_val

    if insert_dummy:
        id_ar = id_ar[1:]
        id_ar[0] -= 1

    # Finally index into input array for the group repeated o/p
    res = id_ar.cumsum(0)
    return res
