当前位置: 首页 > news >正文

深入解析:Day43 Python打卡训练营

深入解析:Day43 Python打卡训练营

作业:

kaggle找到一个图像数据集,用cnn网络进行训练并且用grad-cam做可视化

进阶:并拆分成多个文件

选取Kaggle上的CIFAR-10数据集进行CNN训练,并使用Grad-CAM进行可视化,代码将拆分为多个文件以保持模块化。CIFAR-10是一个包含60,000张32x32彩色图像的数据集,分为10个类别。

项目结构

cifar10_cnn_gradcam/├── data_loader.py         # 数据加载和预处理├── model.py              # CNN模型定义├── gradcam.py            # Grad-CAM实现├── train.py              # 模型训练逻辑├── visualize.py          # 可视化Grad-CAM结果├── main.py               # 主执行脚本└── requirements.txt      # 依赖库

1. 数据加载(data_loader.py)

此文件负责加载和预处理CIFAR-10数据集,并进行训练、验证、测试集划分。

import tensorflow as tffrom sklearn.model_selection import train_test_split def load_cifar10_data():    # 加载CIFAR-10数据集    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()        # 归一化像素值到[0, 1]    x_train = x_train.astype('float32') / 255.0    x_test = x_test.astype('float32') / 255.0        # 将训练集进一步拆分为训练和验证集(80%训练,20%验证)    x_train, x_val, y_train, y_val = train_test_split(        x_train, y_train, test_size=0.2, random_state=42    )        # 类名    class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',                   'dog', 'frog', 'horse', 'ship', 'truck']        return (x_train, y_train), (x_val, y_val), (x_test, y_test), class_names

2. 模型定义 (model.py)

此文件定义一个简单的CNN模型,适合CIFAR-10分类任务。

import tensorflow as tffrom tensorflow.keras import layers, models def build_cnn_model():    model = models.Sequential([        layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3), padding='same'),        layers.MaxPooling2D((2, 2)),        layers.Conv2D(64, (3, 3), activation='relu', padding='same'),        layers.MaxPooling2D((2, 2)),        layers.Conv2D(128, (3, 3), activation='relu', padding='same'),        layers.MaxPooling2D((2, 2)),        layers.Flatten(),        layers.Dense(128, activation='relu'),        layers.Dropout(0.5),        layers.Dense(10, activation='softmax')    ])        model.compile(optimizer='adam',                  loss='sparse_categorical_crossentropy',                  metrics=['accuracy'])        return model

3. Grad-CAM实现 (gradcam.py)

此文件实现Grad-CAM算法,用于生成CNN的注意力热图。

import tensorflow as tfimport numpy as npimport cv2 class GradCAM:    def __init__(self, model, layer_name):        self.model = model        self.layer_name = layer_name        self.grad_model = tf.keras.models.Model(            [model.inputs], [model.get_layer(layer_name).output, model.output]        )     def generate_heatmap(self, image, class_idx):        image = tf.cast(image, tf.float32)        with tf.GradientTape() as tape:            conv_output, predictions = self.grad_model(image)            loss = predictions[:, class_idx]         grads = tape.gradient(loss, conv_output)        pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))        conv_output = conv_output[0]        heatmap = tf.reduce_mean(tf.multiply(conv_output, pooled_grads), axis=-1)        heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)        return heatmap.numpy()     def superimpose_heatmap(self, image, heatmap, alpha=0.4):        heatmap = cv2.resize(heatmap, (image.shape[1], image.shape[0]))        heatmap = np.uint8(255 * heatmap)        heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)        image = np.uint8(255 * image)        superimposed_img = heatmap * alpha + image        superimposed_img = np.clip(superimposed_img, 0, 255).astype(np.uint8)        return superimposed_img

4. 模型训练 (train.py)

此文件包含训练逻辑,使用数据增强以提高模型鲁棒性。

import tensorflow as tffrom tensorflow.keras import layersfrom model import build_cnn_model def train_model(x_train, y_train, x_val, y_val, epochs=25, batch_size=32):    model = build_cnn_model()        # 数据增强    data_augmentation = tf.keras.Sequential([        layers.RandomFlip("horizontal"),        layers.RandomRotation(0.1),        layers.RandomZoom(0.1),    ])        # 训练模型    history = model.fit(        data_augmentation(x_train), y_train,        validation_data=(x_val, y_val),        epochs=epochs,        batch_size=batch_size,        verbose=1    )        model.save('cifar10_cnn_model.h5')    return model, history

5. 可视化Grad-CAM结果 (visualize.py)

此文件负责生成和保存Grad-CAM可视化结果。

import numpy as npimport matplotlib.pyplot as pltfrom gradcam import GradCAM def visualize_gradcam(model, x_test, y_test, class_names, num_images=5):    gradcam = GradCAM(model, layer_name='conv2d_2')  # 选择最后一层卷积层        plt.figure(figsize=(15, 10))    for i in range(num_images):        img = x_test[i:i+1]        true_label = y_test[i][0]        pred = model.predict(img)        pred_label = np.argmax(pred, axis=1)[0]                # 生成热图        heatmap = gradcam.generate_heatmap(img, pred_label)        superimposed_img = gradcam.superimpose_heatmap(img[0], heatmap)                # 可视化        plt.subplot(num_images, 3, i*3 + 1)        plt.imshow(img[0])        plt.title(f'True: {class_names[true_label]}')        plt.axis('off')                plt.subplot(num_images, 3, i*3 + 2)        plt.imshow(heatmap, cmap='jet')        plt.title('Heatmap')        plt.axis('off')                plt.subplot(num_images, 3, i*3 + 3)        plt.imshow(superimposed_img)        plt.title(f'Pred: {class_names[pred_label]}')        plt.axis('off')        plt.tight_layout()    plt.savefig('gradcam_visualization.png')    plt.close()

6. 主执行脚本 (main.py)

此文件协调整个流程,调用其他模块执行数据加载、训练和可视化。

from data_loader import load_cifar10_datafrom train import train_modelfrom visualize import visualize_gradcam def main():    # 加载数据    (x_train, y_train), (x_val, y_val), (x_test, y_test), class_names = load_cifar10_data()        # 训练模型    model, history = train_model(x_train, y_train, x_val, y_val, epochs=25, batch_size=32)        # 可视化Grad-CAM    visualize_gradcam(model, x_test, y_test, class_names, num_images=5)        # 评估模型    test_loss, test_acc = model.evaluate(x_test, y_test, verbose=0)    print(f"Test accuracy: {test_acc:.4f}") if __name__ == "__main__":    main()

7. 依赖文件 (requirements.txt)

列出项目所需的Python库。

tensorflow==2.10.0 numpy scikit-learn matplotlib opencv-python

@浙大疏锦行

http://www.hskmm.com/?act=detail&tid=26260

相关文章:

  • 用 Perl 实现验证码图像识别
  • 实用指南:【结构型模式】代理模式
  • cnblog Test
  • 云数据仓库十年架构演进与技术突破
  • 20251007 模拟测 总结
  • 2025国庆Day6
  • Claude 封杀中国后,我终于找到了平替!
  • [退役感言]You are my only one.
  • Mortal
  • python,shell,linux,bash概念的不同和对比联系 - 指南
  • 制作局域网连接打印机exe文件
  • 深入解析:linux——账号和权限的管理
  • pandoc使用
  • c#造个轮子--GIF录制工具
  • netdata
  • 关于Elment-plus的el-table组件无法通过原生JS监听scroll事件
  • arc3.2语言sort的时候报错:(sort < `(2 9 3 7 5 1)) 得写成此种:(sort > (pair (list 3 2)))
  • 噬菌体展示技术:从诺奖成果到疫苗研发,这一 “表型 - 基因型统一” 工具如何颠覆生物研究?
  • 从零开始学Flink:实时流处理实战
  • 高质量同人动画整理回顾记录的方式
  • 斑马打印机基础知识
  • 加拿大加密货币牌照:合规化加速数字资产成功
  • 深入解析:实时通信RTC与传统直播的异同
  • Exp2-后门原理与实践
  • 【Hexo】4.Hexo 博客文章进行加密 - 实践
  • 思考的动力
  • Software Foundations Vol.I : 多态与高阶函数(Poly)
  • 数学之美感悟。
  • 基于DeploySharp 的深度学习模型部署测试平台:支持YOLO全系列模型
  • 复制别人的vmware虚拟机无法联网ubuntu2204