ModernBERT in Radiology Part 3: Fine-tuning a Classifier
Fine-tune ModernBERT Classifier to predict the UMLS CUIs given a radiology report
In Part 3 of the ModernBERT in Radiology series, we will fine-tune a ModernBERT Classifier to predict the UMLS CUIs given a radiology report. It will combine our fine-tuning from Part 2 to produce a better classifier than the simple scikit-learn Logistic Regression from Part 1.
You can follow along with the associated Colab Notebook for Part 3🔥!
The ModernBERT in Radiology Series
Part 1: Simple Classifier using Hidden States.
Build a multi-label classification using a simple scikit-learn Logistic Regression model on top of the pre-trained ModernBERT body.
Part 2: Fine-tuning a Masked Language Model (MLM).
ModernBERT is pretrained as a Masked Language Model, but we perform a full fine-tuning using radiology text using
AutoModelForMaskedLM
.Part 3: Fine-tuning for Classification. 👈 This Post
Combining Part 1 and Part 2, we will build a ModernBERT classifier performing a fine-tuning of the entire model using
AutoModelForSequenceClassification
Objective
Whereas Part 1 used the hidden state of the ModernBERT body to train a simple classifier, we are going to put a proper Hugging Face neural network classifier head on the ModernBERT body and fine-tune using the unsloth/Radiology_mini
to perform multi-label classification from radiology text to UMLS CUI (concept ID). Finally, we will publish to Hugging Face as johnpaulett/ModernRadBERT-cui-classifier
.
We will follow parts of the excellent Natural Language Processing with Transforms book with code available on GitHub. Part 3 follows Chapter 2 of this book. However, our problem deviates from this chapter since we are performing multi-label classification (i.e., each text can have one or more CUI labels).
WARNING: Since the cui
concepts were generated via MedCAT, we will be learning MedCAT’s predictions.
Code
See Colab for the full Notebook: https://colab.research.google.com/drive/11hpCvNb4g65Igcmz1ePyuqdNmWwYwj4g?usp=sharing
Setup
In Part 3, I used a Colab nVidia L4 GPU. We use Hugging Face 🤗 transformers AutoModelForSequenceClassification
to load the pre-trained ModernBERT for full fine-tuning.
pip install datasets evaluate wandb triton
# flash attention only works on ampere devices (e.g. not T4)
pip install flash-attn
# Until next transformers release (4.48.0)
pip install git+https://github.com/huggingface/transformers.git
model_id = (
"answerdotai/ModernBERT-base"
# answerdotai/ModernBERT-large
)
dataset_name = (
# "eltorio/ROCOv2-radiology"
"unsloth/Radiology_mini" # 0.33% of ROCOv2-radiology, for a quicker demo
)
push_to_hub = True
output_dir = "ModernRadBERT-cui-classifier"
Load & Transform the Dataset
See Part 1 for details on the dataset.
Load the dataset in and re-split.
from datasets import load_dataset, DatasetDict
original_dataset = load_dataset(dataset_name)
print(f"Training Size: {original_dataset['train'].size_in_bytes / (1024 * 1024 * 1024):.2f} GB")
validation_size = int(0.15 * (len(original_dataset['train']) + len(original_dataset['test'])))
dataset = DatasetDict({
'train': original_dataset['train'].shuffle(seed=42).select(range(validation_size, len(original_dataset['train']))),
'validation': original_dataset['train'].shuffle(seed=42).select(range(validation_size), ) ,
# Keep the test -- we'll hold this back for comparison between models
'test': original_dataset['test']
})
dataset = dataset.remove_columns(['image'])
# Now 'new_dataset' contains the training and validation sets
# You can use new_dataset['train'] and new_dataset['validation']
print(f"training set size: {len(dataset['train'])}")
print(f"validation set size: {len(dataset['validation'])}")
print(f"test set size: {len(dataset['test'])}")
Convert the caption
into tokens using ModernBERT’s tokenizer:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
def tokenize_function(examples):
return tokenizer(
examples["caption"],
padding="max_length",
truncation=True,
# ModernBERT allows an increase to 8124 from 512 in BERT!
# Our max len() of the captions in the train set is 934, so roughly 934/4 ~= 233,
# and further testing of the longest attention_mask shows this is actually 206.
# Increasing too high will consume significant memory while we extract
# the hidden states for all the inputs.
max_length=256,
)
dataset = dataset.map(
tokenize_function, batched=True
)
Since cui
is a multi-label, we will use scikit-learn’s MultiLabelBinarizer
:
from sklearn.preprocessing import MultiLabelBinarizer
import numpy as np
mlb = MultiLabelBinarizer()
train_labels = mlb.fit(dataset['train']['cui'])
def transform_labels(example):
# Transform single example's CUIs to binary vector
binary_labels = mlb.transform([example['cui']])[0] # [0] to get the single example's labels
# Convert to float32 for BCEWithLogitsLoss
example['labels'] = binary_labels.astype(np.float32).tolist()
example['num_labels'] = sum(binary_labels)
return example
dataset = dataset.map(
transform_labels,
desc="Transforming labels to binary vectors",
num_proc=4,
)
Train the Classifier
We use AutoModelForSequenceClassification
with the number of distinct labels, making sure to set problem_type="multi_label_classification"
:
from transformers import AutoModelForSequenceClassification
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model = AutoModelForSequenceClassification.from_pretrained(
model_id,
num_labels=len(mlb.classes_),
problem_type="multi_label_classification"
).to(device)
Prepare the F1 score to compute and evaluate each epoch:
from sklearn.metrics import precision_recall_fscore_support, accuracy_score, hamming_loss
# We use sklearn's metrics instead of `evaluate`, due to evaluate's f1 only wanting a single value, not a one-hot array
def compute_metrics(eval_pred):
predictions, labels = eval_pred
# Apply sigmoid activation and threshold at 0.5
predictions = 1 / (1 + np.exp(-predictions)) # sigmoid
predictions = (predictions > 0.5).astype(int)
# Calculate micro-averaged metrics
precision_micro, recall_micro, f1_micro, _ = precision_recall_fscore_support(
labels, predictions, average='micro', zero_division=0
)
# Calculate macro-averaged metrics
precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
labels, predictions, average='macro', zero_division=0
)
# Calculate subset accuracy (exact match)
subset_accuracy = accuracy_score(labels, predictions)
# Calculate Hamming loss
ham_loss = hamming_loss(labels, predictions)
# Calculate per-label accuracy (element-wise)
label_wise_accuracy = np.mean((predictions == labels).astype(float))
results = {
# Micro-averaged metrics
"precision_micro": precision_micro,
"recall_micro": recall_micro,
"f1": f1_micro,
# Macro-averaged metrics
"precision_macro": precision_macro,
"recall_macro": recall_macro,
"f1_macro": f1_macro,
# Other metrics
"exact_match": subset_accuracy,
"hamming_loss": ham_loss,
"label_accuracy": label_wise_accuracy
}
return results
We then conduct 20 epochs of training:
from transformers import TrainingArguments, Trainer
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=20,
learning_rate=2e-5,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
weight_decay=0.01,
logging_strategy="steps", # Log metrics every n steps
logging_steps=100, # Log every 100 steps
eval_strategy="epoch",
metric_for_best_model="f1",
greater_is_better=True, # higher f1 is better
save_strategy="epoch",
load_best_model_at_end=True,
save_total_limit=3, # Only keep the 3 best checkpoints
push_to_hub=push_to_hub,
report_to="none", # Comment to enable wandb
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["validation"],
compute_metrics=compute_metrics,
)
trainer.train()
Use Model to Predict:
def predict(text, threshold=0.5):
# Tokenize input
inputs = tokenizer(
text,
padding=True,
truncation=True,
max_length=256,
return_tensors="pt"
)
# Move inputs to device
inputs = {k: v.to(device) for k, v in inputs.items()}
# Get predictions
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
# Apply sigmoid to get probabilities
probs = torch.sigmoid(logits).cpu().numpy()
# Get binary predictions
predictions = (probs > threshold).astype(int)
# Convert binary predictions back to labels
predicted_labels = mlb.inverse_transform(predictions)[0]
# Create dictionary of label probabilities
label_probs = {
label: float(prob) # Convert to Python float for JSON serialization
for label, prob in zip(mlb.classes_, probs[0])
if prob > threshold
}
# Sort by probability
label_probs = dict(sorted(label_probs.items(), key=lambda x: x[1], reverse=True))
# return {
# "labels": list(predicted_labels),
# "probabilities": label_probs
# }
print(f"{text}: {[cui_map.get(label) for label in list(predicted_labels)]}")
We can see the Anterior-Posterior, which is often related but not explicitly stated. Ideally, using the full ROCOv2 dataset would have performed better.
predict("CT of Chest with pneumothorax")
# CT of Chest with pneumothorax: ['X-Ray Computed Tomography (C0040405)']
predict("Abdomen x-ray with small bowel obstruction")
# Abdomen x-ray with small bowel obstruction: ['Abdomen (C0000726)', 'Plain x-ray (C1306645)', 'Anterior-Posterior (C1999039)']
Push to Hugging Face
Finally, we upload the fine-tuned weights to HuggingFace: https://huggingface.co/johnpaulett/ModernRadBERT-cui-classifier
trainer.push_to_hub()
Conclusion
Fine-tuning a radiology transformer for classification tasks on the report text is incredibly powerful.
Billing tasks such as extracting CPT and ICD codes
Specific findings classification, normal vs abnormal, scoring
Quality assessments such as follow-up recommendations, critical result communication
Explore the Fine-Tuned Model
You can pull down this fine-tuned model (WARNING: it is trained on a small dataset as a demo, so do not use it for any real problems):
from transformers import pipeline
pipe = pipeline("fill-mask", model="johnpaulett/ModernRadBERT-cui-classifier")
Citations
Warner, B., Chaffin, A., Clavié, B., Weller, O., Hallström, O., Taghadouini, S., Gallagher, A., Biswas, R., Ladhak, F., Aarsen, T., Cooper, N., Adams, G., Howard, J., & Poli, I. (2024). Smarter, Better, Faster, Longer: A Modern Bidirectional Encoder for Fast, Memory Efficient, and Long Context Finetuning and Inference. arXiv preprint arXiv:2412.13663.
Ronan, L. M. (2024). ROCOv2-radiology [Dataset]. Hugging Face. https://doi.org/10.57967/hf/3489
Tunstall, L. (2022). Natural Language Processing with Transformers: Building Language Applications with Hugging Face. O’Reilly Media.