当前位置: 首页 > news >正文

多元线性回归-梯度下降法-吴恩达机器学习

0.工具

import copy, math, sys
import numpy as np

1.线性回归模型

def f_wb(x,w,b):return np.dot(w,x) + b

2.成本函数

def compute_cost(X, y, w, b):m,_ = X.shapeJ_wb = 0.0for i in range(m):J_wb += (f_wb(X[i],w,b) - y[i])**2J_wb /= 2*mreturn J_wb

3.梯度计算

def compute_gradient(X, y, w, b):m,n = X.shapedj_dw = np.zeros(n)dj_db = 0.0for i in range(m):err = f_wb(X[i],w,b) - y[i]for j in range(n):dj_dw[j] += err*X[i,j]dj_db += errfor j in range(n):dj_dw[j] /= mdj_db /= mreturn dj_dw, dj_db

4.梯度下降法

def gradient_descend(X, y, w_init, b_init, alpha, num_iters):m,n = X.shapeJ_history = []p_history = []w = copy.deepcopy(w_init)b = b_initfor t in range(num_iters):dj_dw, dj_db = compute_gradient(X, y, w, b)w_temp = w - alpha * dj_dwb_temp = b - alpha * dj_dbw = w_tempb = b_tempJ_history.append(compute_cost(X, y, w, b))if t% math.ceil(num_iters / 10) == 0:print(f"Iteration {t:4d}: Cost {J_history[-1]:8.2f}   ")return w, b, J_history

5.测试算法正确性

X_train = np.array([[2104, 5, 1, 45], [1416, 3, 2, 40], [852, 2, 1, 35]])
y_train = np.array([460, 232, 178])
m,n = X_train.shapew_initial = np.zeros(n)
b_initial = 0.0iterations = 1000
alpha = 1e-7w_final, b_final, J_hist = gradient_descend(X_train, y_train, w_initial, b_initial, alpha, iterations)print(f"b,w found by gradient descent: {b_final:0.2f},{w_final} ")
for i in range(m):print(f"prediction: {np.dot(X_train[i], w_final) + b_final:0.2f}, target value: {y_train[i]}")

输出结果如下

Iteration    0: Cost 28989.11   
Iteration  100: Cost   696.86   
Iteration  200: Cost   696.65   
Iteration  300: Cost   696.43   
Iteration  400: Cost   696.21   
Iteration  500: Cost   696.00   
Iteration  600: Cost   695.78   
Iteration  700: Cost   695.57   
Iteration  800: Cost   695.36   
Iteration  900: Cost   695.14   
b,w found by gradient descent: -0.00,[ 0.20253263  0.00112386 -0.00213202 -0.00933401] 
prediction: 425.71, target value: 460
prediction: 286.41, target value: 232
prediction: 172.23, target value: 178
http://www.hskmm.com/?act=detail&tid=26358

相关文章:

  • 概率论小测试
  • AI 产品研发的一些思考
  • 3.模块化与MVVM设计模式
  • 2025舒适轮胎厂家、静音轮胎厂家企业品牌权威推荐榜:静音技术与驾乘体验口碑之选
  • 幻想是最廉价的止疼药
  • 20251005 耳朵龙字符串
  • 玩转树莓派屏幕之五:自定义LCD屏幕显示
  • AtCoder ARC207 总结
  • 2025.10.7模拟赛
  • 详细介绍:ZLG ZCANPro,ECU刷新,bug分享
  • 好好学习, 天天向上
  • 2.洋葱开发法
  • OpenStack搭建
  • OpenStack实验过程
  • 2025.10.7+7
  • oppoR9m刷Linux系统:VCOM模式备份系统与基带IMEI/NVRAM/QCN
  • 两个开源中国象棋引擎的编译
  • 推荐一款Swift开发框架- Aquarius
  • 1.如何导入Aquarius开发框架
  • 课程作业(10月8日)
  • 帮宣——可控核聚变
  • 浅谈导数
  • 洛谷P5304 [GXOI/GZOI2019] 旅行者(二进制分类技巧)
  • 英语_阅读_AI Robot_待读
  • 【C++】AVL树的概念及完成(万字图文超详解)
  • 打造自主学习的AI Agent:强化学习+LangGraph代码示例
  • 关于二分
  • NKOJ全TJ计划——NP11721
  • 印度全球能力中心2030年展望与技术基建规划
  • NOI Linux 食用教程