Fine-tuning a pre-trained model for an image classification task on a domain-specific problem can significantly reduce the time and computational resources required for training a deep neural network from scratch. In this tutorial, I will walk through the steps of fine-tuning a pre-trained ResNet-18 model on a custom dataset using PyTorch.
Step 1: Load the pre-trained model
let’s start by loading the pre-trained ResNet-18 model from the PyTorch model zoo. The PyTorch model zoo provides a collection of pre-trained models that can be used for various computer vision tasks.
import torch
import torchvision.models as models
# Load the pre-trained ResNet-18 model
model = models.resnet18(pretrained=True)
Step 2: Freeze the pre-trained layers
Next, freeze the pre-trained layers of the ResNet-18 model so that we can only train the last few layers for our specific task. This is done to prevent the weights of the pre-trained layers from being updated during training.
# Freeze all the pre-trained layers
for param in model.parameters():
param.requires_grad = False
Step 3: Modify the last layer
The last layer of the pre-trained ResNet-18 model is a fully connected layer that outputs a 1000-dimensional vector. Since the main idea here is to use this model for a domain-specific image classification task, we need to modify the last layer to output the number of classes in our dataset.
# Modify the last layer of the model
num_classes = 10 # replace with the number of classes in your dataset
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)
Step 4: Load the custom dataset
Next, load the custom dataset that you want to use for training the fine-tuned model. In this example, I will assume that the dataset is organized in the following directory structure:
custom_dataset/
├── train/
│ ├── class1/
│ ├── class2/
│ ├── ...
├── val/
│ ├── class1/
│ ├── class2/
│ ├── ...
I will use the PyTorch ImageFolder
dataset class to load the dataset.
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms
# Define the transformations to apply to the images
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Load the train and validation datasets
train_dataset = ImageFolder('custom_dataset/train', transform=transform)
val_dataset = ImageFolder('custom_dataset/val', transform=transform)
Step 5: Define the loss function and optimizer
I will use the cross-entropy loss function and the stochastic gradient descent (SGD) optimizer for training the fine-tuned model.
# Define the loss function and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.fc.parameters(), lr=0.001, momentum=0.9)
Step 6: Train the model
Finally, train the fine-tuned model on the custom dataset using PyTorch’s DataLoader
and TrainLoader
utilities.
from torch.utils.data import DataLoader
# Create data loaders for the train and validation datasets
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
Next, define a function to train the fine-tuned model for a specified number of epochs.
def train(model, train_loader, val_loader, criterion, optimizer, num_epochs):
# Train the model for the specified number of epochs
for epoch in range(num_epochs):
# Set the model to train mode
model.train()
# Initialize the running loss and accuracy
running_loss = 0.0
running_corrects = 0
# Iterate over the batches of the train loader
for inputs, labels in train_loader:
# Move the inputs and labels to the device
inputs = inputs.to(device)
labels = labels.to(device)
# Zero the optimizer gradients
optimizer.zero_grad()
# Forward pass
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
# Backward pass and optimizer step
loss.backward()
optimizer.step()
# Update the running loss and accuracy
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
# Calculate the train loss and accuracy
train_loss = running_loss / len(train_dataset)
train_acc = running_corrects.double() / len(train_dataset)
# Set the model to evaluation mode
model.eval()
# Initialize the running loss and accuracy
running_loss = 0.0
running_corrects = 0
# Iterate over the batches of the validation loader
with torch.no_grad():
for inputs, labels in val_loader:
# Move the inputs and labels to the device
inputs = inputs.to(device)
labels = labels.to(device)
# Forward pass
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
# Update the running loss and accuracy
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
# Calculate the validation loss and accuracy
val_loss = running_loss / len(val_dataset)
val_acc = running_corrects.double() / len(val_dataset)
# Print the epoch results
print('Epoch [{}/{}], train loss: {:.4f}, train acc: {:.4f}, val loss: {:.4f}, val acc: {:.4f}'
.format(epoch+1, num_epochs, train_loss, train_acc, val_loss, val_acc))
Step 7: Fine-tune the model on the custom dataset
Let’s now fine-tune the pre-trained ResNet-18 model on the custom dataset by training the last layer for a few epochs and then unfreezing all the layers and training the entire network for a few more epochs.
# Set the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
# Fine-tune the last layer for a few epochs
optimizer = torch.optim.SGD(model.fc.parameters(), lr=0.01, momentum=0.9)
train(model, train_loader, val_loader, criterion, optimizer, num_epochs=5)
# Unfreeze all the layers and fine-tune the entire network for a few more epochs
for param in model.parameters():
param.requires_grad = True
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
train(model, train_loader, val_loader, criterion, optimizer, num_epochs=10)
Summary
In this tutorial, I described how to fine-tune a pre-trained ResNet-18 model for a domain-specific image classification task using PyTorch. I loaded the pre-trained model, froze the pre-trained layers, modified the last layer, loaded the custom dataset, defined the loss function and optimizer, and trained the fine-tuned model for a specified number of epochs. I also fine-tuned the last layer for a few epochs and then unfroze all the layers and fine-tuned the entire network for a few more epochs.
Fine-tuning pre-trained models is a powerful technique that can significantly reduce the time and computational resources required for training deep neural networks for specific tasks. It is also a common approach in production environments where the datasets are usually small and the training time is limited.
The code snippets provided in this tutorial can serve as a starting point for fine-tuning pre-trained models on domain-specific image classification tasks. However, the approach may need to be adapted to different datasets and tasks, and hyperparameters may need to be tuned for optimal performance.