当前位置: 首页 > news >正文

详细介绍:图像分割:PyTorch从零开始实现SegFormer语义分割

详细介绍:图像分割:PyTorch从零开始实现SegFormer语义分割

图像分割:PyTorch从零开始实现SegFormer语义分割

  • 前言
  • 环境要求
  • 相关介绍
  • SegFormer核心模块:编码器(MiT)和解码器(All-MLP解码器)。
    • 编码器(MiT):
      • 分层结构,产生多尺度特征(通常有4个阶段,每个阶段特征图尺寸递减)。
      • 每个阶段由多个Transformer块组成,每个块包含:
        • 重叠块嵌入(Overlapped Patch Embedding)
        • 高效自注意力(Efficient Self-Attention)
        • 混合前馈网络(Mix FeedForward Network)
    • 解码器(All-MLP):
      • 将多尺度特征上采样到相同尺寸并拼接。
      • 通过多层感知机(MLP)得到分割结果。
  • 具体实现
    • 导入相关库
    • 准备数据集
    • 定义网络模型
    • 训练验证
    • 推理预测
    • 主函数
    • 输出结果
    • 完整代码
  • 参考

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

前言

环境要求

Package                Version      Editable project location
---------------------- ------------ ----------------------------------------------
addict                 2.4.0
aliyun-python-sdk-core 2.16.0
aliyun-python-sdk-kms  2.16.5
certifi                2025.8.3
cffi                   2.0.0
charset-normalizer     3.4.3
click                  8.3.0
colorama               0.4.6
contourpy              1.3.2
crcmod                 1.7
cryptography           46.0.1
cycler                 0.12.1
einops                 0.8.1
filelock               3.14.0
fonttools              4.60.0
fsspec                 2025.9.0
ftfy                   6.3.1
huggingface-hub        0.35.1
idna                   3.10
jmespath               0.10.0
kiwisolver             1.4.9
Markdown               3.9
markdown-it-py         4.0.0
matplotlib             3.10.6
mdurl                  0.1.2
mmcv                   2.1.0
mmcv-full              1.2.7
mmengine               0.10.7
mmsegmentation         0.11.0
model-index            0.1.11
numpy                  1.26.3
opencv-python          4.6.0.66
opendatalab            0.0.10
openmim                0.3.9
openxlab               0.1.2
ordered-set            4.1.0
oss2                   2.17.0
packaging              24.2
pandas                 2.3.2
pillow                 11.3.0
pip                    23.0.1
platformdirs           4.4.0
polars                 1.33.1
prettytable            3.16.0
psutil                 7.1.0
pycparser              2.23
pycryptodome           3.23.0
Pygments               2.19.2
pyparsing              3.2.5
python-dateutil        2.9.0.post0
pytz                   2023.4
pywin32                311
PyYAML                 6.0.3
regex                  2025.9.18
requests               2.28.2
rich                   13.4.2
safetensors            0.6.2
scipy                  1.15.3
setuptools             60.2.0
six                    1.17.0
tabulate               0.9.0
termcolor              3.1.0
terminaltables         3.1.10
timm                   1.0.20
tomli                  2.2.1
torch                  1.13.1+cu116
torchaudio             0.13.1+cu116
torchvision            0.14.1+cu116
tqdm                   4.65.2
typing_extensions      4.15.0
tzdata                 2025.2
ultralytics            8.3.203
ultralytics-thop       2.0.17
urllib3                1.26.20
wcwidth                0.2.14
yapf                   0.43.0

相关介绍

  • Python是一种跨平台的计算机程序设计语言。是一个高层次的结合了解释性、编译性、互动性和面向对象的脚本语言。最初被设计用于编写自动化脚本(shell),随着版本的不断更新和语言新功能的添加,越多被用于独立的、大型项目的开发。
  • PyTorch 是一个深度学习框架,封装好了很多网络和深度学习相关的工具方便我们调用,而不用我们一个个去单独写了。它分为 CPU 和 GPU 版本,其他框架还有 TensorFlow、Caffe 等。PyTorch 是由 Facebook 人工智能研究院(FAIR)基于 Torch 推出的,它是一个基于 Python 的可续计算包,提供两个高级功能:1、具有强大的 GPU 加速的张量计算(如 NumPy);2、构建深度神经网络时的自动微分机制。
  • SegFormer 是一个简单、高效但功能强大的语义分割框架,它将 Transformers 与轻量级多层感知器 (MLP) 解码器结合在一起。
  • SegFormer 有两个吸引人的特点:
    1. SegFormer 包含一个新颖的分层结构变换器编码器,可输出多尺度特征。它不需要位置编码,从而避免了位置编码的插值,当测试分辨率与训练分辨率不同时,插值会导致性能下降。
    2. SegFormer 避免了复杂的解码器。所提出的 MLP 解码器汇聚了来自不同层的信息,从而将局部注意力和全局注意力结合起来,呈现出强大的表征。
  • 这种简单轻便的设计是在 Transformers 上实现高效分割的关键。通过扩展,获得了从 SegFormer-B0 到 SegFormer-B5 的一系列模型,其性能和效率明显优于之前的同类产品。
  • 例如,SegFormer-B4 在 64M 参数的 ADE20K 上实现了 50.3% 的 mIoU,比之前的最佳方法小 5 倍,好 2.2%。最佳模型 SegFormer-B5 在 Cityscapes 验证集上实现了 84.0% 的 mIoU,并在 Cityscapes-C 上显示了出色的零点稳健性。
  • 官方源代码: https://github.com/NVlabs/SegFormer.git
  • Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo. SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers. 2021
    在这里插入图片描述

SegFormer核心模块:编码器(MiT)和解码器(All-MLP解码器)。

在这里插入图片描述

class Segformer(nn.Module):
def __init__(
self,
*,
dims=(32, 64, 160, 256),
heads=(1, 2, 5, 8),
ff_expansion=(8, 8, 4, 4),
reduction_ratio=(8, 4, 2, 1),
num_layers=2,
channels=3,
decoder_dim=256,
num_classes=4
):
super().__init__()
dims, heads, ff_expansion, reduction_ratio, num_layers = map(partial(cast_tuple, depth=4), (dims, heads, ff_expansion, reduction_ratio, num_layers))
assert all([*map(lambda t: len(t) == 4, (dims, heads, ff_expansion, reduction_ratio, num_layers))]), 'only four stages are allowed, all keyword arguments must be either a single value or a tuple of 4 values'
self.mit = MiT(
channels=channels,
dims=dims,
heads=heads,
ff_expansion=ff_expansion,
reduction_ratio=reduction_ratio,
num_layers=num_layers
)
self.to_fused = nn.ModuleList([nn.Sequential(
nn.Conv2d(dim, decoder_dim, 1),
nn.Upsample(scale_factor=2 ** i)
) for i, dim in enumerate(dims)])
self.to_segmentation = nn.Sequential(
nn.Conv2d(4 * decoder_dim, decoder_dim, 1),
nn.Conv2d(decoder_dim, num_classes, 1),
)
def forward(self, x):
H, W = x.shape[-2:]  # 原始输入高宽
layer_outputs = self.mit(x, return_layer_outputs=True)
fused = [to_fused(output) for output, to_fused in zip(layer_outputs, self.to_fused)]
fused = torch.cat(fused, dim=1)
out = self.to_segmentation(fused)
# 关键修复:上采样到原始输入尺寸
out = F.interpolate(out, size=(H, W), mode='bilinear', align_corners=False)
return out

编码器(MiT):

  • 论文中的MiT:
    • 分层设计的Transformer编码器
    • 4个阶段,每个阶段下采样2倍
    • 使用重叠块嵌入(Overlapped Patch Embedding)
class MiT(nn.Module):
def __init__(
self,
*,
channels,
dims,
heads,
ff_expansion,
reduction_ratio,
num_layers
):
super().__init__()
# 四个阶段的下采样配置
stage_kernel_stride_pad = ((7, 4, 3), (3, 2, 1), (3, 2, 1), (3, 2, 1))
# 对应论文中的阶段1-4
dims = (channels, *dims)
dim_pairs = list(zip(dims[:-1], dims[1:]))
self.stages = nn.ModuleList([])
for (dim_in, dim_out), (kernel, stride, padding), num_layers, ff_expansion, heads, reduction_ratio in zip(dim_pairs, stage_kernel_stride_pad, num_layers, ff_expansion, heads, reduction_ratio):
get_overlap_patches = nn.Unfold(kernel, stride=stride, padding=padding)
overlap_patch_embed = nn.Conv2d(dim_in * kernel ** 2, dim_out, 1)
layers = nn.ModuleList([])
for _ in range(num_layers):
layers.append(nn.ModuleList([
PreNorm(dim_out, EfficientSelfAttention(dim=dim_out, heads=heads, reduction_ratio=reduction_ratio)),
PreNorm(dim_out, MixFeedForward(dim=dim_out, expansion_factor=ff_expansion)),
]))
self.stages.append(nn.ModuleList([
get_overlap_patches,
overlap_patch_embed,
layers
]))
def forward(
self,
x,
return_layer_outputs=False
):
h, w = x.shape[-2:]
layer_outputs = []
for (get_overlap_patches, overlap_embed, layers) in self.stages:
x = get_overlap_patches(x)
num_patches = x.shape[-1]
ratio = int(sqrt((h * w) / num_patches))
x = rearrange(x, 'b c (h w) -> b c h w', h=h // ratio)
x = overlap_embed(x)
for (attn, ff) in layers:
x = attn(x) + x
x = ff(x) + x
layer_outputs.append(x)
ret = x if not return_layer_outputs else layer_outputs
return ret

分层结构,产生多尺度特征(通常有4个阶段,每个阶段特征图尺寸递减)。

class MiT(nn.Module):
def __init__(self, *, channels, dims, heads, ff_expansion, reduction_ratio, num_layers):
# 四个阶段的下采样配置
stage_kernel_stride_pad = ((7, 4, 3), (3, 2, 1), (3, 2, 1), (3, 2, 1))
# 对应论文中的阶段1-4

每个阶段由多个Transformer块组成,每个块包含:

重叠块嵌入(Overlapped Patch Embedding)
get_overlap_patches = nn.Unfold(kernel, stride=stride, padding=padding)
overlap_patch_embed = nn.Conv2d(dim_in * kernel ** 2, dim_out, 1)
高效自注意力(Efficient Self-Attention)
  • 论文创新点:
    • 序列缩减机制,降低计算复杂度
    • 使用reduction_ratio对K,V进行下采样
class EfficientSelfAttention(nn.Module):
def __init__(
self,
*,
dim,
heads,
reduction_ratio
):
super().__init__()
self.scale = (dim // heads) ** -0.5
self.heads = heads
self.to_q = nn.Conv2d(dim, dim, 1, bias=False)
self.to_kv = nn.Conv2d(dim, dim * 2, reduction_ratio, stride=reduction_ratio, bias=False) # 关键:序列缩减
self.to_out = nn.Conv2d(dim, dim, 1, bias=False)
def forward(self, x):
h, w = x.shape[-2:]
heads = self.heads
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=1))
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h=heads), (q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
attn = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) (x y) c -> b (h c) x y', h=heads, x=h, y=w)
return self.to_out(out)
混合前馈网络(Mix FeedForward Network)
  • 论文创新点:
    • 使用3×3深度可分离卷积增强局部特征提取
    • 替换标准MLP
class MixFeedForward(nn.Module):
def __init__(
self,
*,
dim,
expansion_factor
):
super().__init__()
hidden_dim = dim * expansion_factor
self.net = nn.Sequential(
nn.Conv2d(dim, hidden_dim, 1), # 升维
DsConv2d(hidden_dim, hidden_dim, 3, padding=1), # 深度可分离卷积
nn.GELU(),
nn.Conv2d(hidden_dim, dim, 1) # 降维
)
def forward(self, x):
return self.net(x)

解码器(All-MLP):

  • 论文创新点:
    • 简单的MLP结构,无需复杂设计
    • 多尺度特征融合

将多尺度特征上采样到相同尺寸并拼接。

# 多尺度特征融合
self.to_fused = nn.ModuleList([nn.Sequential(
nn.Conv2d(dim, decoder_dim, 1), # 统一通道数
nn.Upsample(scale_factor=2 ** i) # 上采样到1/4尺度
) for i, dim in enumerate(dims)])

通过多层感知机(MLP)得到分割结果。

self.to_segmentation = nn.Sequential(
nn.Conv2d(4 * decoder_dim, decoder_dim, 1), # 特征融合
nn.Conv2d(decoder_dim, num_classes, 1), # 分类头
)

具体实现

导入相关库

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
from torchvision import models
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2
from tqdm import tqdm
import torch.nn.functional as F
from math import sqrt
from functools import partial
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, reduce
from einops.layers.torch import Rearrange
# helpers
def exists(val):
return val is not None
def cast_tuple(val, depth):
return val if isinstance(val, tuple) else (val,) * depth

准备数据集

# ============== MockSegmentationDataset ==============
class MockSegmentationDataset(Dataset):
def __init__(self, size=256, num_samples=1000, num_classes=4):
self.size = size
self.num_samples = num_samples
self.num_classes = num_classes
# 图像变换
self.image_transform = transforms.Compose([
transforms.Resize((size, size)),
transforms.ToTensor(),
])
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
# 使用固定模式而不是完全随机,让模型容易学习
rng = np.random.RandomState(idx)  # 固定随机种子,让数据可重复
# 生成更结构化的背景
img = np.full((self.size, self.size, 3), 128, dtype=np.uint8)  # 固定灰色背景
seg_map = np.zeros((self.size, self.size), dtype=np.uint8)
# 固定位置和尺寸的形状,减少随机性
positions = [
(self.size//4, self.size//4),      # 左上
(3*self.size//4, self.size//4),    # 右上  
(self.size//4, 3*self.size//4),    # 左下
(3*self.size//4, 3*self.size//4),  # 右下
]
# 为每个样本固定选择2个形状,确保类别平衡
shape_indices = [idx % 3 + 1, (idx + 1) % 3 + 1]  # 循环使用类别1,2,3
for i, cls in enumerate(shape_indices[:2]):  # 只画2个形状
pos = positions[i]
if cls == 1:  # 圆形
cv2.circle(seg_map, pos, 25, int(cls), -1)
cv2.circle(img, pos, 25, (255, 0, 0), -1)  # 红色
elif cls == 2:  # 矩形
pt1 = (pos[0]-25, pos[1]-20)
pt2 = (pos[0]+25, pos[1]+20)
cv2.rectangle(seg_map, pt1, pt2, int(cls), -1)
cv2.rectangle(img, pt1, pt2, (0, 255, 0), -1)  # 绿色
elif cls == 3:  # 椭圆
cv2.ellipse(seg_map, pos, (30, 15), 45, 0, 360, int(cls), -1)
cv2.ellipse(img, pos, (30, 15), 45, 0, 360, (0, 0, 255), -1)  # 蓝色
# 应用图像变换
img = Image.fromarray(img)
img = self.image_transform(img)
# 直接转换为tensor,不应用与图像相同的变换
seg_map = torch.from_numpy(seg_map).long()
return img, seg_map

定义网络模型

# ============== SegFormer模型定义 ==============
class DsConv2d(nn.Module):
def __init__(self, dim_in, dim_out, kernel_size, padding, stride=1, bias=True):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(dim_in, dim_in, kernel_size=kernel_size, padding=padding, groups=dim_in, stride=stride, bias=bias),
nn.Conv2d(dim_in, dim_out, kernel_size=1, bias=bias)
)
def forward(self, x):
return self.net(x)
class LayerNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
std = torch.var(x, dim=1, unbiased=False, keepdim=True).sqrt()
mean = torch.mean(x, dim=1, keepdim=True)
return (x - mean) / (std + self.eps) * self.g + self.b
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = LayerNorm(dim)
def forward(self, x):
return self.fn(self.norm(x))
class EfficientSelfAttention(nn.Module):
def __init__(
self,
*,
dim,
heads,
reduction_ratio
):
super().__init__()
self.scale = (dim // heads) ** -0.5
self.heads = heads
self.to_q = nn.Conv2d(dim, dim, 1, bias=False)
self.to_kv = nn.Conv2d(dim, dim * 2, reduction_ratio, stride=reduction_ratio, bias=False) # 关键:序列缩减
self.to_out = nn.Conv2d(dim, dim, 1, bias=False)
def forward(self, x):
h, w = x.shape[-2:]
heads = self.heads
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=1))
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h=heads), (q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
attn = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) (x y) c -> b (h c) x y', h=heads, x=h, y=w)
return self.to_out(out)
class MixFeedForward(nn.Module):
def __init__(
self,
*,
dim,
expansion_factor
):
super().__init__()
hidden_dim = dim * expansion_factor
self.net = nn.Sequential(
nn.Conv2d(dim, hidden_dim, 1), # 升维
DsConv2d(hidden_dim, hidden_dim, 3, padding=1), # 深度可分离卷积
nn.GELU(),
nn.Conv2d(hidden_dim, dim, 1) # 降维
)
def forward(self, x):
return self.net(x)
class MiT(nn.Module):
def __init__(
self,
*,
channels,
dims,
heads,
ff_expansion,
reduction_ratio,
num_layers
):
super().__init__()
# 四个阶段的下采样配置
stage_kernel_stride_pad = ((7, 4, 3), (3, 2, 1), (3, 2, 1), (3, 2, 1))
# 对应论文中的阶段1-4
dims = (channels, *dims)
dim_pairs = list(zip(dims[:-1], dims[1:]))
self.stages = nn.ModuleList([])
for (dim_in, dim_out), (kernel, stride, padding), num_layers, ff_expansion, heads, reduction_ratio in zip(dim_pairs, stage_kernel_stride_pad, num_layers, ff_expansion, heads, reduction_ratio):
get_overlap_patches = nn.Unfold(kernel, stride=stride, padding=padding)
overlap_patch_embed = nn.Conv2d(dim_in * kernel ** 2, dim_out, 1)
layers = nn.ModuleList([])
for _ in range(num_layers):
layers.append(nn.ModuleList([
PreNorm(dim_out, EfficientSelfAttention(dim=dim_out, heads=heads, reduction_ratio=reduction_ratio)),
PreNorm(dim_out, MixFeedForward(dim=dim_out, expansion_factor=ff_expansion)),
]))
self.stages.append(nn.ModuleList([
get_overlap_patches,
overlap_patch_embed,
layers
]))
def forward(
self,
x,
return_layer_outputs=False
):
h, w = x.shape[-2:]
layer_outputs = []
for (get_overlap_patches, overlap_embed, layers) in self.stages:
x = get_overlap_patches(x)
num_patches = x.shape[-1]
ratio = int(sqrt((h * w) / num_patches))
x = rearrange(x, 'b c (h w) -> b c h w', h=h // ratio)
x = overlap_embed(x)
for (attn, ff) in layers:
x = attn(x) + x
x = ff(x) + x
layer_outputs.append(x)
ret = x if not return_layer_outputs else layer_outputs
return ret
class Segformer(nn.Module):
def __init__(
self,
*,
dims=(32, 64, 160, 256),
heads=(1, 2, 5, 8),
ff_expansion=(8, 8, 4, 4),
reduction_ratio=(8, 4, 2, 1),
num_layers=2,
channels=3,
decoder_dim=256,
num_classes=4
):
super().__init__()
dims, heads, ff_expansion, reduction_ratio, num_layers = map(partial(cast_tuple, depth=4), (dims, heads, ff_expansion, reduction_ratio, num_layers))
assert all([*map(lambda t: len(t) == 4, (dims, heads, ff_expansion, reduction_ratio, num_layers))]), 'only four stages are allowed, all keyword arguments must be either a single value or a tuple of 4 values'
self.mit = MiT(
channels=channels,
dims=dims,
heads=heads,
ff_expansion=ff_expansion,
reduction_ratio=reduction_ratio,
num_layers=num_layers
)
# 多尺度特征融合
self.to_fused = nn.ModuleList([nn.Sequential(
nn.Conv2d(dim, decoder_dim, 1), # 统一通道数
nn.Upsample(scale_factor=2 ** i) # 上采样到1/4尺度
) for i, dim in enumerate(dims)])
self.to_segmentation = nn.Sequential(
nn.Conv2d(4 * decoder_dim, decoder_dim, 1), # 特征融合
nn.Conv2d(decoder_dim, num_classes, 1), # 分类头
)
def forward(self, x):
H, W = x.shape[-2:]  # 原始输入高宽
layer_outputs = self.mit(x, return_layer_outputs=True)
fused = [to_fused(output) for output, to_fused in zip(layer_outputs, self.to_fused)]
fused = torch.cat(fused, dim=1)
out = self.to_segmentation(fused)
# 关键修复:上采样到原始输入尺寸
out = F.interpolate(out, size=(H, W), mode='bilinear', align_corners=False)
return out

训练验证

# ============== 训练函数 ==============
def get_segformer(model_name='b0', num_classes=4, decoder_dim=256):
config = {
'b0': dict(dims=(32, 64, 160, 256), num_layers=(2, 2, 2, 2)),
'b1': dict(dims=(64, 128, 320, 512), num_layers=(2, 2, 2, 2)),
'b2': dict(dims=(64, 128, 320, 512), num_layers=(3, 4, 6, 3)),
'b3': dict(dims=(64, 128, 320, 512), num_layers=(3, 4, 18, 3)),
'b4': dict(dims=(64, 128, 320, 512), num_layers=(3, 8, 27, 3)),
'b5': dict(dims=(64, 128, 320, 512), num_layers=(3, 6, 40, 3)),
}
if model_name not in config:
raise ValueError(f"Unsupported model: {model_name}")
cfg = config[model_name]
ff_expansion = (4, 4, 4, 4) if model_name == 'b5' else (8, 8, 4, 4)
return Segformer(
dims=cfg['dims'],
heads=(1, 2, 5, 8),
ff_expansion=ff_expansion,
reduction_ratio=(8, 4, 2, 1),
num_layers=cfg['num_layers'],
decoder_dim=decoder_dim,
num_classes=num_classes
)
def train_model(model, train_loader, val_loader, num_epochs=10, learning_rate=1e-4, device='cuda'):
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
train_losses = []
val_losses = []
for epoch in range(num_epochs):
# Training
model.train()
train_loss = 0.0
for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Training", leave=False):
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
# 确保输出和标签维度匹配
# outputs: [batch, num_classes, H, W]
# labels: [batch, H, W]
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item() * images.size(0)
train_loss /= len(train_loader.dataset)
train_losses.append(train_loss)
# Validation
model.eval()
val_loss = 0.0
with torch.no_grad():
for images, labels in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Validation", leave=False):
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
val_loss += loss.item() * images.size(0)
val_loss /= len(val_loader.dataset)
val_losses.append(val_loss)
# 每2个epoch可视化一次训练样本的预测
if epoch % 2 == 0 or epoch == num_epochs - 1:
model.eval()
with torch.no_grad():
# 取一个训练样本
sample_img, sample_label = next(iter(train_loader))
sample_img, sample_label = sample_img[:1].to(device), sample_label[:1].to(device)
output = model(sample_img)
pred = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()
# 可视化
plt.figure(figsize=(12, 4))
plt.subplot(1, 3, 1)
plt.imshow(sample_img[0].cpu().permute(1, 2, 0))
plt.title('Input')
plt.axis('off')
plt.subplot(1, 3, 2)
plt.imshow(sample_label[0].cpu(), cmap='jet', vmin=0, vmax=3)
plt.title('Ground Truth')
plt.axis('off')
plt.subplot(1, 3, 3)
plt.imshow(pred, cmap='jet', vmin=0, vmax=3)
plt.title(f'Prediction Epoch {epoch}')
plt.axis('off')
plt.savefig(f'train_debug_epoch_{epoch}.png', dpi=100, bbox_inches='tight')
plt.close()
print(f"Debug visualization saved to train_debug_epoch_{epoch}.png")
print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
return model, train_losses, val_losses

推理预测

# ============== 推理函数 ==============
def load_model(model_path, model_name='b0', num_classes=4, device='cuda'):
"""加载训练好的模型"""
model = get_segformer(model_name=model_name, num_classes=num_classes)
# 加载模型权重
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint)
model.to(device)
model.eval()
print(f"模型加载成功,参数数量: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
return model
print(f"模型加载成功,参数数量: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
return model
def predict(model, image_path, device='cuda'):
model = model.to(device)
model.eval()
# 加载并预处理图像
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
])
image = Image.open(image_path).convert('RGB')
image = transform(image).unsqueeze(0)  # Add batch dimension
# Move to device
image = image.to(device)
# Predict
with torch.no_grad():
output = model(image)
# Get prediction (argmax along channels)
pred = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()
return image.squeeze(0).cpu().numpy(), pred
# ============== 可视化函数 ==============
def visualize_results(original, prediction, save_path=None):
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(np.transpose(original, (1, 2, 0)))
plt.title('Original Image')
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(prediction, cmap='jet', vmin=0, vmax=3)
plt.title('Segmentation Prediction')
plt.axis('off')
if save_path:
plt.savefig(save_path, dpi=100, bbox_inches='tight')
print(f"Visualization saved to {save_path}")
else:
plt.show()
def generate_sample_image_and_label(save_img_path="sample_image.png", save_label_path=None, size=256):
"""
生成一张带几何形状的模拟图像和对应的标签图(可选保存)。
- 背景: 类别 0
- 红色圆: 类别 1
- 绿色矩形: 类别 2
- 蓝色椭圆: 类别 3
"""
# 创建灰色背景图像
img = np.full((size, size, 3), 128, dtype=np.uint8)
label = np.zeros((size, size), dtype=np.uint8)
# 1. 红色圆(类别 1)
center1 = (80, 80)
radius1 = 25
cv2.circle(img, center1, radius1, (255, 0, 0), -1)  # 红色
cv2.circle(label, center1, radius1, 1, -1)
# 2. 绿色矩形(类别 2)
pt1 = (150, 60)
pt2 = (200, 110)
cv2.rectangle(img, pt1, pt2, (0, 255, 0), -1)  # 绿色
cv2.rectangle(label, pt1, pt2, 2, -1)
# 3. 蓝色椭圆(类别 3)
center2 = (120, 180)
axes = (30, 15)
cv2.ellipse(img, center2, axes, 45, 0, 360, (0, 0, 255), -1)  # 蓝色
cv2.ellipse(label, center2, axes, 45, 0, 360, 3, -1)
# 保存图像
cv2.imwrite(save_img_path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
print(f"Sample image saved to {save_img_path}")
if save_label_path:
# 保存标签为可视化灰度图(0~255 映射)
label_vis = (label * 60).astype(np.uint8)  # 0,60,120,180 便于肉眼区分
cv2.imwrite(save_label_path, label_vis)
print(f"Label visualization saved to {save_label_path}")
return img, label

主函数

# ============== 主程序 ==============
if __name__ == "__main__":
# 设置随机种子
torch.manual_seed(42)
np.random.seed(42)
# 创建模拟数据集
dataset = MockSegmentationDataset()
# 划分训练集和验证集
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
# 创建DataLoader
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=2)
# # 初始化模型b0
# model = Segformer(
#     dims=(32, 64, 160, 256), # 各阶段通道数 [C1, C2, C3, C4]
#     heads=(1, 2, 5, 8), # 各阶段注意力头数
#     ff_expansion=(8, 8, 4, 4), # FFN扩展因子
#     reduction_ratio=(8, 4, 2, 1), # 序列缩减比例
#     num_layers=2, # 各阶段层数
#     decoder_dim=256, # 解码器统一维度
#     num_classes=4 # 分割类别数
# )
model_name = 'b0'  # 可选 'b0', 'b1', 'b2', 'b3', 'b4', 'b5'
model = get_segformer(model_name, num_classes=4)
os.makedirs(model_name, exist_ok=True)
print(f"Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
# 训练模型
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
model, train_losses, val_losses = train_model(
model,
train_loader,
val_loader,
num_epochs=5,
# num_epochs=10,  # 增加到10个epoch
learning_rate=1e-4,
device=device
)
# 保存模型
torch.save(model.state_dict(), f'{model_name}/segformer_model.pth')
print(f"Model saved to '{model_name}/segformer_model.pth'")
# 测试推理
print("\nTesting inference on a sample image...")
# 生成一个结构清晰的模拟图像用于推理
sample_img, sample_label = generate_sample_image_and_label(
save_img_path="sample_image.png",
save_label_path="sample_label.png",  # 可选:保存标签用于对比
size=256
)
sample_img_path = "sample_image.png"
# 加载模型
model = load_model(f'{model_name}/segformer_model.pth', model_name=model_name, num_classes=4, device=device)
# 进行预测
original, prediction = predict(model, sample_img_path, device=device)
# 可视化结果
visualize_results(original, prediction, save_path=f"{model_name}/segmentation_result.png")
print(f"Inference completed. Result saved to '{model_name}/segmentation_result.png'")

输出结果

Model parameters: 7718244
Using device: cuda
Debug visualization saved to train_debug_epoch_0.png
Epoch 1/5, Train Loss: 0.1226, Val Loss: 0.0077
Epoch 2/5, Train Loss: 0.0052, Val Loss: 0.0037
Debug visualization saved to train_debug_epoch_2.png
Epoch 3/5, Train Loss: 0.0031, Val Loss: 0.0026
Epoch 4/5, Train Loss: 0.0022, Val Loss: 0.0019
Debug visualization saved to train_debug_epoch_4.png
Epoch 5/5, Train Loss: 0.0017, Val Loss: 0.0015
Model saved to 'b0/segformer_model.pth'
Testing inference on a sample image...
Sample image saved to sample_image.png
Label visualization saved to sample_label.png
模型加载成功,参数数量: 7718244
Visualization saved to b0/segmentation_result.png
Inference completed. Result saved to 'b0/segmentation_result.png'

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

完整代码

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
from torchvision import models
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2
from tqdm import tqdm
import torch.nn.functional as F
from math import sqrt
from functools import partial
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange, reduce
from einops.layers.torch import Rearrange
# helpers
def exists(val):
return val is not None
def cast_tuple(val, depth):
return val if isinstance(val, tuple) else (val,) * depth
# ============== MockSegmentationDataset ==============
class MockSegmentationDataset(Dataset):
def __init__(self, size=256, num_samples=1000, num_classes=4):
self.size = size
self.num_samples = num_samples
self.num_classes = num_classes
# 图像变换
self.image_transform = transforms.Compose([
transforms.Resize((size, size)),
transforms.ToTensor(),
])
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
# 使用固定模式而不是完全随机,让模型容易学习
rng = np.random.RandomState(idx)  # 固定随机种子,让数据可重复
# 生成更结构化的背景
img = np.full((self.size, self.size, 3), 128, dtype=np.uint8)  # 固定灰色背景
seg_map = np.zeros((self.size, self.size), dtype=np.uint8)
# 固定位置和尺寸的形状,减少随机性
positions = [
(self.size//4, self.size//4),      # 左上
(3*self.size//4, self.size//4),    # 右上  
(self.size//4, 3*self.size//4),    # 左下
(3*self.size//4, 3*self.size//4),  # 右下
]
# 为每个样本固定选择2个形状,确保类别平衡
shape_indices = [idx % 3 + 1, (idx + 1) % 3 + 1]  # 循环使用类别1,2,3
for i, cls in enumerate(shape_indices[:2]):  # 只画2个形状
pos = positions[i]
if cls == 1:  # 圆形
cv2.circle(seg_map, pos, 25, int(cls), -1)
cv2.circle(img, pos, 25, (255, 0, 0), -1)  # 红色
elif cls == 2:  # 矩形
pt1 = (pos[0]-25, pos[1]-20)
pt2 = (pos[0]+25, pos[1]+20)
cv2.rectangle(seg_map, pt1, pt2, int(cls), -1)
cv2.rectangle(img, pt1, pt2, (0, 255, 0), -1)  # 绿色
elif cls == 3:  # 椭圆
cv2.ellipse(seg_map, pos, (30, 15), 45, 0, 360, int(cls), -1)
cv2.ellipse(img, pos, (30, 15), 45, 0, 360, (0, 0, 255), -1)  # 蓝色
# 应用图像变换
img = Image.fromarray(img)
img = self.image_transform(img)
# 直接转换为tensor,不应用与图像相同的变换
seg_map = torch.from_numpy(seg_map).long()
return img, seg_map
# ============== SegFormer模型定义 ==============
class DsConv2d(nn.Module):
def __init__(self, dim_in, dim_out, kernel_size, padding, stride=1, bias=True):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(dim_in, dim_in, kernel_size=kernel_size, padding=padding, groups=dim_in, stride=stride, bias=bias),
nn.Conv2d(dim_in, dim_out, kernel_size=1, bias=bias)
)
def forward(self, x):
return self.net(x)
class LayerNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.eps = eps
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
def forward(self, x):
std = torch.var(x, dim=1, unbiased=False, keepdim=True).sqrt()
mean = torch.mean(x, dim=1, keepdim=True)
return (x - mean) / (std + self.eps) * self.g + self.b
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = LayerNorm(dim)
def forward(self, x):
return self.fn(self.norm(x))
class EfficientSelfAttention(nn.Module):
def __init__(
self,
*,
dim,
heads,
reduction_ratio
):
super().__init__()
self.scale = (dim // heads) ** -0.5
self.heads = heads
self.to_q = nn.Conv2d(dim, dim, 1, bias=False)
self.to_kv = nn.Conv2d(dim, dim * 2, reduction_ratio, stride=reduction_ratio, bias=False) # 关键:序列缩减
self.to_out = nn.Conv2d(dim, dim, 1, bias=False)
def forward(self, x):
h, w = x.shape[-2:]
heads = self.heads
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=1))
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h=heads), (q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
attn = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) (x y) c -> b (h c) x y', h=heads, x=h, y=w)
return self.to_out(out)
class MixFeedForward(nn.Module):
def __init__(
self,
*,
dim,
expansion_factor
):
super().__init__()
hidden_dim = dim * expansion_factor
self.net = nn.Sequential(
nn.Conv2d(dim, hidden_dim, 1), # 升维
DsConv2d(hidden_dim, hidden_dim, 3, padding=1), # 深度可分离卷积
nn.GELU(),
nn.Conv2d(hidden_dim, dim, 1) # 降维
)
def forward(self, x):
return self.net(x)
class MiT(nn.Module):
def __init__(
self,
*,
channels,
dims,
heads,
ff_expansion,
reduction_ratio,
num_layers
):
super().__init__()
# 四个阶段的下采样配置
stage_kernel_stride_pad = ((7, 4, 3), (3, 2, 1), (3, 2, 1), (3, 2, 1))
# 对应论文中的阶段1-4
dims = (channels, *dims)
dim_pairs = list(zip(dims[:-1], dims[1:]))
self.stages = nn.ModuleList([])
for (dim_in, dim_out), (kernel, stride, padding), num_layers, ff_expansion, heads, reduction_ratio in zip(dim_pairs, stage_kernel_stride_pad, num_layers, ff_expansion, heads, reduction_ratio):
get_overlap_patches = nn.Unfold(kernel, stride=stride, padding=padding)
overlap_patch_embed = nn.Conv2d(dim_in * kernel ** 2, dim_out, 1)
layers = nn.ModuleList([])
for _ in range(num_layers):
layers.append(nn.ModuleList([
PreNorm(dim_out, EfficientSelfAttention(dim=dim_out, heads=heads, reduction_ratio=reduction_ratio)),
PreNorm(dim_out, MixFeedForward(dim=dim_out, expansion_factor=ff_expansion)),
]))
self.stages.append(nn.ModuleList([
get_overlap_patches,
overlap_patch_embed,
layers
]))
def forward(
self,
x,
return_layer_outputs=False
):
h, w = x.shape[-2:]
layer_outputs = []
for (get_overlap_patches, overlap_embed, layers) in self.stages:
x = get_overlap_patches(x)
num_patches = x.shape[-1]
ratio = int(sqrt((h * w) / num_patches))
x = rearrange(x, 'b c (h w) -> b c h w', h=h // ratio)
x = overlap_embed(x)
for (attn, ff) in layers:
x = attn(x) + x
x = ff(x) + x
layer_outputs.append(x)
ret = x if not return_layer_outputs else layer_outputs
return ret
class Segformer(nn.Module):
def __init__(
self,
*,
dims=(32, 64, 160, 256),
heads=(1, 2, 5, 8),
ff_expansion=(8, 8, 4, 4),
reduction_ratio=(8, 4, 2, 1),
num_layers=2,
channels=3,
decoder_dim=256,
num_classes=4
):
super().__init__()
dims, heads, ff_expansion, reduction_ratio, num_layers = map(partial(cast_tuple, depth=4), (dims, heads, ff_expansion, reduction_ratio, num_layers))
assert all([*map(lambda t: len(t) == 4, (dims, heads, ff_expansion, reduction_ratio, num_layers))]), 'only four stages are allowed, all keyword arguments must be either a single value or a tuple of 4 values'
self.mit = MiT(
channels=channels,
dims=dims,
heads=heads,
ff_expansion=ff_expansion,
reduction_ratio=reduction_ratio,
num_layers=num_layers
)
# 多尺度特征融合
self.to_fused = nn.ModuleList([nn.Sequential(
nn.Conv2d(dim, decoder_dim, 1), # 统一通道数
nn.Upsample(scale_factor=2 ** i) # 上采样到1/4尺度
) for i, dim in enumerate(dims)])
self.to_segmentation = nn.Sequential(
nn.Conv2d(4 * decoder_dim, decoder_dim, 1), # 特征融合
nn.Conv2d(decoder_dim, num_classes, 1), # 分类头
)
def forward(self, x):
H, W = x.shape[-2:]  # 原始输入高宽
layer_outputs = self.mit(x, return_layer_outputs=True)
fused = [to_fused(output) for output, to_fused in zip(layer_outputs, self.to_fused)]
fused = torch.cat(fused, dim=1)
out = self.to_segmentation(fused)
# 关键修复:上采样到原始输入尺寸
out = F.interpolate(out, size=(H, W), mode='bilinear', align_corners=False)
return out
# ============== 训练函数 ==============
def get_segformer(model_name='b0', num_classes=4, decoder_dim=256):
config = {
'b0': dict(dims=(32, 64, 160, 256), num_layers=(2, 2, 2, 2)),
'b1': dict(dims=(64, 128, 320, 512), num_layers=(2, 2, 2, 2)),
'b2': dict(dims=(64, 128, 320, 512), num_layers=(3, 4, 6, 3)),
'b3': dict(dims=(64, 128, 320, 512), num_layers=(3, 4, 18, 3)),
'b4': dict(dims=(64, 128, 320, 512), num_layers=(3, 8, 27, 3)),
'b5': dict(dims=(64, 128, 320, 512), num_layers=(3, 6, 40, 3)),
}
if model_name not in config:
raise ValueError(f"Unsupported model: {model_name}")
cfg = config[model_name]
ff_expansion = (4, 4, 4, 4) if model_name == 'b5' else (8, 8, 4, 4)
return Segformer(
dims=cfg['dims'],
heads=(1, 2, 5, 8),
ff_expansion=ff_expansion,
reduction_ratio=(8, 4, 2, 1),
num_layers=cfg['num_layers'],
decoder_dim=decoder_dim,
num_classes=num_classes
)
def train_model(model, train_loader, val_loader, num_epochs=10, learning_rate=1e-4, device='cuda'):
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
train_losses = []
val_losses = []
for epoch in range(num_epochs):
# Training
model.train()
train_loss = 0.0
for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Training", leave=False):
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
# 确保输出和标签维度匹配
# outputs: [batch, num_classes, H, W]
# labels: [batch, H, W]
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss += loss.item() * images.size(0)
train_loss /= len(train_loader.dataset)
train_losses.append(train_loss)
# Validation
model.eval()
val_loss = 0.0
with torch.no_grad():
for images, labels in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Validation", leave=False):
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
val_loss += loss.item() * images.size(0)
val_loss /= len(val_loader.dataset)
val_losses.append(val_loss)
# 每2个epoch可视化一次训练样本的预测
if epoch % 2 == 0 or epoch == num_epochs - 1:
model.eval()
with torch.no_grad():
# 取一个训练样本
sample_img, sample_label = next(iter(train_loader))
sample_img, sample_label = sample_img[:1].to(device), sample_label[:1].to(device)
output = model(sample_img)
pred = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()
# 可视化
plt.figure(figsize=(12, 4))
plt.subplot(1, 3, 1)
plt.imshow(sample_img[0].cpu().permute(1, 2, 0))
plt.title('Input')
plt.axis('off')
plt.subplot(1, 3, 2)
plt.imshow(sample_label[0].cpu(), cmap='jet', vmin=0, vmax=3)
plt.title('Ground Truth')
plt.axis('off')
plt.subplot(1, 3, 3)
plt.imshow(pred, cmap='jet', vmin=0, vmax=3)
plt.title(f'Prediction Epoch {epoch}')
plt.axis('off')
plt.savefig(f'train_debug_epoch_{epoch}.png', dpi=100, bbox_inches='tight')
plt.close()
print(f"Debug visualization saved to train_debug_epoch_{epoch}.png")
print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
return model, train_losses, val_losses
# ============== 推理函数 ==============
def load_model(model_path, model_name='b0', num_classes=4, device='cuda'):
"""加载训练好的模型"""
model = get_segformer(model_name=model_name, num_classes=num_classes)
# 加载模型权重
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint)
model.to(device)
model.eval()
print(f"模型加载成功,参数数量: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
return model
def predict(model, image_path, device='cuda'):
model = model.to(device)
model.eval()
# 加载并预处理图像
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
])
image = Image.open(image_path).convert('RGB')
image = transform(image).unsqueeze(0)  # Add batch dimension
# Move to device
image = image.to(device)
# Predict
with torch.no_grad():
output = model(image)
# Get prediction (argmax along channels)
pred = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()
return image.squeeze(0).cpu().numpy(), pred
# ============== 可视化函数 ==============
def visualize_results(original, prediction, save_path=None):
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(np.transpose(original, (1, 2, 0)))
plt.title('Original Image')
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(prediction, cmap='jet', vmin=0, vmax=3)
plt.title('Segmentation Prediction')
plt.axis('off')
if save_path:
plt.savefig(save_path, dpi=100, bbox_inches='tight')
print(f"Visualization saved to {save_path}")
else:
plt.show()
def generate_sample_image_and_label(save_img_path="sample_image.png", save_label_path=None, size=256):
"""
生成一张带几何形状的模拟图像和对应的标签图(可选保存)。
- 背景: 类别 0
- 红色圆: 类别 1
- 绿色矩形: 类别 2
- 蓝色椭圆: 类别 3
"""
# 创建灰色背景图像
img = np.full((size, size, 3), 128, dtype=np.uint8)
label = np.zeros((size, size), dtype=np.uint8)
# 1. 红色圆(类别 1)
center1 = (80, 80)
radius1 = 25
cv2.circle(img, center1, radius1, (255, 0, 0), -1)  # 红色
cv2.circle(label, center1, radius1, 1, -1)
# 2. 绿色矩形(类别 2)
pt1 = (150, 60)
pt2 = (200, 110)
cv2.rectangle(img, pt1, pt2, (0, 255, 0), -1)  # 绿色
cv2.rectangle(label, pt1, pt2, 2, -1)
# 3. 蓝色椭圆(类别 3)
center2 = (120, 180)
axes = (30, 15)
cv2.ellipse(img, center2, axes, 45, 0, 360, (0, 0, 255), -1)  # 蓝色
cv2.ellipse(label, center2, axes, 45, 0, 360, 3, -1)
# 保存图像
cv2.imwrite(save_img_path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
print(f"Sample image saved to {save_img_path}")
if save_label_path:
# 保存标签为可视化灰度图(0~255 映射)
label_vis = (label * 60).astype(np.uint8)  # 0,60,120,180 便于肉眼区分
cv2.imwrite(save_label_path, label_vis)
print(f"Label visualization saved to {save_label_path}")
return img, label
# ============== 主程序 ==============
if __name__ == "__main__":
# 设置随机种子
torch.manual_seed(42)
np.random.seed(42)
# 创建模拟数据集
dataset = MockSegmentationDataset()
# 划分训练集和验证集
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
# 创建DataLoader
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=2)
# # 初始化模型b0
# model = Segformer(
#     dims=(32, 64, 160, 256), # 各阶段通道数 [C1, C2, C3, C4]
#     heads=(1, 2, 5, 8), # 各阶段注意力头数
#     ff_expansion=(8, 8, 4, 4), # FFN扩展因子
#     reduction_ratio=(8, 4, 2, 1), # 序列缩减比例
#     num_layers=2, # 各阶段层数
#     decoder_dim=256, # 解码器统一维度
#     num_classes=4 # 分割类别数
# )
model_name = 'b0'  # 可选 'b0', 'b1', 'b2', 'b3', 'b4', 'b5'
model = get_segformer(model_name, num_classes=4)
os.makedirs(model_name, exist_ok=True)
print(f"Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
# 训练模型
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
model, train_losses, val_losses = train_model(
model,
train_loader,
val_loader,
num_epochs=5,
# num_epochs=10,  # 增加到10个epoch
learning_rate=1e-4,
device=device
)
# 保存模型
torch.save(model.state_dict(), f'{model_name}/segformer_model.pth')
print(f"Model saved to '{model_name}/segformer_model.pth'")
# 测试推理
print("\nTesting inference on a sample image...")
# 生成一个结构清晰的模拟图像用于推理
sample_img, sample_label = generate_sample_image_and_label(
save_img_path="sample_image.png",
save_label_path="sample_label.png",  # 可选:保存标签用于对比
size=256
)
sample_img_path = "sample_image.png"
# 加载模型
model = load_model(f'{model_name}/segformer_model.pth', model_name=model_name, num_classes=4, device=device)
# 进行预测
original, prediction = predict(model, sample_img_path, device=device)
# 可视化结果
visualize_results(original, prediction, save_path=f"{model_name}/segmentation_result.png")
print(f"Inference completed. Result saved to '{model_name}/segmentation_result.png'")

参考

[1] Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo. SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers. 2021
[2] https://github.com/NVlabs/SegFormer.git
[3] https://github.com/bubbliiiing/segformer-pytorch.git

http://www.hskmm.com/?act=detail&tid=24054

相关文章:

  • 深入解析:Playwright同步、异步、并行、串行执行效率比较
  • 2025十一集训——Day2模拟赛
  • 2025十一集训——Day模拟赛
  • Qt纯代码实现智能安防集中管理平台/楼宇对讲管理系统/门禁管理/视频监控
  • 汉文博士词典库源文件已在 github 开放
  • 读人形机器人30未来20年
  • Flutter + Ollama:开启本地AI的全平台新纪元 —— 从零剖析一款现代化AI客户端的技能奥秘
  • 股票资料API接口全解析:从技术原理到多语言实战(含实时行情、MACD、KDJ等技术指标数据与API文档详解)
  • 产业园区招商团队快躺平了 - 智慧园区
  • 洛谷 P3545
  • 题解:AT_wtf22_day2_b The Greatest Two
  • 威胁狩猎实战:终端攻击行为分析与检测
  • 实用指南:基于Hadoop+Spark的人体体能数据分析与可视化系统开源实现
  • 英语_阅读_Water Sliding_待读
  • 实用指南:ArcGIS JSAPI 高级教程 - 高亮效果优化之开启使用多高亮样式
  • const在for用不了
  • about me
  • 10月北京中学集训随笔
  • 使用100%缩放比例重新启动Visual Studio 界面模糊的解决方案
  • 某工程师入职华为,职级比较高,但还看不懂代码,有点尴尬
  • 使用Silobase在几分钟内快速部署后端API
  • 【光照】[各向异性]在UnityURP中的实现
  • 基于HAL库和中断的LED流水灯
  • 从衡阳麻衣事件到AI元人文:用户端元人文实践的进化路径研究——声明ai研究
  • 5_flutter UI框架选型
  • 4_查询flutter版本信息
  • 3_flutter简单教程
  • 如何给 Claude 中的网页做截图
  • 2_gradle配置加速
  • AI元人文:岐金兰《悬鉴》起源