Notebooks
H
Hugging Face
Knowledge Distillation For Image Classification

Knowledge Distillation For Image Classification

tensorflowhf-notebookstransformers_docen
[ ]

Knowledge Distillation for Computer Vision

Knowledge distillation is a technique used to transfer knowledge from a larger, more complex model (teacher) to a smaller, simpler model (student). To distill knowledge from one model to another, we take a pre-trained teacher model trained on a certain task (image classification for this case) and randomly initialize a student model to be trained on image classification. Next, we train the student model to minimize the difference between its outputs and the teacher's outputs, thus making it mimic the behavior. It was first introduced in Distilling the Knowledge in a Neural Network by Hinton et al. In this guide, we will do task-specific knowledge distillation. We will use the beans dataset for this.

This guide demonstrates how you can distill a fine-tuned ViT model (teacher model) to a MobileNet (student model) using the TrainerĀ API of šŸ¤— Transformers.

Let's install the libraries needed for distillation and evaluating the process.

pip install transformers datasets accelerate tensorboard evaluate trackio --upgrade

In this example, we are using the merve/beans-vit-224 model as teacher model. It's an image classification model, based on google/vit-base-patch16-224-in21k fine-tuned on beans dataset. We will distill this model to a randomly initialized MobileNetV2.

We will now load the dataset.

[ ]

We can use an image processor from either of the models, as in this case they return the same output with same resolution. We will use the map() method of dataset to apply the preprocessing to every split of the dataset.

[ ]

Essentially, we want the student model (a randomly initialized MobileNet) to mimic the teacher model (fine-tuned vision transformer). To achieve this, we first get the logits output from the teacher and the student. Then, we divide each of them by the parameter temperature which controls the importance of each soft target. A parameter called lambda weighs the importance of the distillation loss. In this example, we will use temperature=5 and lambda=0.5. We will use the Kullback-Leibler Divergence loss to compute the divergence between the student and teacher. Given two data P and Q, KL Divergence explains how much extra information we need to represent P using Q. If two are identical, their KL divergence is zero, as there's no other information needed to explain P from Q. Thus, in the context of knowledge distillation, KL divergence is useful.

[ ]

We will now login to Hugging Face Hub so we can push our model to the Hugging Face Hub through the Trainer.

[ ]

Let's set the TrainingArguments, the teacher model and the student model.

[ ]

We can use compute_metrics function to evaluate our model on the test set. This function will be used during the training process to compute the accuracy & f1 of our model.

[ ]

Let's initialize the Trainer with the training arguments we defined. We will also initialize our data collator.

[ ]

We can now train our model.

[ ]

We can evaluate the model on the test set.

[ ]

On test set, our model reaches 72 percent accuracy. To have a sanity check over efficiency of distillation, we also trained MobileNet on the beans dataset from scratch with the same hyperparameters and observed 63 percent accuracy on the test set. We invite the readers to try different pre-trained teacher models, student architectures, distillation parameters and report their findings. The training logs and checkpoints for distilled model can be found in this repository, and MobileNetV2 trained from scratch can be found in this repository.