Train Text Classification
Train Text Classification Model using BERT and Smart Sifting.
This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.
In this notebook we will train a text classification model using BERT (Transformers) and Smart Sifting library. BERT is a transformers encoder model pretrained on a large corpus of English data in a self-supervised fashion. This model is primarily aimed at being fine-tuned on tasks that use the whole sentence (potentially masked) to make decisions, such as sequence classification, token classification, and question answering. We will be building an BERT based Sentiment Analysis model using SST datasaet.
1.Introduction to Smart Sifting
Smart Sifting is a framework to speed up training of PyTorch models. The framework implements a set of algorithms that filter out inconsequential training examples during training, reducing the computational cost and accelerating the training process. It is configuration-driven and extensible, allowing users to add custom logic to transform their training examples into a filterable format. Smart sifting provides a generic utility for any DNN model, and can reduce the training cost by up to 35% in infrastructure cost.

Smart sifting’s task is to sift through your training data during the training process and only feed the more informative samples to the model. During typical training with PyTorch, data is iteratively sent in batches to the training loop and to accelerator devices (e.g. GPUs or Trainium chips) by the PyTorch data loader. Smart sifting is implemented at this data loading stage and is thus independent of any upstream data preprocessing in your training pipeline. Smart sifting uses your live model and a user specified loss function to do an evaluative forward pass of each data sample as it is loaded. Samples which are high loss will materially impact model training and thus are included in training data; meanwhile data samples which are relatively low loss are already well represented by the model and so are set aside and excluded from training. A key input to smart sifting is the proportion of data to exclude: for example, by setting the proportion to 25%, samples in approximately the bottom quartile of loss of each batch will be excluded from training. Once enough high-loss samples have been identified to complete a batch, the data is sent through the full training loop and the model learns and trains normally. Customers don’t need to make any downstream changes to their training loop when smart sifting is enabled.|
2. Prepare Dataset
For this training we will be using SST2. SST2 consists of positive/negative sentiment texts with roughly about 11k sentences extracted from movie reviews.
Lets start by downloading and extracting the dataset.
We will convert the dataset into a tsv file which will be uploaded to S3.
Upload the TSV file to S3
3. Run training Job using SageMaker Training.
Adding Sifting library to the Image classification code involves following the below steps
- Define Loss Function - For Image classification we use CrossEntropy loss defined as below
class BertLoss(Loss): """ This is an implementation of the Loss interface for the BERT model required for Smart Sifting. Use Cross-Entropy loss with 2 classes """ def __init__(self): self.celoss = torch.nn.CrossEntropyLoss(reduction='none') def loss( self, model: torch.nn.Module, transformed_batch: SiftingBatch, original_batch: Any = None, ) -> torch.Tensor: # get original batch onto model device. Note that we are assuming the model is on the right device here already # Pytorch lightning takes care of this under the hood with the model thats passed in. # TODO: ensure batch and model are on the same device in SiftDataloader so that the customer # doesn't have to implement this device = next(model.parameters()).device batch = [t.to(device) for t in original_batch] # compute loss outputs = model(batch) return self.celoss(outputs.logits, batch[2]) - Define Transformation Function to convert input batch to sifting format.
class BertListBatchTransform(SiftingBatchTransform):
"""
This is an implementation of the data transforms for the BERT model
required for Smart Sifting. Transform to and from ListBatch
"""
def transform(self, batch: Any):
inputs = []
for i in range(len(batch[0])):
inputs.append((batch[0][i], batch[1][i]))
labels = batch[2].tolist() # assume the last one is the list of labels
return ListBatch(inputs, labels)
def reverse_transform(self, list_batch: ListBatch):
inputs = list_batch.inputs
input_ids = [iid for (iid, _) in inputs]
masks = [mask for (_, mask) in inputs]
a_batch = [torch.stack(input_ids), torch.stack(masks), torch.tensor(list_batch.labels, dtype=torch.int64)]
return a_batch
-
Define sifting config - Define configuration for sifting.
Beta_value depicts the proportion of samples to keep , higher the value more samples are sifted. loss_history_length - Depicts the window of samples to include when evaluating relative loss.
sift_config = RelativeProbabilisticSiftConfig(
beta_value=3,
loss_history_length=500,
loss_based_sift_config=LossConfig(
sift_config=SiftingBaseConfig(sift_delay=10)
)
)
- Wrap the Pytorch Data Loader with Sifting Data loader. As a last step we wrap the Pytroch Dataloader with siftingDataLoader passing config, transformation and loss functions.
SiftingDataloader(
sift_config = sift_config,
orig_dataloader=DataLoader(self.train, self.batch_size, shuffle=True),
loss_impl=BertLoss(),
model=self.model,
batch_transforms=BertListBatchTransform()
)
We define few metrics to be tracked inorder to monitor sifting. This are optional metrics useful to debug and understand sifting performance.
We will launch the training job using G5.2xlarge instance. Sifting library is part of the SageMaker Pytorch Deep Learning containers starting version 2.0.1.
Launch the training job with Data in S3
In this notebook, we looked at how to use smart sifting library to train an Text classification (Sentiment analysis) model. Smart sifting helps in reducing training time upto 40% without any reduction in Model performance.
Notebook CI Test Results
This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.