Prompt Guard Tutorial
Prompt Guard Tutorial
The goal of this tutorial is to give an overview of several practical aspects of using the Prompt Guard model. We go over:
- The model's scope and what sort of risks it can guardrail against;
- Code for loading and executing the model, and the expected latency on CPU and GPU;
- The limitations of the model on new datasets and the process of fine-tuning the model to adapt to them.
Prompt Guard is a simple classifier model. The most straightforward way to load the model is with the transformers library:
The output of the model is logits that can be scaled to get a score in the range :
The model's positive label (1) corresponds to an input that contains a jailbreaking technique. These are techniques that are intended to override prior instructions or the model's safety conditioning, and in general are directed towards maliciously overriding the intended use of an LLM by application developers.
Detecting Direct Malicious attacks
The model can be used to detect if jailbreaking techniques are being used in direct chats with a model. These are typically users trying to directly override the model's safety conditioning.
Jailbreak Score (benign): 0.001
Jailbreak Score (malicious): 1.000
Detecting Indirect attacks.
We can also check for jailbreaking techniques used in arbitrary data that might be ingested by an LLM, beyond just prompts. This makes sense for scanning content from untrusted third party sources, like tools, web searches, or APIs.
These are often the highest-risk scenarios for jailbreaking techniques, as these attacks can target the users of an application and exploit a model's priveleged access to a user's data, rather than just being a content safety issue.
Inference Latency
The model itself is small and can run quickly on CPU or GPU.
Execution time: 0.088 seconds
GPU can provide a further significant speedup which can be key for enabling low-latency and high-throughput LLM applications.
Fine-tuning Prompt Guard on new datasets for specialized applications
Every LLM-powered application will see a different distribution of prompts, both benign and malicious, when deployed into production. While Prompt Guard can be very useful for flagging malicious inputs out-of-the-box, much more accurate results can be achieved by fitting the model directly to the distribution of datapoints expected. This can be critical to reduce risk for applications while not producing a significant number of regrettable false positives. Fine-tuning also allows LLM application developers to have granular control over the types of queries considered benign or malicious by the application that they choose to filter.
Let's test out Prompt Guard on an external dataset not involved in the training process. For this example, we pull a publicly licensed dataset of "synthetic" prompt injection datapoints from huggingface:
This dataset has LLM-generated examples of attacks and benign prompts, and looks significantly different from the human-written examples the model was trained on:
Let's evaluate the model on this dataset:
Evaluating: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [01:03<00:00, 3.98s/it]
Looking at the plots below, The model definitely has some predictive power over this new dataset, but the results are far from the .99 AUC we see on the original test set.
(Fortunately this is a particularly challenging dataset, and typically we've seen an out-of-distribution AUC of ~.98-.99 on datasets of more realistic attacks and queries. But this dataset is useful to illustrate the challenge of adapting the model to a new distribution of attacks).
Now, let's fine-tune the prompt injection model to match the new distribution, on the training dataset. By doing this, we take advantage of the latent understanding of historical injection attacks the base injection model has developed, while making the model much more precise in it's results on this specific dataset.
Note that to do this we replace the final layer of the model classifier (a linear layer producing the 3 logits corresponding to the output probabilities) with one that produces two logits, to obtain a binary classifier model.
Epoch 1: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 157/157 [34:32<00:00, 13.20s/it]
Average loss in epoch 1: 0.33445613684168285
Training this model is not computationally intensive either (on 5000 datapoints, which is plenty for a solid classifier, this takes ~40 minutes running on a Mac CPU, and only a few seconds running on an NVIDIA GPU.)
Looking at the results, we see a much better fit!
Evaluating: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [01:01<00:00, 3.86s/it]
One good way to quickly obtain labeled training data for a use case is to use the original, non-fine tuned model itself to highlight risky examples to label, while drawing random negatives from below a score threshold. This helps address the class imbalance (attacks and risky prompts can be a very small percentage of all prompts) and includes false positive examples (which tend to be very valuable to train on) in the dataset. Generating synthetic fine-tuning data for specific use cases can also be an effective strategy.