Extreme Classification in Log Memory using Count-Min Sketch: A Case Study of Amazon Search with 50M Products
Abstract: Deep Learning algorithms have become ubiquitous solutions to some of the most challenging problems in a myriad of areas. Concurrent to the growth in data volume and dimensionality, there has been a significant rise in the number of labels of interest in classification tasks. This growth on all fronts makes it prohibitively expensive to scale up existing classification deep learning models due to the memory bottleneck in the last layer. For example, a reasonable softmax layer for the dataset of interest in this paper can easily reach 50 billion parameters (> 200 GB memory). To alleviate this problem, we present Merged-Average Classifiers via Hashing (MACH), a generic $K$-classification algorithm where memory provably scales at $O(\log K)$ without any assumption on the relation between classes. MACH is subtly a count-min sketch structure in disguise, which uses universal hashing to reduce classification with a large number of classes to few embarrassingly parallel and independent classification tasks with a small (constant) number of classes. MACH naturally provides a technique for zero communication model parallelism with a large number of classes. We experiment with 6 datasets; some multiclass and some multilabel, and show consistent improvement in precision and recall metrics compared to respective baselines. In particular, we train on a private multilabel dataset sampled from a real product search engine with 70 million queries and 49.46 million documents. We outperform the state-of-the-art extreme classification model Parabel and the standard embedding model by a significant margin. Our training times are 5-7x faster, and our memory footprints are 2-4x smaller than the best baselines.
CMT Num: 7268
Code Link: https://github.com/Tharun24/MACH
0 Replies
Loading