Learning Interpretable Models Using an Oracle

TMLR Paper822 Authors

31 Jan 2023 (modified: 17 Sept 2024)Rejected by TMLREveryoneRevisionsBibTeXCC BY 4.0
Abstract: We look at a specific aspect of model interpretability: models often need to be constrained in size for them to be considered interpretable, e.g., a decision tree of depth 5 is easier to interpret than one of depth 50. But smaller models also tend to have high bias. This suggests a trade-off between interpretability and accuracy. Our work addresses this by: (a) showing that learning a training distribution can often increase accuracy of small models, and therefore may be used as a strategy to compensate for small sizes, and (b) providing a model-agnostic algorithm to learn such training distributions. We also present a surprising artifact: the learned training distribution may be different from the test distribution. We pose the distribution learning problem as one of optimizing parameters for an Infinite Beta Mixture Model based on a Dirichlet Process, so that the held-out accuracy of a model trained on a sample from this distribution is maximized. To make computation tractable, we project the training data onto one dimension: prediction uncertainty scores as provided by a highly accurate oracle model. A Bayesian Optimizer is used for learning the parameters. Empirical results using multiple real world datasets, various oracles and interpretable models with different notions of model sizes, are presented. We observe significant relative improvements in the F1-score in most cases, occasionally seeing improvements greater than $100\%$ over baselines. Additionally we show that the proposed algorithm provides the following benefits: (a) its a framework which allows for flexibility in implementation, (b) it can be used across feature spaces, e.g., we show that the the text classification accuracy of a Decision Tree using character n-grams improves when using a Gated Recurrent Unit as an oracle, which uses a sequence of characters as its input, (c) it can be used to train models that have a non-differentiable training loss, e.g., Decision Trees, and (d) reasonable defaults exist for most parameters of the algorithm, which makes it convenient to use.
Submission Length: Long submission (more than 12 pages of main content)
Assigned Action Editor: ~Manzil_Zaheer1
Submission Number: 822
Loading