# Inter-Task Dynamics Learning Dynamics in Deep Linear Multi-Task Networks

This repository contains the official implementation of the paper _Inter-Task Dynamics in Deep Linear Multi-Task Networks._
The repository will be published on Github upon acceptance. 

## Content
Following files are included in the repository:
* **Requierments.py**: Contains the specification of the dependencies needed to run the experiments and code to install these requierments in Python. 
* **Datasets.py**: Contains the code to produce the synthetic Multi-Task Datasets used in the experiments of the paper.
* **Analytical_MTL.py**: Contains code to simulate the analytical solutions for QQ(t) described in the paper. 
* **Balanced_Inits.py**: Contains a function which ensures balanced initialisation. 
* **MTL_Neural_Networks.py**: Code that creates the deep linear Multi-Task Networks. 
* **Train_MTL_ANA.py**: The main function of the repositor, runs both analytical and actual-neural networks for a specific dataset. More information about this function is provided below. 
* **Figures_ITD_MTL.ipynb**: Notebook containing the code to reproduce the same kind of Figures as those presented in the main paper.


## Run Experiments

To run experiments, run **Train_MTL_ANA.py** and choose the correct parameters as explained below. 

**Parameters Multi-Task Model**:
* input_dim : sets the input dimension (default=10).
* task1_dim : sets the dimension of task 1 (default=20).
* task2_dim: sets  the dimension of task 2 (default=20).
* shared_dim: set the shared dimension of the hidden shared layer. Watch out, for some experiments the possible values of shared_dim are restricted (See appendix D). (default=5)
* sigmaweights: corresponds to the variance of the weights at initialisation (default=0.01.)
* learning_rate: sets the learning rate for the analytical solution and the optimizer of the neural network (default=0.01).
* batchsize: sets the number of datapoints which are generated and used in the simulation/training (default=100).
* epochs: sets the number of epochs (default=150).
*TW2: changes the task-specific loss weight for task two (default=2).
* deeperMTL: determines whether to use a deeper linear MTL with more than two layers (default=False)
* nshared: sets the number of shared layers. Value higher than one is only possible if deeperMTL=True (default=1). 
* ntaskL: sets the number of task-specific layers for each task. Value higher than one is only possible if deeperMTL=True (default=1)


**Parameters concerning the type learning statistics which is collected**:
* seed: fixes the seed for reproducability (default=300).
* rememberweights: determines whether the weights are collected for every epoch or not (default=True).
* onlyfinalweights: determines whether only the weights in the last epoch are collected (default=False).
* onlyNN: determines if only the data from the Neural Network is collected (and not the analytical solutions) (default=False).

**Parameters concerning the specific dataset used for the experiment**

The handcrafted tasks: 
* For the perfectly Aligned tasks: set ali_sim=True (and ortho_sim to False)
* For the perfectly Orthogonal tasks: set ortho_sim=True (and ali_sim to False)
* Choose the variance of the tasks by adjusting sigmaX (default=1)

The Teacher-Student tasks:
* Set teacher_student=True (default=False)
* For the Aligned tasks set: aligned=True (default=True)
* For the conflicting tasks set: aligned=False (default=True)
* Choose the variance of the tasks by adjusting sigmaweights_TS (default=0.025)

The Random Regression tasks:
* set random_regression=True (default=False)
* Scale_1: determines the scale of task one (alpha_1) in the paper (default=1)
* Scale_2: determines the scale of task 2 (alpha_2) in the paper (default=1)
* aligned=False (default=True)

Multi-Mnist tasks:
* set Mnist=True (default=False)
* aligned=True: for baseline experiments. The shift range will be set to default=4. 
* aligne=False: for permuted experiments. The shift range will still be set to default=4. 