# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
"""
Implementation of a compressed sparse tensor. Similar in
functionality to TensorFlow's IndexedSlices implementation.
"""

import torch


class SparseTensor(object):
    """ Compressed Sparse Tensor """

    def __init__(self, dense_tensor=None):
        self.orig_dense_tensor = dense_tensor
        self.is_sparse = dense_tensor.is_sparse
        if dense_tensor is not None:
            if dense_tensor.is_sparse:
                dense_tensor = dense_tensor.coalesce()
                self.indices = dense_tensor.indices().flatten()
                self.values = dense_tensor.values()
            else:
                result = torch.sum(dense_tensor, dim=1)
                self.indices = result.nonzero().flatten()
                self.values = dense_tensor[self.indices]
            self.dense_size = list(dense_tensor.size())
        else:
            self.indices = None
            self.values = None
            self.dense_size = None

    def to_coo_tensor(self):
        return torch.sparse_coo_tensor(self.indices.unsqueeze(0), self.values, self.dense_size)

    @staticmethod
    def type():
        return "deepspeed.SparseTensor"

    def to_dense(self):
        it = self.indices.unsqueeze(1)
        full_indices = torch.cat([it for _ in range(self.dense_size[1])], dim=1)
        return self.values.new_zeros(self.dense_size).scatter_add_(0, full_indices, self.values)

    def sparse_size(self):
        index_size = list(self.indices.size())
        index_size = index_size[0]
        value_size = list(self.values.size())
        value_size = value_size[0] * value_size[1]
        dense_size = self.dense_size[0] * self.dense_size[1]
        return index_size + value_size, dense_size

    def add(self, b):
        assert self.dense_size == b.dense_size
        self.indices = torch.cat([self.indices, b.indices])
        self.values = torch.cat([self.values, b.values])

    def __str__(self):
        sparse_size, dense_size = self.sparse_size()
        return "DeepSpeed.SparseTensor(indices_size={}, values_size={}, " \
               "dense_size={}, device={}, reduction_factor={})".format(
            self.indices.size(), self.values.size(), self.dense_size,
            self.indices.get_device(), dense_size / sparse_size
        )

    def __repr__(self):
        return self.__str__()
