6 min read

How to Implement Extractive Summarization with BERT in Pytorch

How to Implement Extractive Summarization (as Sentence Classification) using HuggingFace BERT models in Pytorch

In a previous post, we discussed how extractive summarization can be framed as a sentence classification problem. In this post we will explore an implementation of a baseline model starting with data preprocessing, model training/export and inference using Pytorch and the HuggingFace transformers library.


Model Architecture

Extractive summarization as a classification problem.
Extractive summarization as a classification problem.

The model takes in a pair of inputs X=(sentence, document) and predicts a relevance score y. We need representations for our text input. For this, we can use any of the language models from the HuggingFace transformers library. Here we will use the sentence-transformers where a BERT based model has been finetuned for the task of extracting semantically meaningful sentence embeddings.

python
from transformers import AutoTokenizer, AutoModel
sentence_model_name = "sentence-transformers/paraphrase-MiniLM-L3-v2"
tokenizer = AutoTokenizer.from_pretrained(sentence_model_name)
# get mean pooling for sentence bert models
# ref https://www.sbert.net/examples/applications/computing-embeddings/README.html#sentence-embeddings-with-transformers
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] #First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
return sum_embeddings / sum_mask
class SentenceBertClass(torch.nn.Module):
def __init__(self, model_name="sentence-transformers/paraphrase-MiniLM-L3-v2", in_features=384):
super(SentenceBertClass, self).__init__()
self.l1 = AutoModel.from_pretrained(model_name)
self.pre_classifier = torch.nn.Linear(in_features*3, 768)
self.dropout = torch.nn.Dropout(0.3)
self.classifier = torch.nn.Linear(768, 1)
self.classifierSigmoid = torch.nn.Sigmoid()
def forward(self, sent_ids, doc_ids, sent_mask, doc_mask):
sent_output = self.l1(input_ids=sent_ids, attention_mask=sent_mask)
sentence_embeddings = mean_pooling(sent_output, sent_mask)
doc_output = self.l1(input_ids=doc_ids, attention_mask=doc_mask)
doc_embeddings = mean_pooling(doc_output, doc_mask)
# elementwise product of sentence embs and doc embs
combined_features = sentence_embeddings * doc_embeddings
# Concatenate input features and their elementwise product
concat_features = torch.cat((sentence_embeddings, doc_embeddings, combined_features), dim=1)
pooler = self.pre_classifier(concat_features)
pooler = torch.nn.ReLU()(pooler)
pooler = self.dropout(pooler)
output = self.classifier(pooler)
output = self.classifierSigmoid(output)
return output

Model Training

Next, we instantiate the model, move it to a GPU is available and specify training parameters (learning rate, and binary cross entropy loss).

python
from torch import cuda
device = 'cuda' if cuda.is_available() else 'cpu'
model = SentenceBertClass(model_name=sentence_model_name)
model.to(device)
loss_function = torch.nn.BCELoss()
optimizer = torch.optim.Adam(params = model.parameters(), lr=LEARNING_RATE)

Next, we specify the model training loop. For each batch of data from our data loader, we first tokenize the input text, get predictions, compute gradients and update model weights. We also keep a running sum of performance metrics per batch for subsequent visualization.

python
def train(epoch):
tr_loss = 0
n_correct = 0
nb_tr_steps = 0
nb_tr_examples = 0
model.train()
for _,data in tqdm(enumerate(training_loader, 0)):
sent_ids = data['sent_ids'].to(device, dtype = torch.long)
doc_ids = data['doc_ids'].to(device, dtype = torch.long)
sent_mask = data['sent_mask'].to(device, dtype = torch.long)
doc_mask = data['doc_mask'].to(device, dtype = torch.long)
targets = data['targets'].to(device, dtype = torch.float)
outputs = model(sent_ids, doc_ids, sent_mask, doc_mask)
loss = loss_function(outputs, targets)
tr_loss += loss.item()
n_correct += torch.count_nonzero(targets == (outputs > 0.5)).item()
nb_tr_steps += 1
nb_tr_examples+=targets.size(0)
if _%print_n_steps==0:
loss_step = tr_loss/nb_tr_steps
accu_step = (n_correct*100)/nb_tr_examples
print(str(_* train_params["batch_size"]) + "/" + str(len(train_df)) + " - Steps. Acc ->", accu_step, "Loss ->", loss_step)
acc_step_holder.append(accu_step), loss_step_holder.append(loss_step)
optimizer.zero_grad()
loss.backward()
# # When using GPU
optimizer.step()
print(f'The Total Accuracy for Epoch {epoch}: {(n_correct*100)/nb_tr_examples}')
epoch_loss = tr_loss/nb_tr_steps
epoch_accu = (n_correct*100)/nb_tr_examples
print(f"Training Loss Epoch: {epoch_loss}")
print(f"Training Accuracy Epoch: {epoch_accu}")
return
for epoch in range(EPOCHS):
train(epoch)

Model Evaluation

To evaluate the model, we can write a loop similar to our training loop where we get predictions for data in each batch of our test set but do not update model weights.

python
def validate_model(model, testing_loader):
model.eval()
n_correct = 0; n_wrong = 0; total = 0; tr_loss = 0; nb_tr_steps = 0 ; nb_tr_examples = 0
with torch.no_grad():
for _, data in enumerate(testing_loader, 0):
sent_ids = data['sent_ids'].to(device, dtype = torch.long)
doc_ids = data['doc_ids'].to(device, dtype = torch.long)
sent_mask = data['sent_mask'].to(device, dtype = torch.long)
doc_mask = data['doc_mask'].to(device, dtype = torch.long)
targets = data['targets'].to(device, dtype = torch.float)
outputs = model(sent_ids, doc_ids, sent_mask, doc_mask)
loss = loss_function(outputs, targets)
tr_loss += loss.item()
n_correct += torch.count_nonzero(targets == (outputs > 0.5)).item()
nb_tr_steps += 1
nb_tr_examples+=targets.size(0)
if _%print_n_steps==0:
loss_step = tr_loss/nb_tr_steps
accu_step = (n_correct*100)/nb_tr_examples
print(str(_* test_params["batch_size"]) + "/" + str(len(train_df)) + " - Steps. Acc ->", accu_step, "Loss ->", loss_step)
epoch_loss = tr_loss/nb_tr_steps
epoch_accu = (n_correct*100)/nb_tr_examples
print(f"Validation Loss Epoch: {epoch_loss}")
print(f"Validation Accuracy Epoch: {epoch_accu}")
return epoch_accu

Model Inference

Inference is implemented in the summarize method in the snippet below. First we sentencify our document (break up the document into individual sentences using a spacy language model), get score predictions for each sentence, sort by score and order of appearance and then return the top_k sentences as the summary.

python
# create spacy model
nlp = spacy.load('en_core_web_lg')
# tokenize text as required by BERT based models
def get_tokens(text, tokenizer):
inputs = tokenizer.batch_encode_plus(
text,
add_special_tokens=True,
max_length=512,
padding="max_length",
return_token_type_ids=True,
truncation=True
)
ids = inputs['input_ids']
mask = inputs['attention_mask']
return ids, mask
# get predictions given some an array of sentences and their corresponding documents
def predict(sents, doc):
sent_id, sent_mask = get_tokens(sents,tokenizer)
sent_id, sent_mask = torch.tensor(sent_id, dtype=torch.long),torch.tensor(sent_mask, dtype=torch.long)
doc_id, doc_mask = get_tokens([doc],tokenizer)
doc_id, doc_mask = doc_id * len(sents), doc_mask* len(sents)
doc_id, doc_mask = torch.tensor(doc_id, dtype=torch.long),torch.tensor(doc_mask, dtype=torch.long)
preds = model(sent_id, doc_id, sent_mask, doc_mask)
return preds
def summarize(doc, model, min_sentence_length=14, top_k=4, batch_size=3):
doc = doc.replace("\n","")
doc_sentences = []
for sent in nlp(doc).sents:
if len(sent) > min_sentence_length:
doc_sentences.append(str(sent))
doc_id, doc_mask = get_tokens([doc],tokenizer)
doc_id, doc_mask = doc_id * batch_size, doc_mask* batch_size
doc_id, doc_mask = torch.tensor(doc_id, dtype=torch.long),torch.tensor(doc_mask, dtype=torch.long)
scores = []
# run predictions using some batch size
for i in tqdm(range(int(len(doc_sentences) / batch_size) + 1)):
preds = predict(doc_sentences[i*batch_size: (i+1) * batch_size], doc)
scores = scores + preds.tolist()
sent_pred_list = [{"sentence": doc_sentences[i], "score": scores[i][0], "index":i} for i in range(len(doc_sentences))]
sorted_sentences = sorted(sent_pred_list, key=lambda k: k['score'], reverse=True)
sorted_result = sorted_sentences[:top_k]
sorted_result = sorted(sorted_result, key=lambda k: k['index'])
summary = [ x["sentence"] for x in sorted_result]
summary = " ".join(summary)
return summary, scores, doc_sentences

Conclusions

In this post, we have walked through some code on how to implement extractive summarization in Pytorch. You can find a repository wil the full code here.

References

[^1]: Zero-shot learning. https://en.wikipedia.org/wiki/Zero-shot_learning.

Interested in more articles like this? Subscribe to get a monthly roundup of new posts and other interesting ideas at the intersection of Applied AI and HCI.

RELATED POSTS | how to, nlp, machine learning

Read the Newsletter.

I write a monthly newsletter on Applied AI and HCI. Subscribe to get notified on new posts.

Feel free to reach out! Twitter, GitHub, LinkedIn

.