How to Implement Gradient Explanations for a HuggingFace Text Classification Model (Tensorflow 2.0)
The interactive visualization above uses gradient attributions to explain which of the input tokens have the most influence for predictions made by a model.
Explanations help us make sense of a model's predictions. This can both be useful in debugging a model (does the model predict a class for the right reason) and also as a tool for end users to make sense of the model's predictions and build appropriate levels of trust in a model's output. For deep learning models (famed for being black box models), one way to explain their predictions use gradients (also known as vanilla gradients[^3]) in computing the contribution of input features to the model's output. This way, we can attribute output to aspects of input.
This post provides code snippets on how to implement gradient based explanations for a BERT based model for Huggingface text classifcation models (Tensorflow 2.0).
I recently used this method to debug a simple model I built to classify text as political or not for a specialized dataset (tweets from Nigeria, discussing the 2019 presidential elections). Essentially, I wanted to inspect that the model correctly focuses on politics related words/tokens in classifying a text as political.
Vanilla Gradients: Implementation Steps
We can rely on gradients to infer how changes in the output are influenced by changes in the input [^1]. Concretely, we compute the gradient of the predicted output class with respect to the input tokens. We can do this with the Tensorflow GradientTape api in keras. This approach is sometimes referred to as vanilla gradients or gradient sensitivity and the general steps are outlined below:
- Initialize a GradientTape which records operations for automatic differentiation.
- Create a one hot vector that represents our input token (note that input correspond to the token index for words as generated by the tokenizer). We will instruct Tensorflow to
watch
this variable within the gradient tape. - Multiply input by embedding matrix; this way we can backpropagate prediction wrt to input
- Get prediction for input tokens
- Get gradient of input with respect to predicted class. For a classification model with
n
classes, we zero out all the othern-1
classes except for the predicted class (i.e., class with the highest prediction logits) and get gradient wrt to this predicted class. - Normalize gradient (0-1) and return them as explanations.
def get_gradients(text, model, tokenizer):def get_correct_span_mask(correct_index, token_size):span_mask = np.zeros((1, token_size))span_mask[0, correct_index] = 1span_mask = tf.constant(span_mask, dtype='float32')return span_maskembedding_matrix = model.bert.embeddings.weights[0]encoded_tokens = tokenizer(text, return_tensors="tf")token_ids = list(encoded_tokens["input_ids"].numpy()[0])vocab_size = embedding_matrix.get_shape()[0]# convert token ids to one hot. We can't differentiate wrt to int token ids hence the need for one hot representationtoken_ids_tensor = tf.constant([token_ids], dtype='int32')token_ids_tensor_one_hot = tf.one_hot(token_ids_tensor, vocab_size)with tf.GradientTape(watch_accessed_variables=False) as tape:# (i) watch input variabletape.watch(token_ids_tensor_one_hot)# multiply input model embedding matrix; allows us do backprop wrt one hot inputinputs_embeds = tf.matmul(token_ids_tensor_one_hot,embedding_matrix)# (ii) get predictionpred_scores = model({"inputs_embeds": inputs_embeds, "attention_mask": encoded_tokens["attention_mask"] } ).logitsmax_class = tf.argmax(pred_scores, axis=1).numpy()[0]# get mask for predicted score classscore_mask = get_correct_span_mask(max_class, pred_scores.shape[1])# zero out all predictions outside of the correct prediction class; we want to get gradients wrt to just this classpredict_correct_class = tf.reduce_sum(pred_scores * score_mask )# (iii) get gradient of input with respect to prediction classgradient_non_normalized = tf.norm(tape.gradient(predict_correct_class, token_ids_tensor_one_hot),axis=2)# (iv) normalize gradient scores and return them as "explanations"gradient_tensor = (gradient_non_normalized /tf.reduce_max(gradient_non_normalized))gradients = gradient_tensor[0].numpy().tolist()token_words = tokenizer.convert_ids_to_tokens(token_ids)prediction_label= "political" if max_class == 1 else "general"return gradients, token_words , prediction_label
We can then plot the results of this with the following helper function
import matplotlib.pyplot as pltdef plot_gradients(tokens,gradients, title):""" Plot explanations"""plt.figure(figsize=(21,3))xvals = [ x + str(i) for i,x in enumerate(tokens)]colors = [ (0,0,1, c) for c in (gradients) ]# edgecolors = [ "black" if t==0 else (0,0,1, c) for c,t in zip(gradients, token_types) ]# colors = [ ("r" if t==0 else "b") for c,t in zip(gradients, token_types) ]plt.tick_params(axis='both', which='minor', labelsize=29)p = plt.bar(xvals, gradients, color=colors, linewidth=1 )plt.title(title)p=plt.xticks(ticks=[i for i in range(len(tokens))], labels=tokens, fontsize=12,rotation=90)
texts = ["The results of the elections appear to favour candidate Atiku","The sky is green and beautiful","The fool doth think he is wise, but the wise man knows himself to be a fool.","Oby ezekwesili was talking about results of the polls in today's briefing","Which party ran the most effective campaign strategy? APC or PDP"]# texts = sorted(texts, key=len)examples = []for text in texts:gradients, words, label = get_gradients(text, model, tokenizer)plot_gradients(words, gradients, f"Prediction: {label.upper()} | {text} ")print(label, text)examples.append({"sentence": text,"words": words,"label": label,"gradients": gradients})
Conclusions
As part of building this baseline model, it was important that I had a way to inspect the model's behaviour over a few examples. Visualizing gradient based explanations were useful in a few ways:
-
Adding Custom Tokens: Visualizing tokens and their contributions helped me realize the need for expanding the token vocabulary used to train the model. My dataset had alot of Nigerian names that the standard BERT tokenizer did not contain leading them to be represented by partial tokens (possibly less informative as full tokens). This led me to add custom tokens to the base pretrained model before finetuning (see here). I used the
tokenizer.add_tokens
method and resized my model embedding size to fix this. -
Informing strategies for improving the model: Looking at how importance is assigned to each token was useful in rewriting the initial heuristics used on generating a training set (e.g., not matching on some frequently occuring last names to minimize spurious attributions), introducing additional keywords, updating my preprocessing logic etc.
Note that there are other methods/variations of gradient based attribution, however vanilla gradients are particularly straightforward to implement and yields fairly similar results[^4].
References
[^1]: Explaining Machine Learning Models. Talk by Anku Taly, Fiddler Labs http://theory.stanford.edu/~ataly/Talks/ExplainingMLModels.pdf [^2]: Gradient*Input as Explanation https://towardsdatascience.com/basics-gradient-input-as-explanation-bca79bb80de0. [^3]: Karen Simonyan, Andrea Vedaldi, and Andrew Zisserman. 2014. Deep inside convolutional networks: Visualising image classification models and saliency maps. https://arxiv.org/pdf/1312.6034.pdf [^4]: Wu, Z. and Ong, D.C., 2021. On explaining your explanations of bert: An empirical study with sequence classification. arXiv preprint arXiv:2101.00196.