import numpy as np
class Sp_tensor:
    def __init__(self, coords, values, tensor_size, normalize=False, check_empty=True, sort=True):
        coords = np.array(coords, dtype=np.int64)
        assert len(coords) == len(values), "#coord and #values need to be same"
        if sort:
            sort_keys = tuple( coords[:, col] for col in range(coords.shape[1]-1,-1,-1))
            coords = coords[ np.lexsort(sort_keys) ]
            values = values[ np.lexsort(sort_keys) ]
        assert len(coords) == len(values), "#coord and #values need to be same"
        
        self.coords = np.array(coords, dtype=np.int64)
        self.values = np.array(values, dtype=np.float64)
        self.tensor_size = tensor_size
        self.nnz = len(self.values)
        
        if isinstance(tensor_size, int):
            self.tensor_dim  = 1
        else:
            self.tensor_dim  = len(tensor_size)
            
        if normalize:
            self.normalize()
            
        if check_empty:
            self.see_empty()
        
        if self.tensor_dim == 1:
            self.coord_to_value = { self.coords[n] : self.values[n] for n in range(self.nnz) }
        else:
            self.coord_to_value = { tuple(self.coords[n]) : self.values[n] for n in range(self.nnz) }

    def update_coord_to_value(self):
        if self.tensor_dim == 1:
            self.coord_to_value = { self.coords[n] : self.values[n] for n in range(self.nnz) }
        else:
            self.coord_to_value = { tuple(self.coords[n]) : self.values[n] for n in range(self.nnz) }
    
    def see_empty(self):
        for d in range(self.tensor_dim):
            assert len(np.unique(self.coords[:,d])) == self.tensor_size[d], "empty label exists"
            
    def normalize(self):
        self.values /= np.sum(self.values)


def dense_to_sparse(X):
    tensor_size = np.shape(X)
    coords = np.array([list(index) for index in np.ndindex(X.shape)])
    spt = Sp_tensor(coords, X.reshape(-1), X.shape)
    
    return spt