当前位置: 首页 > news >正文

线性回归 C++ 实现

线性回归 C++ 实现

参考链接

https://www.cnblogs.com/eat-too-much/p/16796533.html

代码

#include <iostream>
#include <vector>
#include <random>
#include <cmath>
#include <algorithm>using namespace std;
// 计算点积
float dot_product(const vector<float>& a, const vector<float>& b) {float sum = 0.0;for (size_t i = 0; i < a.size(); ++i) {sum += a[i] * b[i];}return sum;
}
// 生成数据
void synthetic_data(const vector<float> & w, float b, int num_examples, vector<vector<float>>  & X, vector<float> & y) {random_device rd;mt19937 gen(rd());normal_distribution<float> normal(0, 1);X.resize(num_examples, vector<float>(w.size()));y.resize(num_examples);for(int i=0; i<num_examples; ++i) {for(int j=0; j<w.size(); j++) {X[i][j] = normal(gen);}float err = normal(gen) * 0.01;y[i] = (float)(dot_product(X[i], w) + b + err);}
}// 批量数据迭代器
void data_iter(int batch_size, const vector<vector<float>>& features, const vector<float>& labels, vector<pair<vector<vector<float>>, vector<float>>>& batches) {vector<int> indices(features.size());iota(indices.begin(), indices.end(), 0);random_shuffle(indices.begin(), indices.end());batches.clear();for (int i = 0; i < features.size(); i += batch_size) {vector<vector<float>> X_batch;vector<float> y_batch;int batch_end = min(i + batch_size, static_cast<int>(features.size()));X_batch.resize(batch_end - i);y_batch.resize(batch_end - i);for (int j = i; j < batch_end; ++j) {X_batch[j - i] = features[indices[j]];y_batch[j - i] = labels[indices[j]];}batches.push_back({X_batch, y_batch});}
}// 线性回归模型
vector<float> linreg(const vector<vector<float>>& X, const vector<float>& w, float b) {vector<float> y_hat(X.size());for (size_t i = 0; i < X.size(); ++i) {y_hat[i] = dot_product(X[i], w) + b;}return y_hat;
}// 均方损失
float squared_loss(const vector<float>& y_hat, const vector<float>& y) {float loss = 0.0;for (size_t i = 0; i < y.size(); ++i) {loss += pow(y_hat[i] - y[i], 2) / 2;}return loss / y.size();
}// 小批量随机梯度下降
void sgd(vector<float>& params, const vector<float>& grads, float lr, int batch_size) {for (size_t i = 0; i < params.size(); ++i) {params[i] -= lr * grads[i] / batch_size;}
}void sgd_bias(float& b, float grad_b, float lr, int batch_size) {b -= lr * grad_b / batch_size;
}int main() {// 真实参数vector<float> true_w = {2.0, -3.4};float true_b = 4.2;// 生成数据int num_examples = 1000;vector<vector<float>> features;vector<float> labels;synthetic_data(true_w, true_b, num_examples, features, labels);// 初始化模型参数vector<float> w = {0.0, 0.0};float b = 0.0;// 超参数float lr = 0.0001;int num_epochs = 1000;int batch_size = 10;// 训练过程for (int epoch = 0; epoch < num_epochs; ++epoch) {vector<pair<vector<vector<float>>, vector<float>>> batches;data_iter(batch_size, features, labels, batches);for (auto& batch : batches) {vector<vector<float>> X_batch = batch.first;vector<float> y_batch = batch.second;// 前向传播vector<float> y_hat = linreg(X_batch, w, b);// 计算损失  没有必要 python 用于 自动微分计算才需要//float l = squared_loss(y_hat, y_batch);// 计算梯度vector<float> grad_w(w.size(), 0.0);float grad_b = 0.0;// 计算梯度(for (size_t i = 0; i < y_batch.size(); ++i) {float err = y_hat[i] - y_batch[i];for (size_t j = 0; j < w.size(); ++j) {grad_w[j] += err * X_batch[i][j];  // 无除法,变为 sum}grad_b += err;  // 无除法,变为 sum}// 更新参数sgd(w, grad_w, lr, batch_size);sgd_bias(b, grad_b, lr, batch_size);}// 打印损失vector<float> all_y_hat = linreg(features, w, b);float train_l = squared_loss(all_y_hat, labels);cout << "epoch " << epoch + 1 << ", loss: " << train_l << endl;}// 输出估计误差cout << "w真实值: " << endl;for (size_t i = 0; i < w.size(); ++i) {cout << true_w[i]<< " ";}cout << endl;cout << "w的估计误差: ";for (size_t i = 0; i < w.size(); ++i) {cout << true_w[i] - w[i] << " ";}cout << endl;cout << "w的估计: ";for (size_t i = 0; i < w.size(); ++i) {cout << w[i] << " ";}cout << endl;cout << "b的真实: " << true_b << endl;cout << "b的估计误差: " << true_b - b << endl;cout << "b的估计: " << b << endl;return 0;
}
http://www.hskmm.com/?act=detail&tid=29388

相关文章:

  • 内存分区
  • Spring Data JPA学习笔记
  • P1112 波浪数 题解
  • 20232411 2025-2026-1 《网络与系统攻防技术》实验一实验报告
  • 使用 Pascal 实现英文数字验证码识别系统
  • PWN手的成长之路-15-jarvisoj_level2_x64
  • 2025.10.12——1绿
  • 价值博弈场的工程实现:构建数字文明的价值免疫系统——声明Ai生成
  • 基于 Rust 的英文数字验证码识别系统设计与实现
  • 2025年两联供室内机厂家最新权威推荐榜:技术实力与市场口碑
  • 2025武汉商铺装修防水厂家最新权威推荐榜:专业施工与品质保
  • 2025铝合金微弧氧化厂家权威推荐榜:表面处理技术实力深度解
  • 2025杉木木方厂家最新权威推荐榜:优质木材与稳定供应口碑之
  • 2025年厂房保养厂家最新权威推荐榜:专业维护与成本控制优选
  • 使用C语言实现重写stm32的启动文件
  • 2025中医师承权威推荐榜:名师带徒与临床实践深度解析
  • 让我们开始 CSS 的学习之旅
  • 2025液压无损扒胎机厂家权威推荐榜:高效无损与耐用性能深度
  • Linux环境下的UDEV机制及其与守护进程的关联
  • 在Red Hat Enterprise Linux 9上使用Docker快速安装并部署
  • [转] 基于ubuntu-base进行根文件系统的修改与打包
  • jquery 字符串转数字
  • AI圈每日技术学习---紧跟时代脚步(N8n工作流)
  • B -识别浮点常量问题
  • 小C语言--词法分析程序-A
  • 查寻MySQL或SQL Server的连接数,并配置超时时间和最大连接量
  • 2025宿舍上下床厂家权威推荐榜:耐用设计与空间优化口碑之选
  • 梳理一下Java中为保证线程安全提供了那些方案 - 浪矢
  • 处理Git错误:“invalid object [hash]”和“unable to read tree [hash]”
  • 2025厂房恒温恒湿设备厂家权威推荐榜:精准控温与节能技术深