## SIGUAResolving Training Biases via Influence-based Data Relabeling

This repository is the official implementation of our paper "Resolving Training Biases via Influence-based Data Relabeling"

#### Brief Introduction

RDIA aims to reuse harmful training samples toward better model performance. To achieve
this, we use influence functions to estimate how relabeling a training sample would
affect model’s test performance and further develop a novel relabeling function R to update the labels of identified harmful samples.  Our framework can be shown as the image below:

![image-20210602141323921](./Figure/Framework.png)

#### Requirements

To install requirements:

* Python 3.7

```
pip install -r requirement.txt
```

#### Training & Evaluation

##### Experiments of RDIA

To train the model,  run this command:

```
python RDIA.py --dataset realsim  --noise 0.8 --alpha 0.0002
python RDIA.py --dataset mnist    --noise 0.8 --alpha 0.0002
python RDIA.py --dataset cancer   --noise 0.8 --alpha 0.0002
```

The hyperparameter \alpha should be tuned with dataset and noise rate.  For example, the experiments over **MNIST** are shown below:

|      RDIA       | Noise=0.0 | Noise=0.2 | Noise=0.5 | Noise=0.8 |
| :-------------: | :-------: | :-------: | :-------: | :-------: |
|  \alpha = 0.01  |  0.0235   |  0.0443   |  0.0903   |  0.1009   |
| \alpha = 0.002  |  0.0207   |  0.0315   |  0.0519   |  0.0465   |
| \alpha = 0.0002 |  0.0903   |  0.0392   |  0.0410   |  0.0405   |

| Noise ratio |  Noise=0.2   |  Noise=0.5   |  Noise=0.8   |
| :---------: | :----------: | :----------: | :----------: |
|  Standard   | 35.14 ± 0.44 | 16.97 ± 0.40 | 4.41 ± 0.14  |
| Co-teaching | 43.73 ± 0.16 | 34.96 ± 0.50 | 15.15 ± 0.46 |
|    SIGUA    | 45.52 ± 0.21 | 35.28 ± 0.42 | 14.31 ± 0.02 |
|    Ours     | 50.24 ± 0.15 | 40.50 ± 0.17 | 20.21 ± 0.04 |



| Noise ratio |  Noise=0.2   |  Noise=0.5   |
| :---------: | :----------: | :----------: |
|  Standard   | 35.56 ± 0.31 | 19.58 ± 0.21 |
| Co-teaching | 45.60 ± 0.74 | 37.09 ± 0.53 |
|    SIGUA    | 45.79 ± 0.15 | 36.65 ± 0.37 |
|    Ours     | 48.63 ± 0.35 | 40.20 ± 0.15 |



|   Standard   | Co-teaching  |    SIGUA     |     Ours     |
| :----------: | :----------: | :----------: | :----------: |
| 64.54 ± 1.05 | 68.36 ± 0.35 | 69.35 ± 0.41 | 69.64 ± 0.14 |

Due to the limitation of Supplementary material, we only uploaded three datasets.  For other data sets, we provide a simple tool to proceed the data set from the raw text to the processed *scipy.sparse* matrix, which supports pretty large and high dimensional data set in practice (more than 10-million-feature data set):

```
python -u process_data.py -p 2 -b 10 -n 1000 -f fm data/XXX.txt
Args:
-p: # of threads used in processing
-b: # of lines processed in a thread
-n: the maximum # of features for the raw data set
-f: should be "fm" or "ffm" indicating the format of the raw text data, the "fm" stores one sample in a line as "feature_id:value", while the "ffm" has "field_id:feature_id:value".
```

All the datasets could be found in [here](https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/).

##### Experiments of RDIA-LS

To train the model, run this command:

```
python RDIA-LS.py --result_dir <path to your dir>  --dataset cifar10 --noise_rate 0.8 --weight_decay 0      --gamma 0.05 --gpu 1

python RDIA-LS.py --result_dir <path to your dir>  --dataset mnist   --noise_rate 0.8 --weight_decay 0.00001 --gamma 0.2  --gpu 1
```

The hyperparameter \gamma should be tuned with dataset and noise rate. Average test accuracy on MNIST and CIFAR10 with 80% noise rata is :

|  Model  |    MNIST     |   CIFAR10    |
| :-----: | :----------: | :----------: |
| RDIA-LS | 87.85 ± 0.21 | 25.35 ± 0.17 |

We also provide the test accuracy vs. number of epochs below:

![image-20210602182401477](./Figure/Acc vs Epoch.png)

## Results

All the experimental results could be reproduced by the aforementioned code.  We have already shown some of the experimental results in the 'Training & Evaluation' section.