2 minute read

How to Implement Gradient Explanations for a HuggingFace Text Classification Model (Tensorflow 2.0)

Attribution based explanations via vanilla gradients for Huggingface BERT text classification models.
Model rediction : political
The results of the elections appear to favour candidate Atiku
Normalized gradient : hover on token
Select an example sentence to view gradient attribution for each token.
The results of the elections appear to favour candidate Atiku
The sky is green and beautiful
The fool doth think he is wise, but the wise man knows himself to be a fool.
Oby ezekwesili was talking about results of the polls in today's briefing
Which party ran the most effective campaign strategy? APC or PDP
The interactive visualization above uses gradient attributions to explain which of the input tokens have the most influence for predictions made by a model.

Explanations help us make sense of a model's predictions. This can both be useful in debugging a model (does the model predict a class for the right reason) and also as a tool for end users to make sense of the model's predictions and build appropriate levels of trust in a model's output. For deep learning models (famed for being black box models), one way to explain their predictions use gradients (also known as vanilla gradients3) in computing the contribution of input features to the model's output. This way, we can attribute output to aspects of input.

In a previous post I discuss gradient based explanations for BERT Question/Answer models and why they are an effective method for explaining transformer based models

This post provides code snippets on how to implement gradient based explanations for a BERT based model for Huggingface text classifcation models (Tensorflow 2.0).

I recently used this method to debug a simple model I built to classify text as political or not for a specialized dataset (tweets from Nigeria, discussing the 2019 presidential elections). Essentially, I wanted to inspect that the model correctly focuses on politics related words/tokens in classifying a text as political.

Vanilla Gradients: Implementation Steps

We can rely on gradients to infer how changes in the output are influenced by changes in the input 1. Concretely, we compute the gradient of the predicted output class with respect to the input tokens. We can do this with the Tensorflow GradientTape api in keras. This approach is sometimes referred to as vanilla gradients or gradient sensitivity and the general steps are outlined below:

  • Initialize a GradientTape which records operations for automatic differentiation.
  • Create a one hot vector that represents our input token (note that input correspond to the token index for words as generated by the tokenizer). We will instruct Tensorflow to watch this variable within the gradient tape.
  • Multiply input by embedding matrix; this way we can backpropagate prediction wrt to input
  • Get prediction for input tokens
  • Get gradient of input with respect to predicted class. For a classification model with n classes, we zero out all the other n-1 classes except for the predicted class (i.e., class with the highest prediction logits) and get gradient wrt to this predicted class.
  • Normalize gradient (0-1) and return them as explanations.
def get_gradients(text, model, tokenizer):
def get_correct_span_mask(correct_index, token_size):
span_mask = np.zeros((1, token_size))
span_mask[0, correct_index] = 1
span_mask = tf.constant(span_mask, dtype='float32')
return span_mask
embedding_matrix = model.bert.embeddings.weights[0]
encoded_tokens = tokenizer(text, return_tensors="tf")
token_ids = list(encoded_tokens["input_ids"].numpy()[0])
vocab_size = embedding_matrix.get_shape()[0]
# convert token ids to one hot. We can't differentiate wrt to int token ids hence the need for one hot representation
token_ids_tensor = tf.constant([token_ids], dtype='int32')
token_ids_tensor_one_hot = tf.one_hot(token_ids_tensor, vocab_size)
with tf.GradientTape(watch_accessed_variables=False) as tape:
# (i) watch input variable
# multiply input model embedding matrix; allows us do backprop wrt one hot input
inputs_embeds = tf.matmul(token_ids_tensor_one_hot,embedding_matrix)
# (ii) get prediction
pred_scores = model({"inputs_embeds": inputs_embeds, "attention_mask": encoded_tokens["attention_mask"] } ).logits
max_class = tf.argmax(pred_scores, axis=1).numpy()[0]
# get mask for predicted score class
score_mask = get_correct_span_mask(max_class, pred_scores.shape[1])
# zero out all predictions outside of the correct prediction class; we want to get gradients wrt to just this class
predict_correct_class = tf.reduce_sum(pred_scores * score_mask )
# (iii) get gradient of input with respect to prediction class
gradient_non_normalized = tf.norm(
tape.gradient(predict_correct_class, token_ids_tensor_one_hot),axis=2)
# (iv) normalize gradient scores and return them as "explanations"
gradient_tensor = (
gradient_non_normalized /
gradients = gradient_tensor[0].numpy().tolist()
token_words = tokenizer.convert_ids_to_tokens(token_ids)
prediction_label= "political" if max_class == 1 else "general"
return gradients, token_words , prediction_label

We can then plot the results of this with the following helper function

import matplotlib.pyplot as plt
def plot_gradients(tokens,gradients, title):
""" Plot explanations
xvals = [ x + str(i) for i,x in enumerate(tokens)]
colors = [ (0,0,1, c) for c in (gradients) ]
# edgecolors = [ "black" if t==0 else (0,0,1, c) for c,t in zip(gradients, token_types) ]
# colors = [ ("r" if t==0 else "b") for c,t in zip(gradients, token_types) ]
plt.tick_params(axis='both', which='minor', labelsize=29)
p = plt.bar(xvals, gradients, color=colors, linewidth=1 )
p=plt.xticks(ticks=[i for i in range(len(tokens))], labels=tokens, fontsize=12,rotation=90)
texts = ["The results of the elections appear to favour candidate Atiku",
"The sky is green and beautiful",
"The fool doth think he is wise, but the wise man knows himself to be a fool.",
"Oby ezekwesili was talking about results of the polls in today's briefing",
"Which party ran the most effective campaign strategy? APC or PDP"]
# texts = sorted(texts, key=len)
examples = []
for text in texts:
gradients, words, label = get_gradients(text, model, tokenizer)
plot_gradients(words, gradients, f"Prediction: {label.upper()} | {text} ")
print(label, text)
{"sentence": text,
"words": words,
"label": label,
"gradients": gradients}


As part of building this baseline model, it was important that I had a way to inspect the model's behaviour over a few examples. Visualizing gradient based explanations were useful in a few ways:

  • Adding Custom Tokens: Visualizing tokens and their contributions helped me realize the need for expanding the token vocabulary used to train the model. My dataset had alot of Nigerian names that the standard BERT tokenizer did not contain leading them to be represented by partial tokens (possibly less informative as full tokens). This led me to add custom tokens to the base pretrained model before finetuning (see here). I used the tokenizer.add_tokens method and resized my model embedding size to fix this.

  • Informing strategies for improving the model: Looking at how importance is assigned to each token was useful in rewriting the initial heuristics used on generating a training set (e.g., not matching on some frequently occuring last names to minimize spurious attributions), introducing additional keywords, updating my preprocessing logic etc.

Note that there are other methods/variations of gradient based attribution, however vanilla gradients are particularly straightforward to implement and yields fairly similar results4.


  1. Karen Simonyan, Andrea Vedaldi, and Andrew Zisserman. 2014. Deep inside convolutional networks: Visualising image classification models and saliency maps. https://arxiv.org/pdf/1312.6034.pdf
  2. Explaining Machine Learning Models. Talk by Anku Taly, Fiddler Labs http://theory.stanford.edu/~ataly/Talks/ExplainingMLModels.pdf
  3. Wu, Z. and Ong, D.C., 2021. On explaining your explanations of bert: An empirical study with sequence classification. arXiv preprint arXiv:2101.00196.
Interested in more articles like this? Subscribe to get a monthly roundup of new posts and other interesting ideas at the intersection of Applied AI and HCI.

RELATED POSTS | machine learning, howto, nlp, visualization

Read the Newsletter.

I write a monthly newsletter on Applied AI and HCI. Subscribe to get notified on new posts.

Feel free to reach out! Twitter, GitHub, LinkedIn