3d Brain Tumor Segmentation
Brain tumor 3D segmentation with MONAI and Weights & Biases
This tutorial shows how to construct a training workflow of multi-labels 3D brain tumor segmentation task using MONAI and use experiment tracking and data visualization features of Weights & Biases. The tutorial contains the following features:
- Initialize a Weights & Biases run and synchrozize all configs associated with the run for reproducibility.
- MONAI transform API:
- MONAI Transforms for dictionary format data.
- How to define a new transform according to MONAI
transformsAPI. - How to randomly adjust intensity for data augmentation.
- Data Loading and Visualization:
- Load Nifti image with metadata, load a list of images and stack them.
- Cache IO and transforms to accelerate training and validation.
- Visualize the data using
wandb.Tableand interactive segmentation overlay on Weights & Biases.
- Training a 3D
SegResNetmodel- Using the
networks,losses, andmetricsAPIs from MONAI. - Training the 3D
SegResNetmodel using a PyTorch training loop. - Track the training experiment using Weights & Biases.
- Log and version model checkpoints as model artifacts on Weights & Biases.
- Using the
- Visualize and compare the predictions on the validation dataset using
wandb.Tableand interactive segmentation overlay on Weights & Biases.
🌴 Setup and Installation
First, let us install the latest version of both MONAI and Weights and Biases.
We will then authenticate this colab instance to use W&B.
🌳 Initialize a W&B Run
We will start a new W&B run to start tracking our experiment.
Use of proper config system is a recommended best practice for reproducible machine learning. We can track the hyperparameters for every experiment using W&B.
We would also need to set the random seed for modules to enable or disable deterministic training.
💿 Data Loading and Transformation
Here we use the monai.transforms API to create a custom transform that converts the multi-classes labels into multi-labels segmentation task in one-hot format.
Next, we set up transforms for training and validation datasets respectively.
🍁 The Dataset
The dataset that we will use for this experiment comes from http://medicaldecathlon.com/. We will use Multimodal multisite MRI data (FLAIR, T1w, T1gd, T2w) to segment Gliomas, necrotic/active tumour, and oedema. The dataset consists of 750 4D volumes (484 Training + 266 Testing).
We will use the DecathlonDataset to automatically download and extract the dataset. It inherits MONAI CacheDataset which enables us to set cache_num=N to cache N items for training and use the default args to cache all the items for validation, depending on your memory size.
Note: Instead of applying the train_transform to the train_dataset, we have applied val_transform to both the training and validation datasets. This is because, before training, we would be visualizing samples from both the splits of the dataset.
📸 Visualizing the Dataset
Weights & Biases supports images, video, audio, and more. Log rich media to explore our results and visually compare our runs, models, and datasets. We would be using the segmentation mask overlay system to visualize our data volumes. To log segmentation masks in tables, we will need to provide a `wandb.Image`` object for each row in the table.
An example is provided in the Code snippet below:
table = wandb.Table(columns=["ID", "Image"])
for id, img, label in zip(ids, images, labels):
mask_img = wandb.Image(
img,
masks={
"prediction": {"mask_data": label, "class_labels": class_labels}
# ...
},
)
table.add_data(id, img)
wandb.log({"Table": table})
Let us now write a simple utility function that takes a sample image, label, wandb.Table object and some associated metadata and populate the rows of a table that would be logged to our Weights & Biases dashboard.
Next, we define the wandb.Table object and what columns it consists of so that we can populate with our data visualizations.
Then we loop over the train_dataset and val_dataset respectively to generate the visualizations for the data samples and populate the rows of the table which we would log to our dashboard.
The data appears to us on our W&B dashboard in an interactive tabular format. We can see each channel of a particular slice from a data volume overlayed with the respective segmentation mask in each row. Let us write Weave queries to filter the data on our table and focus on one particular row.

Let us now open an image and check how we can interact with each of the segmentation masks using the interactive overlay.

Note: The labels in the dataset consist of non-overlapping masks across classes, hence, they were logged as separate masks in the overlay.
🛫 Loading the Data
We create the PyTorch dataloaders for loading the data from the datasets. Note that before creating the dataloaders, we set the transform for train_dataset to train_transform to preprocess and transform the data for training.
🤖 Creating the Model, Loss, and Optimizer
In this tutorial we will be training a SegResNet model based on the paper 3D MRI brain tumor segmentation using autoencoder regularization. We create the SegResNet model that comes implemented as a PyTorch Module as part of the monai.networks API. We also create our optimizer and learning rate scheduler.
We define our loss as multi-label DiceLoss using the monai.losses API and the corresponding dice metrics using the monai.metrics API.
🚝 Training and Validation
Before we start training, let us define some metric properties which will later be logged with wandb.log() for tracking our training and validation experiments.
🍭 Execute Standard PyTorch Training Loop
Instrumenting our code with wandb.log not only enables us to track all the metrics associated with our training and validation process, but also the all system metrics (our CPU and GPU in this case) on our W&B dashboard.

If we navigate to the artifacts tab in the W&B run dashboard, we will be able to access the different versions of model checkpoint artifacts that we logged during training.

🔱 Inference
Using the artifacts interface, we can select which version of the artifact is the best model checkpoint, in this case, the mean epoch-wise training loss. We can also explore the entire lineage of the artifact and also use the version that we need.

Let us fetch the version of the model artifact with the best epoch-wise mean training loss and load the checkpoint state dictionary to the model.
📸 Visualizing Predictions and Comparing with the Ground Truth Labels
In order to visualize the predictions of the pre-trained model and compare them with the corresponding ground-truth segmentation mask using the interactive segmentation mask overlay, let us create another ultility function.
Let us see how we can analyze and compare the predicted segmentation masks and the ground-truth labels for each class using the interactive segmentation mask overlay.

You can also check out the report Brain Tumor Segmentation using MONAI and WandB for more details regarding training a brain-tumor segmentation model using MONAI and W&B.