当前位置: 首页 > news >正文

第七章 手写数字识别V2

#这个版本考虑将隐藏层、输出层及对应的forward和backward封装到对于类中
#减少了步长,实测lr>0.02就会溢出,可能是哪里没有优化好
# 导入必要的库
import numpy as np
import os
import struct
import matplotlib.pyplot as plt# 定义导入函数
def load_images(path):with open(path, "rb") as f:data = f.read()magic_number, num_items, rows, cols = struct.unpack(">iiii", data[:16])return np.asanyarray(bytearray(data[16:]), dtype=np.uint8).reshape(num_items, 28, 28)def load_labels(file):with open(file, "rb") as f:data = f.read()return np.asanyarray(bytearray(data[8:]), dtype=np.int32)# 定义sigmoid函数
def sigmoid(x):result = np.zeros_like(x)positive_mask = x >= 0result[positive_mask] = 1 / (1 + np.exp(-x[positive_mask]))negative_mask = x < 0exp_x = np.exp(x[negative_mask])result[negative_mask] = exp_x / (1 + exp_x)return result# 定义softmax函数
def softmax(x):max_x = np.max(x, axis=-1, keepdims=True)x = x - max_xex = np.exp(x)sum_ex = np.sum(ex, axis=1, keepdims=True)result = ex / sum_exresult = np.clip(result, 1e-10, 1e10)return result# 定义独热编码函数
def make_onehot(labels, class_num):result = np.zeros((labels.shape[0], class_num))for idx, cls in enumerate(labels):result[idx, cls] = 1return result# 定义dataset类
class Dataset:def __init__(self, all_images, all_labels):self.all_images = all_imagesself.all_labels = all_labelsdef __getitem__(self, index):image = self.all_images[index]label = self.all_labels[index]return image, labeldef __len__(self):return len(self.all_images)# 定义dataloader类
class DataLoader:def __init__(self, dataset, batch_size, shuffle=True):self.dataset = datasetself.batch_size = batch_sizeself.shuffle = shuffleself.idx = np.arange(len(self.dataset))def __iter__(self):# 如果需要打乱,则在每个 epoch 开始时重新排列索引if self.shuffle:np.random.shuffle(self.idx)self.cursor = 0return selfdef __next__(self):if self.cursor >= len(self.dataset):raise StopIteration# 使用索引来获取数据batch_idx = self.idx[self.cursor : min(self.cursor + self.batch_size, len(self.dataset))]batch_images = self.dataset.all_images[batch_idx]batch_labels = self.dataset.all_labels[batch_idx]self.cursor += self.batch_sizereturn batch_images, batch_labels# 定义linear类
class Linear:def __init__(self, in_features, out_features):# 将这部分置于Linear类中# # 权重和偏置# w1 = np.random.normal(0, 1, size=(784, 512))  # 第一层权重# w3 = np.random.normal(0, 1, size=(512, 256))  # 第三层权重# w4 = np.random.normal(0, 1, size=(256, 10))  # 第四层权重# b1 = np.random.normal(0, 1, size=(1, 512))  # 第一层偏置# b3 = np.random.normal(0, 1, size=(1, 256))  # 第三层偏置# b4 = np.random.normal(0, 1, size=(1, 10))  # 第四层偏置self.w = np.random.normal(0, 1, size=(in_features, out_features))self.b = np.random.normal(0, 1, size=(1, out_features))def forward(self, x):# H1 = np.dot(batch_images, w1) + b1  # 第一层输出# H3 = np.dot(H2, w3) + b3  # 第三层输出# H4 = np.dot(H3, w4) + b4  # 第四层输出self.x = xreturn np.dot(x, self.w) + self.bdef backward(self, G):# dw4 = np.dot(H3.T, G4)  # 第四层权重梯度# dw3 = np.dot(H2.T, G3)  # 第三层权重梯度# dw1 = np.dot(batch_images.T, G1)  # 第一层权重梯度dw = np.dot(self.x.T, G)# db4 = np.mean(G4, axis=0, keepdims=True)  # 第四层偏置梯度# db3 = np.mean(G3, axis=0, keepdims=True)  # 第三层偏置梯度# db1 = np.mean(G1, axis=0, keepdims=True)  # 第一层偏置梯度db = np.sum(G, axis=0, keepdims=True)# 更新权重和偏置# w1 -= lr * dw1# b1 -= lr * db1# w3 -= lr * dw3# b3 -= lr * db3# w4 -= lr * dw4# b4 -= lr * db4self.w -= lr * dwself.b -= lr * db# G3 = np.dot(G4, w4.T)  # 第三层误差# G2 = np.dot(G3, w3.T)  # 第二层误差return np.dot(G, self.w.T)# 定义Sigmoid类
class Sigmoid:def __init__(self):passdef forward(self, x):# H2 = sigmoid(H1)  # 第二层输出,使用sigmoid激活函数self.result = sigmoid(x)return self.resultdef backward(self, G):# G2 = G * H2 * (1 - H2)  # 第二层误差return G * self.result * (1 - self.result)# 定义Softmax类
class Softmax:def __init__(self):passdef forward(self, x):# p = softmax(H4)  # 输出层输出,使用softmax激活函数self.p = softmax(x)return self.pdef backward(self, G):# G4 = G * H4 * (1 - H4)  # 第四层误差G = (self.p - G) / len(G)return G# 主函# 主函数
if __name__ == "__main__":# 加载训练集图片、标签train_images = (load_images(os.path.join("Python", "NLP basic", "data", "minist", "train-images.idx3-ubyte"))/ 255)train_labels = make_onehot(load_labels(os.path.join("Python", "NLP basic", "data", "minist", "train-labels.idx1-ubyte")),10,)# 加载测试集图片、标签dev_images = (load_images(os.path.join("Python", "NLP basic", "data", "minist", "t10k-images.idx3-ubyte"))/ 255)dev_labels = load_labels(os.path.join("Python", "NLP basic", "data", "minist", "t10k-labels.idx1-ubyte"))# 设置超参数epochs = 10lr = 0.008#V2版本调整了学习率batch_size = 200shuffle = True# 展开图片数据train_images = train_images.reshape(60000, 784)dev_images = dev_images.reshape(-1, 784)# 调用dataset类和dataloader类train_dataset = Dataset(train_images, train_labels)train_dataloader = DataLoader(train_dataset, batch_size, shuffle)dev_dataset = Dataset(dev_images, dev_labels)dev_dataloader = DataLoader(dev_dataset, batch_size, shuffle)# #隐藏层*4 输出层*1# layer1=Linear(784,512)# layer2=Sigmoid()# layer3=Linear(512,256)# layer4=Linear(256,10)# layer5=Softmax()# 转化为layers数组layers = [Linear(784, 512), Sigmoid(), Linear(512, 256), Linear(256, 10), Softmax()]# 训练集训练过程for e in range(epochs):for x, l in train_dataloader:# 前向传播for layer in layers:x=layer.forward(x)# 计算损失loss = -np.mean(l * np.log(x))# 反向传播for layer in layers[::-1]:l = layer.backward(l)# 验证集验证并输出预测准确率right_num = 0for x, batch_labels in dev_dataloader:# H1 = np.dot(batch_images, w1) + b1  # 第一层输出# H2 = sigmoid(H1)  # 第二层输出,使用sigmoid激活函数# H3 = np.dot(H2, w3) + b3  # 第三层输出# H4 = np.dot(H3, w4) + b4  # 第四层输出# p = softmax(H4)  # 输出层输出,使用softmax激活函数for layer in layers:x = layer.forward(x)pre_idx = np.argmax(x, axis=-1)  # 预测类别right_num += np.sum(pre_idx == batch_labels)  # 统计正确个数acc = right_num / len(dev_images)  # 计算准确率print(f"Epoch {e}, Acc: {acc:.4f}")

image

http://www.hskmm.com/?act=detail&tid=15042

相关文章:

  • 常用软件下载
  • 实用指南:S 4.1深度学习--自然语言处理NLP--理论
  • PyTorch图神经网络(五)
  • java
  • Jordan块新解
  • [CSP-S 2024] 染色
  • Kerberos 安装和使用
  • 第一次个人编程任务
  • 概率期望总结
  • redis实现秒杀下单的业务逻辑
  • 关于边缘网络+数据库(1)边缘网络数据库模式及选型
  • 题解:B4357 [GESP202506 二级] 幂和数
  • 2025年9月23日 - 20243867孙堃2405
  • 2025.9.23
  • 软件工程学习日志2025.9.23
  • markdown 使用指南
  • 第6.2节 Android Agent制作<三>
  • LVS 服务器 知识
  • 07-django+DRF项目中统一json返回格式 - 详解
  • 软工第二次作业——个人项目
  • 近十年 CSP-J 复赛知识点分布表
  • AT_arc181_d [ARC181D] Prefix Bubble Sort
  • 【MySQL】使用C/C++链接mysql数据库 - 指南
  • 枚举子集
  • cv-css 快捷方式,将指定节点的计算样式获取下拉 获取tailwind网页样式成原生样式
  • day002
  • PyTorch图神经网络(四)
  • 软件工程:构建数字世界的基石
  • Avalonia 学习笔记07. Control Themes(控件主题)
  • matter 协议的架构;