# Adaptive Sparse Softmax: An Effective and Efficient Softmax Variant for Text Classification

We propose the AS-Softmax algorithm, which improves the performance of the model with a reasonable and test-matching transformation on top of softmax.

## Requirements

To install requirements:
```setup
pip install -r requirements.txt
```

## Examples:

### AS-Softmax

```python
In [1]: from as_sm import AS_Softmax

In [2]: logits = torch.tensor([
            [1.2,1.5,1.0,-0.35],
            [2.5,3.5,0.4,0.3],
            [2.5,3.5,0.4,0.3]
        ])

In [3]: labels = torch.tensor([1,2,-100])

In [4]: as_softmax = AS_Softmax(delta=0.10)

In [5]: as_logits, labels = as_softmax(logits,labels)

In [6]: as_logits
Out[6]: tensor([[-inf, 1.5000, -inf, -inf],
                [2.5000, 3.5000, 0.4000, 0.3000],
                [2.5000, 3.5000, -inf, -inf]
        ])
```

### Multi-label AS-Softmax

```python
In [1]: from as_sm import Multi_label_AS_Softmax

In [2]: logits = torch.tensor([
            [1.2,1.5,1.0,-0.35],
            [2.5,3.5,0.4,0.3],
            [2.5,3.5,0.4,0.3]
        ])

In [3]: multi_labels = torch.tensor([[1,1,0,0], [0,1,1,0], [1,1,0,1]])

In [4]: multi_label_as_softmax = Multi_label_AS_Softmax(delta=0.10)

In [5]: mask_neg,mask_pos = multi_label_as_softmax(logits,multi_labels)

In [6]: loss = multi_label_as_softmax.as_multilabel_categorical_crossentropy(logits, multi_labels, mask_neg, mask_pos)

In [7]: loss
Out[7]: tensor(2.0778)
```

### AS-Speed

```python
from as_sm import compute_accumulation_step

# If need_as_steps = 0, then the accumulation steps need to be recalculated.
need_as_steps = 0

# Whether accelerate training process
as_speed_up = True

# lambda and max accumulation steps
lamb = 1
max_accu_steps=2

as_accumulation_steps = 1

for step, inputs in enumerate(epoch_iterator):
    output = model(**inputs)
    step_loss = output['loss']

    if need_as_steps == 0:
        # After back propagation, the new gradient accumulation steps are calculated.
        _, _, need_as_steps = compute_accumulation_step(output['logits'], lamb, as_speed_up)
        need_as_steps = max(need_as_steps, as_accumulation_steps)

        #Ensure the difference between two adjacent accumulation steps does not exceed 1.
        if need_as_steps - as_accumulation_steps > 1:
            need_as_steps = as_accumulation_steps + 1

        #Accumulation steps can't exceed an upper limitation.
        need_as_steps = min(max_accu_steps, need_as_steps)
        as_accumulation_steps = need_as_steps

        tr_loss_all = step_loss
    else:
        tr_loss_all += step_loss

    need_as_steps -= 1

    # back propagation
    if need_as_steps == 0:
        tr_loss_all = tr_loss_all / as_accumulation_steps
        tr_loss_all.backward()
        tr_loss_all = tr_loss_all.detach()
        model.zero_grad()

```


## Training and Evaluation

To train the model(s) with AS-Softmax on SST5 dataset, run this command:

```shell
python ./src/transformers/examples/pytorch/text_classification/run_sst5.py --model_name_or_path bert-base-cased --task_name sst5 --do_train --do_eval --do_predict --seed 42 --max_seq_length 128 --per_device_train_batch_size 16 --learning_rate 2e-5 --num_train_epochs 7 --output_dir results --fp16 --cache_dir sst5 --load_best_model_at_end True --save_strategy steps --logging_strategy steps --evaluation_strategy steps --save_steps 200 --logging_steps 200 --eval_steps 200 --greater_is_better True --metric_for_best_model accuracy --overwrite_output_dir --warmup_ratio 0.1 --initial_as_delta 1.0 --min_as_delta 0.10 --as_speed_up False --as_warm_up False --ratio 0 --accu_steps 2 --accu_lambda 1 --dataloader_num_workers 4 --dataloader_pin_memory True
```

>📋 AS-Softmax : 2 parameters (δ and r). In the command above, δ corresponds to min_as_delta and r corresponds to ratio. If as_speed_up is "True", AS-Speed algorithm is enabled. At this point, accu_steps and accu_lambda come into play.

## Results

Our algorithm achieves the following performance compared with softmax :

### Text Classification [ Accuracy ]

| Algorithm/Datasets         | SST5  | Clinc_oos |
| ------------------ |---------------- | -------------- |
| softmax   |     51.90         |      88.60       |
| AS-Softmax   |     53.12         |      89.07       |

### Token Classification [ F1 ]

| Algorithm/Datasets         | Conll2003  | SIGHAN2015 |
| ------------------ |---------------- | -------------- |
| softmax   |     90.56         |       70.80       |
| AS-Softmax   |     91.03         |      72.81       |

### Multi-label Classification [ Macro-f1 / Micro-f1 ]

| Algorithm/Datasets         | Eurlex  | WOS-46985 |
| ------------------ |---------------- | -------------- |
| softmax   |     46.87 / 68.73         |      80.40 / 86.00       |
| AS-Softmax   |     48.30 / 68.99         |      80.94 / 86.39       |

## Contributing

>📋  Pick a licence and describe how to contribute to your code repository. 
MIT license
