Generalizing Neural Additive Models via Statistical Multimodal Analysis

Published: 01 Feb 2024, Last Modified: 01 Feb 2024Accepted by TMLREveryoneRevisionsBibTeX
Abstract: Interpretable models are gaining increasing attention in the machine learning community, and significant progress is being made to develop simple, interpretable, yet powerful deep learning approaches. Generalized Additive Models (GAM) and Neural Additive Models (NAM) are prime examples. Despite these methods' great potential and popularity in critical applications, e.g., medical applications, they fail to generalize to distributions with more than one mode (multimodal\footnote{In this paper, multimodal refers to the context of distributions, wherein a distribution possesses more than one mode.}). The main reason behind this limitation is that these "all-fit-one" models collapse multiple relationships by being forced to fit the data unimodally. We address this critical limitation by proposing interpretable multimodal network frameworks capable of learning a Mixture of Neural Additive Models (MNAM). The proposed MNAM learns relationships between input features and outputs in a multimodal fashion and assigns a probability to each mode. The proposed method shares similarities with Mixture Density Networks (MDN) while keeping the interpretability that characterizes GAM and NAM. We demonstrate how the proposed MNAM balances between rich representations and interpretability with numerous empirical observations and pedagogical studies. We present and discuss different training alternatives and provided extensive practical evaluation to assess the proposed framework. The code is available at \href{https://github.com/youngkyungkim93/MNAM}{https://github.com/youngkyungkim93/MNAM}.
License: Creative Commons Attribution 4.0 International (CC BY 4.0)
Submission Length: Long submission (more than 12 pages of main content)
Previous TMLR Submission Url: https://openreview.net/forum?id=PCHenaSkyx&referrer=%5BAuthor%20Console%5D(%2Fgroup%3Fid%3DTMLR%2FAuthors%23your-submissions)
Changes Since Last Submission: Dear Editor and Reviewer, Thank you so much for the thoughtful feedback, which was immensely helpful for us in improving our manuscript. In this re-submission, we have addressed the reviewers' comments and incorporate their recommendations. We should note that we misunderstood and miscommunicated some of these comments in the original submission, and now while addressing them we learned a lot and significantly improved the paper. Next, we respond to the reviewers and describe the main changes incorporated in this new and improved version (main changes in the manuscript are highlighted with blue text). We have summarized reviewers' comments due to word limits. ### 1. Reviewer: Implement “Mixture Density Network” (MDN) for baseline Thanks for pointing us to the reference “Mixture Density Network” (MDN), which is very relevant and allowed us to improve the clarity and discussion of our method and contributions. We have now included the results, and compared them with MDN (Section 3.4), we also discussed the connections and differences between the proposed ideas and MDN through the updated version of the manuscript. It is worth noting that MDN's implementation and tunning are more challenging in real-world datasets (as one might anticipate by reading the original MDN publication). This has been documented in various research studies which we know incorporated in our manuscript [Choi et al. (2018); Makansi et al. (2019)]. These challenges might stem from the training algorithm used for MDN, which we illustrate and discuss in Section 3.2. ### 2. Reviewer: More clear explanation on likelihood metrics and fit a global variance on original NAM for comparison among models. A data likelihood metric represents the mean likelihood, and we have included the description in Section 3.1.2. As recommended, we trained a global variance parameter on the original NAM and other deterministic models to compute likelihood scores for comparison, as shown in Table 1 and Table 4. ### 3. Reviewer: Compute Mean Absolute Error (MAE) for pNAM As suggested, we have added a MAE score for pNAM in both Table 1 and Table 4. ### 4. Reviewer: Implement the training algorithm from the original MDN paper and compare with the authors’ algorithm Thank you for the clarification. We have added a comparison among the hard-thresholding, soft-thresholding, and training algorithms from the original MDN paper in Section 3.2. Our findings indicate that the hard-thresholding approach exhibited superior numerical stability and likelihood scores performance compared to other training algorithms. The original MDN paper training algorithm suffered from numerical instability and exhibited mode-collapsing issues (which is a problem previously reported in the literature, as discussed above). Although this algorithm performed reasonably well in representing the mean of the data, it performed poorly in representing multimodal distributions, possibly due to all modes collapsing around the mean. ### 5. Reviewer: Clarification on how author will utilize LIME We have added more clarification in Section 5. In summary, we train a black box model on a classification dataset. Within the LIME framework, MNAMs is the interpretable components approximating the black box model's predictions. The black box model predictions are continuous values ranging from zero to one. ### 6. Reviewer: Explain and clarify what issues papers like Stirn et al. identify training a probabilistic model on heteroskedastic regression We have clarified in Section 5 to address the limitations that could arise with heteroskedastic regression problems. We did mention that the model exhibits a bias toward learning in regions with low variance, as variance serves as the denominator for the gradient on the mean, which is referred to in Stirn et al. paper. Thus, we intend to explore other methods to address this issue in future works. ### Citations: Sungjoon Choi, Kyungjae Lee, Sungbin Lim, and Songhwai Oh. Uncertainty-aware learning from demonstration using mixture density networks with sampling-free variance modeling. In 2018 IEEE International Conference on Robotics and Automation (ICRA), pp. 6915–6922. IEEE, 2018. Osama Makansi, Eddy Ilg, Ozgun Cicek, and Thomas Brox. Overcoming limitations of mixture density networks: A sampling and fitting framework for multimodal future prediction. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 7144–7153, 2019. Andrew Stirn, Harm Wessels, Megan Schertzer, Laura Pereira, Neville Sanjana, and David Knowles. Faithful heteroscedastic regression with neural networks. In Francisco Ruiz, Jennifer Dy, and Jan-Willem van de Meent (eds.), Proceedings of The 26th International Conference on Artificial Intelligence and Statistics, volume 206 of Proceedings of Machine Learning Research, pp. 5593–5613. PMLR, 25–27 Apr 2023. URL https://proceedings.mlr.press/v206/stirn23a.html.
Code: https://github.com/youngkyungkim93/MNAM
Assigned Action Editor: ~Yingzhen_Li1
Submission Number: 1595
Loading