查看原文
其他

复现经典:《统计学习方法》第1章 统计学习方法概论

机器学习初学者 机器学习初学者 2022-05-16

本文是李航老师的《统计学习方法》[1]一书的代码复现。

作者:黄海广[2]

备注:代码都可以在github[3]中下载。

我将陆续将代码发布在公众号“机器学习初学者”,敬请关注。

代码目录

  • 第 1 章 统计学习方法概论
  • 第 2 章 感知机
  • 第 3 章 k 近邻法
  • 第 4 章 朴素贝叶斯
  • 第 5 章 决策树
  • 第 6 章 逻辑斯谛回归
  • 第 7 章 支持向量机
  • 第 8 章 提升方法
  • 第 9 章 EM 算法及其推广
  • 第 10 章 隐马尔可夫模型
  • 第 11 章 条件随机场
  • 第 12 章 监督学习方法总结

代码参考:wzyonggege[4],WenDesi[5],火烫火烫的[6]

第 1 章 统计学习方法概论

1.统计学习是关于计算机基于数据构建概率统计模型并运用模型对数据进行分析与预测的一门学科。统计学习包括监督学习、非监督学习、半监督学习和强化学习。

2.统计学习方法三要素——模型、策略、算法,对理解统计学习方法起到提纲挈领的作用。

3.本书主要讨论监督学习,监督学习可以概括如下:从给定有限的训练数据出发, 假设数据是独立同分布的,而且假设模型属于某个假设空间,应用某一评价准则,从假设空间中选取一个最优的模型,使它对已给训练数据及未知测试数据在给定评价标准意义下有最准确的预测。

4.统计学习中,进行模型选择或者说提高学习的泛化能力是一个重要问题。如果只考虑减少训练误差,就可能产生过拟合现象。模型选择的方法有正则化与交叉验证。学习方法泛化能力的分析是统计学习理论研究的重要课题。

5.分类问题、标注问题和回归问题都是监督学习的重要问题。本书中介绍的统计学习方法包括感知机、近邻法、朴素贝叶斯法、决策树、逻辑斯谛回归与最大熵模型、支持向量机、提升方法、EM 算法、隐马尔可夫模型和条件随机场。这些方法是主要的分类、标注以及回归方法。它们又可以归类为生成方法与判别方法。

使用最小二乘法拟和曲线

高斯于 1823 年在误差独立同分布的假定下,证明了最小二乘方法的一个最优性质: 在所有无偏的线性估计类中,最小二乘方法是其中方差最小的!对于数据

拟合出函数

有误差,即残差:

此时范数(残差平方和)最小时, 和  相似度最高,更拟合

一般的次的多项式,

为参数

最小二乘法就是要找到一组  ,使得 (残差平方和) 最小

即,求 


举例:我们用目标函数, 加上一个正态分布的噪音干扰,用多项式去拟合【例 1.1 11 页】

import numpy as npimport scipy as spfrom scipy.optimize import leastsqimport matplotlib.pyplot as plt%matplotlib inline
  • ps: numpy.poly1d([1,2,3]) 生成 *
# 目标函数def real_func(x): return np.sin(2*np.pi*x)
# 多项式def fit_func(p, x): f = np.poly1d(p) return f(x)
# 残差def residuals_func(p, x, y): ret = fit_func(p, x) - y return ret
# 十个点x = np.linspace(0, 1, 10)x_points = np.linspace(0, 1, 1000)# 加上正态分布噪音的目标函数的值y_ = real_func(x)y = [np.random.normal(0, 0.1) + y1 for y1 in y_]

def fitting(M=0): """ M 为 多项式的次数 """ # 随机初始化多项式参数 p_init = np.random.rand(M + 1) # 最小二乘法 p_lsq = leastsq(residuals_func, p_init, args=(x, y)) print('Fitting Parameters:', p_lsq[0])
# 可视化 plt.plot(x_points, real_func(x_points), label='real') plt.plot(x_points, fit_func(p_lsq[0], x_points), label='fitted curve') plt.plot(x, y, 'bo', label='noise') plt.legend() return p_lsq

M=0

# M=0p_lsq_0 = fitting(M=0)
Fitting Parameters: [0.02515259]

M=1

# M=1p_lsq_1 = fitting(M=1)
Fitting Parameters: [-1.50626624 0.77828571]

M=3

# M=3p_lsq_3 = fitting(M=3)
Fitting Parameters: [ 2.21147559e+01 -3.34560175e+01 1.13639167e+01 -2.82318048e-02]

M=9

# M=9p_lsq_9 = fitting(M=9)
Fitting Parameters: [-1.70872086e+04 7.01364939e+04 -1.18382087e+05 1.06032494e+05
-5.43222991e+04 1.60701108e+04 -2.65984526e+03 2.12318870e+02
-7.15931412e-02 3.53804263e-02]

当 M=9 时,多项式曲线通过了每个数据点,但是造成了过拟合

正则化

结果显示过拟合, 引入正则化项(regularizer),降低过拟合

回归问题中,损失函数是平方损失,正则化可以是参数向量的 L2 范数,也可以是 L1 范数。

  • L1: regularization*abs(p)

  • L2: 0.5 * regularization * np.square(p)

regularization = 0.0001def residuals_func_regularization(p, x, y): ret = fit_func(p, x) - y ret = np.append(ret, np.sqrt(0.5 * regularization * np.square(p))) # L2范数作为正则化项 return ret
# 最小二乘法,加正则化项p_init = np.random.rand(9 + 1)p_lsq_regularization = leastsq( residuals_func_regularization, p_init, args=(x, y))
plt.plot(x_points, real_func(x_points), label='real')plt.plot(x_points, fit_func(p_lsq_9[0], x_points), label='fitted curve')plt.plot( x_points, fit_func(p_lsq_regularization[0], x_points), label='regularization')plt.plot(x, y, 'bo', label='noise')plt.legend()

参考资料

[1] 《统计学习方法》: https://baike.baidu.com/item/统计学习方法/10430179
[2] 黄海广: https://github.com/fengdu78
[3] github: https://github.com/fengdu78/lihang-code
[4] wzyonggege: https://github.com/wzyonggege/statistical-learning-method
[5] WenDesi: https://github.com/WenDesi/lihang_book_algorithm
[6] 火烫火烫的: https://blog.csdn.net/tudaodiaozhale



关于本站

机器学习初学者”公众号由是黄海广博士创建,黄博个人知乎粉丝23000+,github排名全球前100名(33000+)。本公众号致力于人工智能方向的科普性文章,为初学者提供学习路线和基础资料。原创作品有:吴恩达机器学习个人笔记、吴恩达深度学习笔记等。


往期精彩回顾

备注:加入本站微信群或者qq群,请回复“加群

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存