lin_reg
Typer | Posted on | |
# ------------------------------
# 导入库
# ------------------------------
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from sklearn.linear_model import LinearRegression
# ------------------------------
# 生成随机数据
# ------------------------------
pool = np.random.RandomState(10) # 创建随机数生成器
x = 5 * pool.rand(30) # 在区间 [0, 5) 上生成30个随机数
y = 3 * x - 2 + pool.randn(30) # 构造近似线性关系,加上高斯噪声
# ------------------------------
# 创建线性回归模型并训练
# ------------------------------
lregr = LinearRegression(fit_intercept=False) # 不拟合截距项
X = x[:, np.newaxis] # 将一维数据转换为二维列向量
lregr.fit(X, y) # 训练模型
# ------------------------------
# 生成预测结果
# ------------------------------
lspace = np.linspace(0, 5) # 生成绘图用的横坐标范围
X_regr = lspace[:, np.newaxis] # 转换为二维列向量
y_regr = lregr.predict(X_regr) # 计算对应的预测值
# ------------------------------
# 绘图展示
# ------------------------------
plt.scatter(x, y) # 原始散点图
plt.plot(X_regr, y_regr) # 回归直线
plt.show()