Deep Learning Approaches in Medical Image Segmentation
Image by Editor | Midjourney
Medical imaging has been revolutionized by the adoption of deep learning techniques. The use of this branch of machine learning has ushered in a new era of precision and efficiency in medical image segmentation, a central analytical process in modern healthcare diagnostics and treatment planning. By harnessing neural networks, deep learning algorithms are able to detect anomalies within medical images with unprecedented accuracy.
This technological leap forward is helping to reshape the paradigm of how we approach medical image analysis. From improving early disease detection to facilitating personalized treatment strategies, deep learning in medical image segmentation is paving the way for more targeted and effective patient care. In this article, we will delve into the transformative methods that deep learning brings to the field of medical image segmentation, exploring how these advanced algorithms are pushing the boundaries of what’s possible in medical imaging and, by extension, in healthcare itself.
Introduction to Medical Image Segmentation
Medical image segmentation involves dividing an image into different regions. Each region represents a specific structure or feature, like organs or tumors. This process is important for understanding and analyzing medical images. It helps doctors diagnose diseases more accurately. Segmentation helps plan treatments and track how a patient’s condition changes.
Common Deep Learning Architectures for Image Segmentation
Let’s start by looking at the various common architectures for image segmentation with deep learning.
1. U-Net
U-Net has a “U” shape with an encoder for context and a decoder for precise localization. Skip connections in U-Net keep important details from the encoder and decoder layers. U-Net helps to segment organs, brain tumors, lung nodules, and other key structures in MRI and CT scans.
2. Fully Convolutional Networks (FCNs)
FCNs use convolutional layers throughout the network, instead of fully connected layers. This enables the model to produce dense segmentation maps. FCNs maintain the spatial dimensions of the input image with up-sampling techniques. They help classify each pixel individually. For example, they help find brain tumors in MRI scans and show where the liver is in CT images.
3. SegNet
SegNet balances performance with computational efficiency. Its encoder-decoder design first reduces the image size and then enlarges it again to create detailed segmentation maps. SegNet stores max-pooling indices during encoding and reuses them during decoding to improve accuracy. It is used for segmenting retinal blood vessels, lung fields in X-rays, and other structures where efficiency is important.
4. DeepLab
DeepLab uses atrous convolutions to expand the receptive field while keeping spatial resolution. The ASPP module captures features at different scales. This helps the model handle images with varying resolutions. DeepLab is used for tasks like finding brain tumors, liver lesions, and heart details in MRI scans.
Example: U-Net Lung Tumor Segmentation
Let’s now have a look at a step-by-step example of lung tumor segmentation using a U-Net model.
1. Mount Google Drive
First we will mount Google Drive to access files stored in it.
from google.colab import drive
drive.mount('/content/drive')
2. Define Folder Paths
Now we set the paths for the folders containing images and labels in Google Drive.
# Define paths to the folders in Google Drive
image_folder_path="/content/drive/My Drive/Dataset/Lung dataset"
label_folder_path="/content/drive/My Drive/Dataset/Ground truth"
3. Collect PNG Files
Next, define a function to gather and sort all PNG file paths from a specified folder.
# Function to collect PNG images from a folder
def collect_png_from_folder(folder_path):
png_files = []
for root, _, files in os.walk(folder_path):
for file in files:
if file.endswith(".png"):
png_files.append(os.path.join(root, file))
return sorted(png_files)
4. Load and Preprocess Dataset
Next we will define a function to load and preprocess images and labels from their respective folders. This function ensures that the images and labels are correctly matched and resized.
# Function to load images and labels directly
def load_images_and_labels(image_folder_path, label_folder_path, target_size=(256, 256), filter_size=3):
# Collect file paths
image_files = collect_png_from_folder(image_folder_path)
label_files = collect_png_from_folder(label_folder_path)
# Ensure images and labels are sorted and match in number
if len(image_files) != len(label_files):
raise ValueError("Number of images and labels do not match.")
# Load images
def load_image(image_path):
image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
if image is None:
raise ValueError(f"Unable to load image: image_path")
image = cv2.resize(image, target_size)
image = cv2.medianBlur(image, filter_size)
return image.astype('float32') / 255.0
# Load labels
def load_label(label_path):
label = cv2.imread(label_path, cv2.IMREAD_COLOR)
if label is None:
raise ValueError(f"Unable to load label image: label_path")
return cv2.resize(label, target_size)
images = np.array([load_image(path) for path in image_files])
labels = np.array([load_label(path) for path in label_files])
return images, labels
5. Display Images and Labels
Now we will define a function to display a specified number of images and their corresponding labels side-by-side. Use the previously defined function to load images and labels, and then display a few samples for visualization. The blue spots represent the tumor labels.
# Function to display images and labels
def display_images_and_labels(images, labels, num_samples=5):
num_samples = min(num_samples, len(images))
plt.figure(figsize=(15, 3 * num_samples))
for i in range(num_samples):
plt.subplot(num_samples, 2, 2 * i + 1)
plt.title(f'Image i + 1')
plt.imshow(images[i], cmap='gray')
plt.axis('off')
plt.subplot(num_samples, 2, 2 * i + 2)
plt.title(f'Label i + 1')
plt.imshow(labels[i])
plt.axis('off')
plt.tight_layout()
plt.show()
# Load images and labels
images, labels = load_images_and_labels(image_folder_path, label_folder_path)
# Display a few samples
display_images_and_labels(images, labels, num_samples=5)
6. Define the U-Net model
Now it’s time to define the U-Net model. The U-Net architecture uses the Adam optimizer. It employs categorical crossentropy as the loss function. Accuracy is used as the evaluation metric.
# Define the U-Net model
def unet_model(input_size=(256, 256, 1), num_classes=3):
inputs = Input(input_size)
# Encoder (Downsampling Path)
c1 = Conv2D(64, (3, 3), activation='relu', kernel_initializer="he_normal", padding='same')(inputs)
c1 = Dropout(0.1)(c1)
c1 = Conv2D(64, (3, 3), activation='relu', kernel_initializer="he_normal", padding='same')(c1)
p1 = MaxPooling2D((2, 2))(c1)
c2 = Conv2D(128, (3, 3), activation='relu', kernel_initializer="he_normal", padding='same')(p1)
c2 = Dropout(0.1)(c2)
c2 = Conv2D(128, (3, 3), activation='relu', kernel_initializer="he_normal", padding='same')(c2)
p2 = MaxPooling2D((2, 2))(c2)
c3 = Conv2D(256, (3, 3), activation='relu', kernel_initializer="he_normal", padding='same')(p2)
c3 = Dropout(0.2)(c3)
c3 = Conv2D(256, (3, 3), activation='relu', kernel_initializer="he_normal", padding='same')(c3)
p3 = MaxPooling2D((2, 2))(c3)
c4 = Conv2D(512, (3, 3), activation='relu', kernel_initializer="he_normal", padding='same')(p3)
c4 = Dropout(0.2)(c4)
c4 = Conv2D(512, (3, 3), activation='relu', kernel_initializer="he_normal", padding='same')(c4)
p4 = MaxPooling2D(pool_size=(2, 2))(c4)
# Bottleneck
c5 = Conv2D(1024, (3, 3), activation='relu', kernel_initializer="he_normal", padding='same')(p4)
c5 = Dropout(0.3)(c5)
c5 = Conv2D(1024, (3, 3), activation='relu', kernel_initializer="he_normal", padding='same')(c5)
# Decoder (Upsampling Path)
u6 = Conv2DTranspose(512, (2, 2), strides=(2, 2), padding='same')(c5)
u6 = concatenate([u6, c4])
c6 = Conv2D(512, (3, 3), activation='relu', kernel_initializer="he_normal", padding='same')(u6)
c6 = Dropout(0.2)(c6)
c6 = Conv2D(512, (3, 3), activation='relu', kernel_initializer="he_normal", padding='same')(c6)
u7 = Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(c6)
u7 = concatenate([u7, c3])
c7 = Conv2D(256, (3, 3), activation='relu', kernel_initializer="he_normal", padding='same')(u7)
c7 = Dropout(0.2)(c7)
c7 = Conv2D(256, (3, 3), activation='relu', kernel_initializer="he_normal", padding='same')(c7)
u8 = Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(c7)
u8 = concatenate([u8, c2])
c8 = Conv2D(128, (3, 3), activation='relu', kernel_initializer="he_normal", padding='same')(u8)
c8 = Dropout(0.1)(c8)
c8 = Conv2D(128, (3, 3), activation='relu', kernel_initializer="he_normal", padding='same')(c8)
u9 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(c8)
u9 = concatenate([u9, c1], axis=3)
c9 = Conv2D(64, (3, 3), activation='relu', kernel_initializer="he_normal", padding='same')(u9)
c9 = Dropout(0.1)(c9)
c9 = Conv2D(64, (3, 3), activation='relu', kernel_initializer="he_normal", padding='same')(c9)
# Output layer
outputs = Conv2D(num_classes, (1, 1), activation='softmax')(c9)
model = Model(inputs=[inputs], outputs=[outputs])
# Compile the model
model.compile(optimizer="adam",
loss="categorical_crossentropy",
metrics=['accuracy'])
return model
7. Training the U-Net Model
Here we will train the U-Net model and save it to a file. Training and validation accuracy and loss over epochs are plotted to visualize the model’s performance. This model can be used for testing on new data.
from sklearn.model_selection import train_test_split
# Split the data into training, validation, test sets
X_train, X_temp, y_train, y_temp = train_test_split(X, Y, test_size=0.4, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)
# Define EarlyStopping callback
early_stopping = EarlyStopping(monitor="val_loss", patience=3, restore_best_weights=True)
# Train the model with EarlyStopping
history = model.fit(X_train, y_train,
epochs=50,
batch_size=16,
validation_data=(X_val, y_val),
callbacks=[early_stopping])
# Save the model
model.save('/content/unet_real_data.h5')
# Function to Plot Accuracy
def plot_accuracy(history):
epochs = range(1, len(history.history['accuracy']) + 1)
# Plot Training and Validation Accuracy
plt.figure(figsize=(6, 4))
plt.plot(epochs, history.history['accuracy'], 'bo-', label="Training Accuracy")
plt.plot(epochs, history.history['val_accuracy'], 'ro-', label="Validation Accuracy")
plt.title('Training and Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.tight_layout()
plt.show()
# Function to Plot Loss
def plot_loss(history):
epochs = range(1, len(history.history['loss']) + 1)
# Plot Training and Validation Loss
plt.figure(figsize=(6, 4))
plt.plot(epochs, history.history['loss'], 'bo-', label="Training Loss")
plt.plot(epochs, history.history['val_loss'], 'ro-', label="Validation Loss")
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.tight_layout()
plt.show()
# Call the functions to plot accuracy and loss
plot_accuracy(history)
plot_loss(history)
Figure 1: Training and validation accuracy plot
Figure 2: Training and validation loss plot
Advantages of Deep Learning in Medical Segmentation
The advantages of deep learning in medical segmentation are numerous. Here are a few of the important ones:
- Improved Accuracy: Deep learning models are very good at accurately segmenting medical images. They can find and outline small or tricky details that might be missed with older methods.
- Efficiency and Speed: These models can quickly process and analyze many images. They make the segmentation process faster and reduce the need for human work.
- Handling Complex Data: Deep learning models can work with complex 3D images from CT or MRI scans. They can handle different types of images and adapt to various imaging techniques.
Challenges of Deep Learning in Medical Image Segmentation
Just as there are advantages, we must also keep in mind the challenges of using the technology.
- Limited Data: There aren’t always enough labeled medical images to train deep learning models. Creating these labels is time-consuming and requires skilled experts. This makes it challenging to get enough data for training.
- Privacy Concerns: Medical images include sensitive patient information, so there are strict rules to keep this data private. This means there might not be as much data available for research and training.
- Interpretability: Deep learning models can be tricky to understand. This makes it difficult to trust and validate their results.
Conclusion
In summary, deep learning has made medical image segmentation much better. Methods like convolutional neural networks and transformers improve how we analyze images. This leads to more accurate diagnoses and better patient care.
Jayita Gulati is a machine learning enthusiast and technical writer driven by her passion for building machine learning models. She holds a Master’s degree in Computer Science from the University of Liverpool.
Our Top 3 Course Recommendations
1. Google Cybersecurity Certificate – Get on the fast track to a career in cybersecurity.
2. Google Data Analytics Professional Certificate – Up your data analytics game
3. Google IT Support Professional Certificate – Support your organization in IT