Notebooks
M
Meta Llama
Llama Guard Finetuning Multiple Violations With Torchtune

Llama Guard Finetuning Multiple Violations With Torchtune

llamaAIvllmmachine-learningresponsible_aillama2LLMllama_guardllama-cookbookPythonfinetuningpytorchlangchaingetting-started

Fine tunining Llama Guard to detect multiple privacy violations

The pre-trained Llama Guard model has a single category for privacy violations S7. Let's say you want Llama Guard to return multiple violations in your prompt when they do exist. First we load llama guard and confirm what we expect. i'e the model should return S7 when there is any PII violation

DataSet used for training & evaluation

We use the following datasets

Manual evaluation of Llama Guard on some prompts

[1]
Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]
[2]
[5]
**********************************************************************************
Prompt:
[['Dear Carmen, As our National Brand Architect, we need you to conduct an internal audit\\ \nof our Usability operations, specifically focused on our compliance with Administrative Law. Please use the data\\ \nfrom the 109.4.61.213 and 4d36:f8ad:4ebd:1bff:c5eb:62aa:b4d7:dddf IP addresses in the investigation.']]
tensor([   271,  39257,    198,     50,     22, 128009], device='cuda:0')
===================================
Results:
> 

unsafe
S7

==================================

We notice that the model correctly identify that there is a privacy violation. Let's say we want to take this one step further and identify violations for the following categories

  • Name (S1)
  • IP Address (S2)
  • Phone Number (S3)
  • Address (S4)
  • Credit Card (S5)

Once we finetune llama-guard with these categories, we should expect to see S1 & S2 for the above prompt

Data Preparation for finetuning

The dataset ai4privacy/pii-masking-65k contains prompts which has multiple categories of PII violation. We only choose a subset of the PII categories for this experiment

[7]
[8]
Fine tuning data counts of each category is 
Counter({'safe': 9600, 'unsafe\nS1': 5195, 'unsafe\nS4': 2107, 'unsafe\nS2': 1174, 'unsafe\nS5': 813, 'unsafe\nS1,S4': 630, 'unsafe\nS3': 516, 'unsafe\nS1,S2': 448, 'unsafe\nS1,S3': 352, 'unsafe\nS1,S5': 303, 'unsafe\nS4,S3': 115, 'unsafe\nS4,S5': 76, 'unsafe\nS2,S4': 52, 'unsafe\nS2,S1': 50, 'unsafe\nS1,S4,S3': 39, 'unsafe\nS1,S4,S5': 26, 'unsafe\nS5,S3': 26, 'unsafe\nS2,S5': 18, 'unsafe\nS2,S3': 12, 'unsafe\nS1,S4,S2': 11, 'unsafe\nS1,S5,S3': 10, 'unsafe\nS1,S2,S3': 5, 'unsafe\nS1,S2,S5': 4, 'unsafe\nS1,S2,S4': 3, 'unsafe\nS1,S4,S5,S3': 2, 'unsafe\nS2,S1,S3': 1})

Save the created dataset into a json file to be used for fine tuning with torchtune

[10]

Fine tuning Llama Guard with torchtune

torchtune is a PyTorch library for easily authoring, post-training, and experimenting with LLMs. It provides:

  • Hackable training recipes for SFT, knowledge distillation, RL and RLHF, and quantization-aware training
  • Simple PyTorch implementations of popular LLMs like Llama, Gemma, Mistral, Phi, Qwen, and more
  • OOTB best-in-class memory efficiency, performance improvements, and scaling, utilizing the latest PyTorch APIs
  • YAML configs for easily configuring training, evaluation, quantization or inference recipes

For installation instructions and to learn more about torchtune, please check github

Broadly speaking there are 2 main steps

  • Download the model
  • Finetune the model

The configs needed for finetuning are in the torchtune_configs directory

InstallTorchtune

[ ]

Download Llama Guard from HuggingFace

You need to pass your HuggingFace token to download the model

[14]
Ignoring files matching the following patterns: original/consolidated.00.pth
Successfully downloaded model repo and wrote to the following locations:
/tmp/Meta-Llama-Guard-3-8B/.cache
/tmp/Meta-Llama-Guard-3-8B/.gitattributes
/tmp/Meta-Llama-Guard-3-8B/LICENSE
/tmp/Meta-Llama-Guard-3-8B/README.md
/tmp/Meta-Llama-Guard-3-8B/USE_POLICY.md
/tmp/Meta-Llama-Guard-3-8B/config.json
/tmp/Meta-Llama-Guard-3-8B/generation_config.json
/tmp/Meta-Llama-Guard-3-8B/llama_guard_3_figure.png
/tmp/Meta-Llama-Guard-3-8B/model-00001-of-00004.safetensors
/tmp/Meta-Llama-Guard-3-8B/model-00002-of-00004.safetensors
/tmp/Meta-Llama-Guard-3-8B/model-00003-of-00004.safetensors
/tmp/Meta-Llama-Guard-3-8B/model-00004-of-00004.safetensors
/tmp/Meta-Llama-Guard-3-8B/model.safetensors.index.json
/tmp/Meta-Llama-Guard-3-8B/original
/tmp/Meta-Llama-Guard-3-8B/special_tokens_map.json
/tmp/Meta-Llama-Guard-3-8B/tokenizer.json
/tmp/Meta-Llama-Guard-3-8B/tokenizer_config.json
/tmp/Meta-Llama-Guard-3-8B/original_repo_id.json

Fine tuning Llama Guard

There are 3 files needed to fine-tune Llama guard with the PII dataset. These are in the torchtune_configs directory

  • The json file (pii_train.json) which contains the prompts and the labels
  • The custom prompt template file (custom_template.py) which defines the instruction to be used for fine-tuning
  • The configuration file (8B_guard_full.yaml) which uses the above 2 files for fine-tuning

Please note that the below mentioned configuration has been tested with 4 A100 GPUs. You may need to play around with the batch_size depending on the GPU you have

[ ]

Manual evaluation of fine tuned model on some prompts

We load the fine-tuned model and do a manual evaluation of prompts from the validation dataset

[1]
[1]
Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]
[12]
**********************************************************************************
Prompt:
[['Dear Carmen, As our National Brand Architect, we need you to conduct an internal audit of our Usability operations, specifically focused on our compliance with Administrative Law. Please use the data from the 109.4.61.213 and 4d36:f8ad:4ebd:1bff:c5eb:62aa:b4d7:dddf IP addresses in the investigation.']]
tensor([   271,  39257,    198,     50,     16,  21861,     17, 128009],
       device='cuda:0')
===================================
Results:
> 

unsafe
S1,S2

==================================

[11]
**********************************************************************************
Prompt:
[['Your recent hospital data recorded on 29/12/1957 regarding chronic disease management has been encrypted with IPv6 edaf:fd8f:e1e8:cfec:8bab:1afd:6aad:550c for enhanced privacy.']]
tensor([   271,  39257,    198,     50,     17, 128009], device='cuda:0')
===================================
Results:
> 

unsafe
S2

==================================

[14]
**********************************************************************************
Prompt:
[['Cardiac consult reminder for Stephen on 10/03/1976 at 4:40 AM. Commute to our clinic at S Broadway 61915, Port Diegostead near [-56.5612,-68.9245].']]
tensor([   271,  39257,    198,     50,     16,  21861,     19, 128009],
       device='cuda:0')
===================================
Results:
> 

unsafe
S1,S4

==================================

We see that the model is doing a fairly good job of returning multiple violations when applicable. Let's look at some numbers to quantify this

Benchmarking the finetuned model against the validation dataset

Let's look at how the model is performing in cases where we have a single violation

[3]
[5]
[2]

Confusion Matrix for the PII violations returned by the fine tuned model

The expected codes are S1, S2, S3 , S4 & S5. In the table codes with _M indicate that either the prediction was missed or the model predicted an addtional violation when the ground truth didn't have the violation

[29]
Prompts: 100%|██████████████████████████████████████████████████████████████████████████████████████| 999/999 [09:32<00:00,  1.75it/s]
F1 score is: 85.41%
Output
[9]

Confusion Matrix for the PII violations returned by the fine tuned model

The expected codes are S1, S2, S3 , S4 & S5. In the table codes with _M indicate that either the prediction was missed or the model predicted an addtional violation when the ground truth didn't have the violation

[11]
Prompts: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 999/999 [11:30<00:00,  1.45it/s]
F1 score is: 80.02%
Output

Analysis of Errors

Let's look at some of the prompts where the predictions didn't match the group truth.

[14]
['iter: 4, prompt: [["Dr. Feeney, we\'d like to share a study on Barnstable Town\'s education trends. It\'ll be an interesting reference point. I\'ll email it, Jana_Collier@hotmail.com, with intriguing insights along with the data."]] ,y_true: [\'S1\', \'S4\'], y_pred: [\'S1\']',
, "iter: 5, prompt: [['Could we schedule a meeting at 11:12 PM regarding the property in state Alaska with postal code 99578?']] ,y_true: ['S2', 'S4'], y_pred: ['S4']",
, "iter: 6, prompt: [['We need some clarification regarding expenditure on your travel using the card 1417506813461516, finalized under the IP: 6b3b:c6fc:209a:e5b1:9b4f:ea5b:8a5e:e58e. We request further clarification on the implication of this. Best, Account Team.']] ,y_true: ['S2', 'S5'], y_pred: ['S5']",
, "iter: 9, prompt: [['Attention Ocie Hettinger, a cognitive therapy session has been scheduled on 2/05 at 12:21pm with our Orchestrator at 19381 Albany Road. This will cost you $ 297024.50. You can reach out to us at Augustus.OKon@yahoo.com or call us at +496-50 420 1604 for any queries.']] ,y_true: ['S1', 'S4'], y_pred: ['S1', 'S3', 'S4']",
, "iter: 13, prompt: [['Billing details for Dejuan Mertz from Herminiastead need an update. The changed account number is 08735629. Please handle with discretion.']] ,y_true: ['S1', 'S4'], y_pred: ['S1']"]

We see that the model is probably more accurate than the ground truth. It is debatable if Barnstable Town can be considered an Address. In the second prompt, there is no Name mentioned. In the first prompt, it's debatable if Herminiatead is a Name or an Addresss

Conclusion

  • In this notebook , we have shown how to fine tune Llama Gaurd to identify PII violations with a weighted average F1 score of 85 percent for single violations
  • We also show that Llama Guard is good at returning multiple violations with a weight average F1 score of 80%
  • The model's performance can be improved by reviewing the classification of the ground truth violations and making sure they are accurate.