import torch
import time
from multiprocessing import pool

'''
The math in this file is not generally correct for sparse tensors.
But under certain assumptions (such as those that fits the situation in this
paper), the results are correct.
Only use this file after prudent examination of your model. Reading our paper
is strongly recommended.
'''


class _Values(torch.autograd.Function):
    @staticmethod
    def forward(ctx, tensor):
        ctx.save_for_backward(tensor._indices())
        ctx.shape = tensor.shape
        return tensor._values()

    @staticmethod
    def backward(ctx, grad_output):
        indices, = ctx.saved_tensors
        # grad_input = grad_output.clone()
        grad = torch.sparse_coo_tensor(
            indices,
            grad_output,
            ctx.shape
        )
        return grad


_values = _Values.apply


class SparseVectorMul(torch.autograd.Function):
    @staticmethod
    def forward(ctx, tensor, vector, dim):
        indices = tensor._indices()
        values = tensor._values().clone()
        for i, _c in enumerate(vector):
            values[indices[dim] == i] *= _c
        ctx.save_for_backward(tensor, vector)
        ctx.dim = dim
        return torch.sparse_coo_tensor(
            indices, values, tensor.shape
        )

    @staticmethod
    def backward(ctx, grad_output):
        tensor, vector = ctx.saved_tensors
        dim = ctx.dim
        grad_vector = torch.sparse.sum(
            tensor * grad_output,
            [_i for _i in range(tensor.ndim) if _i != dim]
        ).to_dense()
        indices = grad_output._indices()
        grad_values = grad_output._values().clone()
        for i, _c in enumerate(vector):
            grad_values[indices[dim] == i] *= _c
        grad_tensor = torch.sparse_coo_tensor(
            indices, grad_values, grad_output.shape)
        return grad_tensor, grad_vector, None


sparse_vectormul = SparseVectorMul.apply


# class Squeeze(torch.autograd.Function):
#     @staticmethod
#     def forward(ctx, tensor, dim):
#         ctx.dim = dim
#         return tensor.select(dim, 0)
#
#     @staticmethod
#     def backward(ctx, grad_output):
#         return grad_output.unsqueeze(ctx.dim)


# class SparseSparseBatchMatmul(torch.autograd.Function):
#     @staticmethod
#     def bmm(tensor1, tensor2, transpose1, transpose2):
#         assert tensor1.shape[:-2] == tensor2.shape[:-2]
#         if tensor1.ndim == 4:
#             output = torch.stack([
#                 torch.stack([
#                     torch.sparse.mm(__a, __b)
#                     for __a, __b in zip(_a, _b)
#                 ])
#                 for _a, _b in zip(
#                     tensor1.transpose(-1, -2) if transpose1 else tensor1,
#                     tensor2.transpose(-1, -2) if transpose2 else tensor2
#                 )
#             ])
#             return output
#         else:
#             raise NotImplementedError()
#
#     @staticmethod
#     def forward(ctx, tensor1, tensor2):
#         ctx.save_for_backward(tensor1, tensor2)
#         output = SparseSparseBatchMatmul.bmm(tensor1, tensor2, False, False)
#         return output
#
#     @staticmethod
#     def backward(ctx, grad_output):
#         tensor1, tensor2 = ctx.saved_tensors
#         grad1 = SparseSparseBatchMatmul.bmm(
#             grad_output, tensor2, False, True)
#         grad2 = SparseSparseBatchMatmul.bmm(
#             tensor1, grad_output, True, False)
#         return grad1, grad2
#
#
# spbmm = SparseSparseBatchMatmul.apply


def spmatmul(a, b, dim=0):
    if a.is_sparse:
        mm = torch.sparse.mm
        if a.ndim == 2 and b.ndim == 2:
            return mm(a, b)
        elif a.ndim == b.ndim and a.shape[:-2] == b.shape[:-2]:
            return torch.stack([
                spmatmul(_a, _b)
                for _a, _b in zip(a, b)
            ])
        elif b.ndim == 2:
            return torch.stack([
                torch.sparse.sum(sparse_vectormul(
                    a, b[:, i], dim
                ), dim) for i in range(b.shape[1])
            ], dim)
        else:
            raise NotImplementedError()
    else:
        return torch.matmul(a, b)


def sparse_scalarmul(a, c):
    if a.is_sparse:
        return torch.sparse_coo_tensor(
            a._indices(), _values(a) * c, a.shape
        )
    else:
        return a * c


def sparse_select(tensor, dim, index):
    assert tensor.is_sparse
    indices = tensor._indices()
    indexed = indices[dim] == index
    new_indices = indices[:, indexed]
    dim_ind = list(range(tensor.ndim))
    dim_ind.remove(dim)
    return torch.sparse_coo_tensor(
        new_indices[dim_ind],
        _values(tensor)[indexed],
        tensor.shape[:dim]+tensor.shape[dim+1:]
    )


def squeeze(tensor, dim):
    if tensor.is_sparse:
        # return Squeeze.apply(tensor, dim)
        indices = tensor._indices()
        dim_ind = list(range(tensor.ndim))
        dim_ind.remove(dim)
        return torch.sparse_coo_tensor(
            indices[dim_ind],
            _values(tensor),
            tensor.shape[:dim]+tensor.shape[dim+1:]
        )
    else:
        return tensor.squeeze(dim)


def sparse_unsqueeze(tensor, dim):
    indices = tensor._indices()
    return torch.sparse_coo_tensor(
        torch.cat([
            indices[:dim], torch.zeros_like(indices[:1]), indices[dim:]
        ]),
        _values(tensor),
        tensor.shape[:dim]+(1,)+tensor.shape[dim:]
    )


def get_timed_fn(fn):
    def timed_fn(*arg, **kwarg):
        a = time.time()
        output = fn(*arg, **kwarg)
        b = time.time()
        print(fn, "{:.3f}".format(b-a), end=";\t")
        return output
    return timed_fn


# _values = get_timed_fn(_values)
# sparse_vectormul = get_timed_fn(sparse_vectormul)
# spmatmul = get_timed_fn(spmatmul)
# sparse_scalarmul = get_timed_fn(sparse_scalarmul)
# sparse_select = get_timed_fn(sparse_select)
# squeeze = get_timed_fn(squeeze)
# sparse_unsqueeze = get_timed_fn(sparse_unsqueeze)
