联合项目
https://github.com/loki-keroro/SAMbase_segmentation?tab=readme-ov-file
模型会根据不同的提示文本,生成不同的掩码,可修改main.py中的category_cfg变量,自定义提示文本。
- landcover_prompts 为地物分类的提示,在全景图中场景下一般用于分割区域连续或新增的类别
- cityobject_prompts 作为实例分割的提示,在全景图中场景下一般用于图像内区域不连续的对象类别
- landcover_prompts_cn和cityobject_prompts_cn为每个类别的中文含义
category_cfg = {"landcover_prompts": ['building', 'low vegetation', 'tree', 'river', 'shed', 'road', 'lake', 'bare soil'],"landcover_prompts_cn": ['建筑', '低矮植被', '树木', '河流', '棚屋', '道路', '湖泊', '裸土'],"cityobject_prompts": ['car', 'truck', 'bus', 'train', 'ship', 'boat'],"cityobject_prompts_cn": ['轿车', '卡车', '巴士', '列车', '船(舰)', '船(舶)'] }
from inference import PSAM# 模型配置文件和权重文件 model_cfg = {"DINO_WEIGHT_PATH": "weights/GSA_weights/groundingdino_swinb_cogcoor.pth","DINO_CFG_PATH": "groundingdino/config/GroundingDINO_SwinB.py","SAM_WEIGHT_PATH": "weights/GSA_weights/sam_vit_h_4b8939.pth","CLIP_WEIGHT_DIR": "weights/CLIP_weights/" }# prompts提示,可自定义类别列表 # 模型会根据不同的prompts提示,生成不同的掩码 # category_cfg = { # "landcover_prompts": ['building', 'low vegetation', 'tree', 'water', 'shed', 'road', 'lake', 'bare soil',], # "landcover_prompts_cn": ['建筑', '低矮植被', '树木', '水体', '棚屋', '道路', '湖泊', '裸土'], # "cityobject_prompts": ['car', 'truck', 'bus', 'train', 'ship', 'boat'], # "cityobject_prompts_cn": ['轿车', '卡车', '巴士', '列车', '船(舰)', '船(舶)'] # } category_cfg = {"landcover_prompts": [ 'building', 'water', 'tree', 'road','shed', 'cropland','grassland', 'Agricultural Fields','bare soil'],"landcover_prompts_cn": ['建筑', '水体', '树木', '道路', '棚屋', '农田', '草地', '农用地','裸土'],"cityobject_prompts": ['car', 'truck','train'],"cityobject_prompts_cn": ['轿车', '货车','火车'] }gpus = ["1"]# matplotlib使用中文绘制 cn_style = False # 是否使用中文 font_style_path = '/usr/share/fonts/wqy-microhei/wqy-microhei.ttc' # 中文字体路径,可通过fc-list命令查看系统中所安装的字体if __name__ == "__main__":psam = PSAM(model_cfg, category_cfg, gpus)# img_path = "/home/piesat/data/无人机全景图/panorama01-04/match_imgs/CD_dataset/01->03/A_B/A/100_right_0_1_hw(2701,672).png"# img_path = "/home/piesat/media/ljh/pycharm_project__ljh/panorama_sam/photos/c1.png"file_path = '/home/piesat/data/无人机全景图/panorama01-04/match_imgs/CD_dataset/cwptys_tmp/A'save_path = '/home/piesat/media/ljh/pycharm_project__ljh/panorama_sam/photos/croplands/'import osfiles = []for root, dirs, filenames in os.walk(file_path):for filename in filenames:in_img_path = os.path.join(root, filename)out_img_path = os.path.join(save_path, filename)psam.load_image(in_img_path)panoptic_inds = psam.generate_panoptic_mask()psam.plt_draw_image(cn_style, font_style_path, out_img_path)print(panoptic_inds.shape) # panoptic_inds:单通道掩码图像
import numpy as np from PIL import Image import cv2 import matplotlib.pyplot as plt from matplotlib.font_manager import FontPropertiesimport torchfrom utils.data_utils import generate_color_list from utils.load_models import load_clip_model, load_dino_model, load_sam_model from utils.func_utils import dino_detection, sam_masks_from_dino_boxes, clipseg_segmentation, \clip_and_shrink_preds, sample_points_based_on_preds, sam_mask_from_points, preds_to_semantic_indsclass PSAM(object):def __init__(self, model_cfg, category_cfg, gpu_ids):# 初始化GroundingDINO、SAM、CLIPSeg模型self.device = torch.device("cuda:%s" % gpu_ids[0] if torch.cuda.is_available() and len(gpu_ids) > 0 else "cpu")self.groundingdino_model = load_dino_model(model_cfg["DINO_CFG_PATH"], model_cfg["DINO_WEIGHT_PATH"], self.device)self.sam_predictor = load_sam_model(model_cfg["SAM_WEIGHT_PATH"], self.device)self.clipseg_processor, self.clipseg_model = load_clip_model(model_cfg["CLIP_WEIGHT_DIR"], self.device)self.landcover_categories = category_cfg["landcover_prompts"]self.cityobject_categories = category_cfg["cityobject_prompts"]self.category_names = ["background"] + self.landcover_categories + self.cityobject_categoriesself.category_name_to_id = {category_name: i for i, category_name in enumerate(self.category_names)}self.category_id_to_name = {i: category_name for i, category_name in enumerate(self.category_names)}self.color_map = generate_color_list(len(self.category_names))self.landcover_categories_cn = category_cfg["landcover_prompts_cn"]self.cityobject_categories_cn = category_cfg["cityobject_prompts_cn"]self.category_names_cn = ["背景"] + self.landcover_categories_cn + self.cityobject_categories_cnself.category_id_to_name_cn = {i: category_name for i, category_name in enumerate(self.category_names_cn)}def load_image(self, img_path):# 读取图像并进行SAM的图像编码image = Image.open(img_path)self.image = image.convert("RGB")self.image_array = np.asarray(self.image)self.sam_predictor.set_image(self.image_array)def generate_panoptic_mask(self, dino_box_threshold=0.2,dino_text_threshold=0.20,segmentation_background_threshold=0.1,shrink_kernel_size=10,num_samples_factor=300):# 1.基于DINO的城市目标检测,并结合SAM进行分割cityobject_category_ids = []cityobject_masks = torch.empty(0)cityobject_boxes = []if len(self.cityobject_categories) > 0:cityobject_boxes, cityobject_category_ids, _ = dino_detection(self.groundingdino_model,self.image,self.cityobject_categories,self.category_name_to_id,dino_box_threshold,dino_text_threshold,self.device,)if len(cityobject_boxes) > 0:cityobject_masks = sam_masks_from_dino_boxes(self.sam_predictor, self.image_array, cityobject_boxes, self.device)# 2.基于CLIP的地物分类,并结合SAM进行分割if len(self.landcover_categories) > 0:clipseg_preds, clipseg_semantic_inds = clipseg_segmentation(self.clipseg_processor,self.clipseg_model,self.image,self.landcover_categories,segmentation_background_threshold,self.device,)clipseg_semantic_inds_without_cityobject = clipseg_semantic_inds.clone()if len(cityobject_boxes) > 0:combined_cityobject_mask = torch.any(cityobject_masks, dim=0)clipseg_semantic_inds_without_cityobject[combined_cityobject_mask[0]] = 0clipsed_clipped_preds, relative_sizes = clip_and_shrink_preds(clipseg_semantic_inds_without_cityobject,clipseg_preds,shrink_kernel_size,len(self.landcover_categories) + 1,)sam_preds = torch.zeros_like(clipsed_clipped_preds)for i in range(clipsed_clipped_preds.shape[0]):clipseg_pred = clipsed_clipped_preds[i]num_samples = int(relative_sizes[i] * num_samples_factor)if num_samples == 0:continuepoints = sample_points_based_on_preds(clipseg_pred.cpu().numpy(), num_samples)if len(points) == 0:continuepred = sam_mask_from_points(self.sam_predictor, self.image_array, points)sam_preds[i] = predsam_semantic_inds = preds_to_semantic_inds(sam_preds, segmentation_background_threshold)# 3.结合城市目标和地物分类的掩码结果if len(self.landcover_categories) > 0:# 进行开闭运算self.panoptic_inds = sam_semantic_inds.clone().cpu().numpy().astype(np.uint8)kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (7, 7))self.panoptic_inds = cv2.morphologyEx(self.panoptic_inds, cv2.MORPH_OPEN, kernel)self.panoptic_inds = cv2.morphologyEx(self.panoptic_inds, cv2.MORPH_CLOSE, kernel)else:self.panoptic_inds = np.zeros((self.image_array.shape[0], self.image_array.shape[1]), dtype=np.uint8)for mask_cid in range(cityobject_masks.shape[0]):ind = cityobject_category_ids[mask_cid]mask_bool = cityobject_masks[mask_cid].squeeze(dim=0).cpu().numpy()self.panoptic_inds[mask_bool] = indreturn self.panoptic_indsdef plt_draw_image(self, cn_style=False, font_style_path=None, save_file_path =None):# 是否使用中文显示if cn_style==True and font_style_path is not None:cn_style = Truefont = FontProperties(fname=font_style_path)id_to_name = self.category_id_to_name_cnelse:cn_style = Falsefont = FontProperties()id_to_name = self.category_id_to_name# 使用unique函数和return_counts参数计算每种类别的占用像素个数unique_values, counts = np.unique(self.panoptic_inds, return_counts=True)count_map = {}bar_colors = [] # 储存每种类别的颜色for value, count in zip(unique_values, counts):count_map[id_to_name[value]] = countr, g, b = self.color_map[value]r = r / 255g = g / 255b = b / 255bar_colors.append((r, g, b, 1.0))x = list(count_map.keys())y = list(count_map.values())# 创建子图fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(15, 12))# 绘制原图axes[0, 0].imshow(self.image)# 绘制掩码图cm = [list(t) for t in self.color_map]cm = np.array(cm).astype('uint8')label_img = cm[self.panoptic_inds]axes[0, 1].imshow(Image.fromarray(label_img))# 绘制合并图draw_image = cv2.addWeighted(np.array(self.image), 0.7, label_img, 0.3, 0)axes[1, 0].imshow(Image.fromarray(draw_image))# 绘制柱状图axes[1, 1].bar(range(len(x)), y, label=x, color=bar_colors)# 添加数值标签for a, b in zip(range(len(x)), y):axes[1, 1].text(a, b, b, ha='center', va='bottom', fontproperties=font)# 添加标题和横纵坐标含义if cn_style:axes[1, 1].set_title('统计每个类别占用的像素', fontproperties=font)axes[1, 1].set_xlabel('类别', fontproperties=font)axes[1, 1].set_ylabel('像素', fontproperties=font)else:axes[1, 1].set_title('Pixel Count', fontproperties=font)axes[1, 1].set_xlabel('Category', fontproperties=font)axes[1, 1].set_ylabel('Pixel', fontproperties=font)axes[1, 1].set_xticklabels([])# 添加图例axes[1, 1].legend(prop=font)# 调整子图间距plt.subplots_adjust(wspace=0.15, hspace=0.2)#保存图像plt.savefig(save_file_path)# # 显示图形# plt.show()