# META-PRIOR: META LEARNING FOR ADAPTIVE INVERSE PROBLEM SOLVERS

This folder contains the codes that were used in our submission.

## Requirements

This code relies on the following packages:

```bash
torch
torchvision
numpy
jax
optax
torchvision
matplotlib
scipy
```

Due to some specificities of jax and pytorch, we've used the following commands to install the packages:
```bash
conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia
conda install numpy
python3 -m pip install jaxlib==0.4.10+cuda11.cudnn86 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
python3 -m pip install jax==0.4.10 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
python3 -m pip install jaxopt
```


## Reproducing results

The notebook `SR_test.ipynb` reproduces the results from Figure 4.
The notebook `MRI_test.ipynb` reproduces the results from Figure 6.

## Training models

The models were trained on natural images with the following command:

```bash
python train_imaml_inverse_problem.py --model_name='uPDNet' --results_folder='results_uPDNet_1ch_imeta_120_supervised/' --channels=1 --num_inner_steps=1 --supervised=1 --num_layers=120
```

## Finetuning models

The finetuning can be performed with the following commands. For the MRI problem:
```bash
srun python train_finetune_fastMRI.py --model_name='uPDNetMRI' --num_iter=120 --supervised=1
```

For the SR problem:
```bash
srun python train_finetune_SR.py --model_name='uPDNetSR' --num_iter=120 --supervised=1
```
