import torch
import torch.nn as nn
from sparsemax.utils import flatten_all_but_nth_dim, unflatten_all_but_nth_dim


class Sparsemax(nn.Module):
    __constants__ = ["dim"]

    def __init__(self, dim=-1):
        """
        Sparsemax class as seen in https://arxiv.org/pdf/1602.02068.pdf
        Parameters
        ----------
        dim: The dimension we want to cast the operation over. Default -1
        """
        super(Sparsemax, self).__init__()
        self.dim = dim

    def __setstate__(self, state):
        self.__dict__.update(state)
        if not hasattr(self, "dim"):
            self.dim = None

    def forward(self, input):
        return SparsemaxFunction.apply(input, self.dim)

    def extra_repr(self):
        return f"dim={self.dim}"


class SparsemaxFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input: torch.Tensor, dim: int = -1):
        input_dim = input.dim()
        if input_dim <= dim or dim < -input_dim:
            raise IndexError(
                f"Dimension out of range (expected to be in range of [-{input_dim}, {input_dim - 1}], but got {dim})"
            )

        # Save operating dimension to context
        ctx.needs_reshaping = input_dim > 2
        ctx.dim = dim

        if ctx.needs_reshaping:
            ctx, input = flatten_all_but_nth_dim(ctx, input)

        # Translate by max for numerical stability
        # print('input',input)
        # print('check')

        input_max = input.max(-1, keepdim=True).values.expand_as(input)
        input_max_check =  torch.nonzero(input_max.isinf())
        # print('input max infy', input_max.isinf())
        # print('input max', input_max[input_max_check])

        input = input - input.max(-1, keepdim=True).values.expand_as(input)

        zs = input.sort(-1, descending=True).values
        range = torch.arange(1, input.size()[-1] + 1)
        range = range.expand_as(input).to(input)

        # Determine sparsity of projection
        bound = 1 + range * zs
        is_gt = bound.gt(zs.cumsum(-1)).type(input.dtype)
        k = (is_gt * range).max(-1, keepdim=True).values

        # Compute threshold
        zs_sparse = is_gt * zs

        # Compute taus
        taus = (zs_sparse.sum(-1, keepdim=True) - 1) / k
        taus = taus.expand_as(input)

        output = torch.max(torch.zeros_like(input), input - taus) 

        # Save context
        ctx.save_for_backward(output, taus, input_max, input, zs_sparse)
        # print('taus infy', taus[torch.nonzero(taus.isinf())])
        # print('output infy', output[torch.nonzero(output.isinf())])
        # Reshape back to original shape
        if ctx.needs_reshaping:
            ctx, output = unflatten_all_but_nth_dim(ctx, output)

        return output

    @staticmethod
    def backward(ctx, grad_output):
        # output, *_ = ctx.saved_tensors
        output, taus, input_max, input, zs_sparse = ctx.saved_tensors
        # Reshape if needed
        if ctx.needs_reshaping:
            ctx, grad_output = flatten_all_but_nth_dim(ctx, grad_output)

        # Compute gradient
        nonzeros = torch.ne(output, 0)

        # print('input check', input[:10])
        # print('input_max check', input_max[:10])
        # print('zs_sparse check', zs_sparse[:10])
        # print('taus check', taus[:10])
        # print('output check', output[:10])
        # print('grad output', grad_output)
        num_nonzeros = nonzeros.sum(-1, keepdim=True)
        # print('num_nonzeros', num_nonzeros)


        sum = (grad_output * nonzeros).sum(-1, keepdim=True) / num_nonzeros
        # print('sum',sum)
        grad_input = nonzeros * (grad_output - sum.expand_as(grad_output))

        # Reshape back to original shape
        if ctx.needs_reshaping:
            ctx, grad_input = unflatten_all_but_nth_dim(ctx, grad_input)

        return grad_input, None
