快速学会一个算法,UNet
UNet 是一种专门用于图像分割任务的卷积神经网络(CNN)架构,最早由 Olaf Ronneberger 等人在 2015 年提出。
Unet 模型架构
编码器(收缩路径)
每个卷积层通常包含两次卷积操作(使用 3x3 卷积核),每次卷积操作后接一个 ReLU 激活函数。 之后,采用一个 2x2 的最大池化层(Max Pooling)进行下采样,以减少特征图的空间维度。 每次下采样后,特征图的空间尺寸减小,而通道数增加,以提取更高层次的特征。
解码器(扩展路径)
上采样(Upsampling),通常通过反卷积将特征图的空间分辨率逐步恢复。 上采样后,通过跳跃连接(Skip Connection)将对应层的编码器特征与解码器特征拼接在一起,这样可以保留输入图像的细节。 拼接后的特征图经过两次卷积操作(同样使用 3x3 卷积核)和 ReLU 激活函数进行处理。 最终,经过逐步上采样和卷积,恢复到与输入图像相同的分辨率。
跳跃连接 (Skip Connections)
UNet模型的优点
高效处理小样本数据集
UNet 最初设计用于生物医学图像分割,具有高效利用小样本数据集的能力。 精细的分割结果 通过跳跃连接,UNet 能够很好地保留高分辨率的细节,使得分割结果更为精确。 灵活性强
UNet 结构简单且有效,容易扩展和调整,适应不同类型的分割任务。
案例分享
下面是一个使用 PyTorch 实现 UNet 模型的代码示例。这个示例展示了一个简化版的UNet模型,并应用于图像分割任务。
import torch
import torch.nn as nn
import torch.nn.functional as F
class UNet(nn.Module):
def __init__(self, in_channels, out_channels):
super(UNet, self).__init__()
# 编码器部分
self.encoder1 = self.double_conv(in_channels, 64)
self.encoder2 = self.double_conv(64, 128)
self.encoder3 = self.double_conv(128, 256)
self.encoder4 = self.double_conv(256, 512)
# 最底部的卷积
self.bottleneck = self.double_conv(512, 1024)
# 解码器部分
self.upconv4 = self.upconv(1024, 512)
self.decoder4 = self.double_conv(1024, 512)
self.upconv3 = self.upconv(512, 256)
self.decoder3 = self.double_conv(512, 256)
self.upconv2 = self.upconv(256, 128)
self.decoder2 = self.double_conv(256, 128)
self.upconv1 = self.upconv(128, 64)
self.decoder1 = self.double_conv(128, 64)
# 最终的1x1卷积,用于生成分割图
self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)
def double_conv(self, in_channels, out_channels):
"""两次卷积操作"""
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def upconv(self, in_channels, out_channels):
"""上采样操作"""
return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
def forward(self, x):
# 编码器部分
enc1 = self.encoder1(x)
enc2 = self.encoder2(F.max_pool2d(enc1, kernel_size=2))
enc3 = self.encoder3(F.max_pool2d(enc2, kernel_size=2))
enc4 = self.encoder4(F.max_pool2d(enc3, kernel_size=2))
# Bottleneck
bottleneck = self.bottleneck(F.max_pool2d(enc4, kernel_size=2))
# 解码器部分
dec4 = self.upconv4(bottleneck)
dec4 = torch.cat((dec4, self.crop_tensor(enc4, dec4)), dim=1)
dec4 = self.decoder4(dec4)
dec3 = self.upconv3(dec4)
dec3 = torch.cat((dec3, self.crop_tensor(enc3, dec3)), dim=1)
dec3 = self.decoder3(dec3)
dec2 = self.upconv2(dec3)
dec2 = torch.cat((dec2, self.crop_tensor(enc2, dec2)), dim=1)
dec2 = self.decoder2(dec2)
dec1 = self.upconv1(dec2)
dec1 = torch.cat((dec1, self.crop_tensor(enc1, dec1)), dim=1)
dec1 = self.decoder1(dec1)
# 最后的1x1卷积生成输出
return self.final_conv(dec1)
def crop_tensor(self, encoder_tensor, decoder_tensor):
"""裁剪编码器张量,使其与解码器张量大小匹配"""
_, _, H, W = decoder_tensor.size()
encoder_tensor = self.center_crop(encoder_tensor, H, W)
return encoder_tensor
def center_crop(self, tensor, target_height, target_width):
"""中心裁剪函数"""
_, _, h, w = tensor.size()
crop_y = (h - target_height) // 2
crop_x = (w - target_width) // 2
return tensor[:, :, crop_y:crop_y + target_height, crop_x:crop_x + target_width]
# 使用示例
model = UNet(in_channels=1, out_channels=1) # 输入和输出均为1通道(例如用于灰度图像)
input_image = torch.randn(1, 1, 572, 572) # 随机生成一个输入图像
output = model(input_image)
print(output.shape)
最后
—
「进群方式:加我微信,备注 “python”」
往期回顾
Fashion-MNIST 服装图片分类-Pytorch实现