Towards Improving Adversarial Training of NLP Models

This is the source code for the EMNLP 2021 (Findings) paper “Towards Improving Adversarial Training of NLP Models”.

If you use the code, please cite the paper:

      title={Towards Improving Adversarial Training of NLP Models}, 
      author={Jin Yong Yoo and Yanjun Qi},


The work heavily relies on the TextAttack package. In fact, the main training code is implemented in the TextAttack package.

Required packages are listed in the requirements.txt file.

pip install -r requirements.txt


All of the data used for the paper are available from HuggingFace’s Datasets.

For IMDB and Yelp datasets, because there are no official validation splits, we randomly sampled 5k and 10k, respectively, from the training set and used them as valid splits. We provide the splits in this Google Drive folder. To use them with the provided code, place each folder (e.g. imdb, yelp, augmented_data) inside ./data (run mkdir data).

Also, augmented training data generated using SSMBA and back-translation are available in the same folder.


To train BERT model on IMDB dataset with A2T attack for 4 epochs and 1 clean epoch with gamma of 0.2:

    --train imdb 
    --eval imdb 
    --model-type bert 
    --model-save-path ./example 
    --num-epochs 4 
    --num-clean-epochs 1 
    --num-adv-examples 0.2 
    --attack-epoch-interval 1 
    --attack a2t 
    --learning-rate 5e-5 
    --num-warmup-steps 100 
    --grad-accumu-steps 1 
    --checkpoint-interval-epochs 1 
    --seed 42

You can also pass roberta to train RoBERTa model instead of BERT model. To select other datasets from the paper, pass rt (MR), yelp, or snli for --train and --eval.

This script is actually just to run the Trainer class from the TextAttack package. To checkout how training is performed, please checkout the Trainer class.


To evalute the accuracy, robustness, and interpretability of our trained model from above, run

    --dataset imdb 
    --model-type bert 
    --checkpoint-paths ./example_run 
    --epoch 4 
    --attacks a2t a2t_mlm textfooler bae pwws pso 

This takes the last checkpoint model (--epoch 4) and evaluates its accuracy on both IMDB and Yelp dataset (for cross-domain accuracy). It also evalutes the model’s robustness against A2T, A2T-MLM, TextFooler, BAE, PWWS, and PSO attacks. Lastly, with the --interpretability flag, AOPC scores are calculated.

Note that you will have to run --robustness and --interpretability with --accuracy (or after you separately evaluate accuracy) since both robustness and intepretability evaluations rely on the accuracy evaluation to know which samples the model was able to predict correctly. By default 1000 samples are attacked to evaluate robustness. Likewise, 1000 samples are used to calculate AOPC score for interpretability.

If you’re evaluating multiple models for comparison, it’s also advised that you provide all the checkpoint paths together to --checkpoint-paths. This is because the samples that are correctly by each model will be different, so we first need to identify the intersection of the all correct predictions before using them to evaluate robustness for all the models. This will allow fairer comparison of models’ robustness rather than using attack different samples for each model.

Data Augmentation

Lastly, we also provide which we used to perform data augmentation methods such as SSMBA and back-translation.

Following is an example command for augmenting imdb dataset with SSMBA method.

    --dataset imdb 
    --augmentation ssmba 
    --output-path ./augmented_data 
    --seed 42 

You can also pass backtranslation to --augmentation.


Leave a Reply

Your email address will not be published. Required fields are marked *

This site is protected by reCAPTCHA and the Google Privacy Policy and Terms of Service apply.

GIPHY App Key not set. Please check settings

Blank — Blank is a unique escape room, text based adventure game

Top 11 distributed tracing tools in 2021