How to Implement Gradient Explanations for a HuggingFace Text Classification Model (Tensorflow 2.0)
Explanations help us make sense of a model's predictions. This can both be useful in debugging a model (does the model predict a class for the right reason) and also as a tool for end users to make sense of the model's predictions and build appropriate levels of trust in a model's output. For deep learning models (famed for being black box models), one way to explain their predictions use gradients (also known as vanilla gradients3) in computing the contribution of input features to the model's output. This way, we can attribute output to aspects of input.
This post provides code snippets on how to implement gradient based explanations for a BERT based model for Huggingface text classifcation models (Tensorflow 2.0).
I recently used this method to debug a simple model I built to classify text as political or not for a specialized dataset (tweets from Nigeria, discussing the 2019 presidential elections). Essentially, I wanted to inspect that the model correctly focuses on politics related words/tokens in classifying a text as political.
We can rely on gradients to infer how changes in the output are influenced by changes in the input 1. Concretely, we compute the gradient of the predicted output class with respect to the input tokens. We can do this with the Tensorflow GradientTape api in keras. This approach is sometimes referred to as vanilla gradients or gradient sensitivity and the general steps are outlined below:
- Initialize a GradientTape which records operations for automatic differentiation.
- Create a one hot vector that represents our input token (note that input correspond to the token index for words as generated by the tokenizer). We will instruct Tensorflow to
watchthis variable within the gradient tape.
- Multiply input by embedding matrix; this way we can backpropagate prediction wrt to input
- Get prediction for input tokens
- Get gradient of input with respect to predicted class. For a classification model with
nclasses, we zero out all the other
n-1classes except for the predicted class (i.e., class with the highest prediction logits) and get gradient wrt to this predicted class.
- Normalize gradient (0-1) and return them as explanations.
We can then plot the results of this with the following helper function
As part of building this baseline model, it was important that I had a way to inspect the model's behaviour over a few examples. Visualizing gradient based explanations were useful in a few ways:
Adding Custom Tokens: Visualizing tokens and their contributions helped me realize the need for expanding the token vocabulary used to train the model. My dataset had alot of Nigerian names that the standard BERT tokenizer did not contain leading them to be represented by partial tokens (possibly less informative as full tokens). This led me to add custom tokens to the base pretrained model before finetuning (see here). I used the
tokenizer.add_tokensmethod and resized my model embedding size to fix this.
Informing strategies for improving the model: Looking at how importance is assigned to each token was useful in rewriting the initial heuristics used on generating a training set (e.g., not matching on some frequently occuring last names to minimize spurious attributions), introducing additional keywords, updating my preprocessing logic etc.
- Karen Simonyan, Andrea Vedaldi, and Andrew Zisserman. 2014. Deep inside convolutional networks: Visualising image classification models and saliency maps. https://arxiv.org/pdf/1312.6034.pdf↩
- Explaining Machine Learning Models. Talk by Anku Taly, Fiddler Labs http://theory.stanford.edu/~ataly/Talks/ExplainingMLModels.pdf↩
- Wu, Z. and Ong, D.C., 2021. On explaining your explanations of bert: An empirical study with sequence classification. arXiv preprint arXiv:2101.00196.↩