Abstract: Deep neural networks have made tremendous gains in achieving human-like intelligence, but still lag behind human competence in strong forms of generalization. One such case is out-of-distribution (OOD) generalization— successful performance on test examples that lie outside the distribution of the training set. Here, we identify how certain properties of processing in the brain can be used to achieve strong OOD generalization and offer a two-part algorithm that improves OOD generalization performance of artificial neural networks on multiple tasks widely used in neuroscience to demonstrate proof of concept. First, we exploit the fact that the mammalian brain represents metric spaces using grid-like representations: abstract representations of relational structure, organized in recurring motifs that cover the representational space. Second, we propose a selectional mechanism that operates over these grid representations using determinantal point process (DPP-S) - a transformation that ensures maximum sparseness in the coverage of that space. We show that a loss function that combines standard task-optimized error with DPP-S can exploit the recurring motifs in grid codes, and can be integrated with common architectures to achieve strong OOD generalization performance on analogy and arithmetic tasks.
Submission Length: Long submission (more than 12 pages of main content)
Changes Since Last Submission: - Changed extrapolation to out-of-distribution (OOD) generalization
- Changed attention to selection
- Changed reasoning module to inference module
- Changed DPP-A to DPP-S
- Clarification improvement for Introduction section
- First two paragraphs added in section 2.2.2
- Change in the narrative, extension to more practical applications as mentioned through slightly updated Abstract, Introduction and Discussion and future directions sections
- Ablation study on the choice of frequency (Section 7.3, Figure 9)
- Baseline using dynamic attention across frequencies (Section 7.4, Figure 10)
- Ablation study on tuning number of grid frequencies $N_f$ (Figure 13)
- Regression formulation (Section 7.6)
- Other miscellaneous changes as mentioned in the response to reviewers
Assigned Action Editor: ~Caglar_Gulcehre1
Submission Number: 746
Loading