第4.6节 实例分析手写体识别
各位朋友大家好,欢迎来到月来客栈,我是掌柜空字符。
本期推送内容目录如下,如果你觉得本期内容对你所有帮助欢迎点个赞、关个注、下回更新不迷路。
4.6 实例分析手写体识别 4.6.1数据预处理 4.6.2 模型选择 4.6.3 模型测试 4.6.4 小结
4.6 实例分析手写体识别
经过前面几节的介绍,我们对模型的改善与泛化已经有了一定的认识。下面笔者就通过一个实际的手写体分类任务进行示范,介绍一下常见的操作流程。同时顺便介绍一下sklearn和matplotlib中常见方法的使用,完整代码见Book/Chapter04/05_digits_classification.py
文件。
4.6.1数据预处理
在4.1.3节中,笔者详细介绍了为何需要对输入特征进行标准化操作,以及一种常见的标准化方法。接下来,就来看一下标准化在模型训练过程中的具体流程。
1. 载入数据集
首先,需要载入在训练模型时所用到的数据集。这里以sklearn中常见的手写体数据集load_digits
为例,代码如下:
def load_data():
data = load_digits()
x, y = data.data, data.target
load_digits
数据集一共包含1797个样本共10个类别,每个样本包含64个特征维度。在载入完成后还可以对其进行可视化,如图4-23所示。
2. 划分数据集
其次,在开始进行标准化之前,需要将数据集分割成训练集和测试集两部分,这里可以借助sklearn中的train_test_split
方法来完成,代码如下:
from sklearn.model_selection import train_test_split
def load_data():
#此处接“1.载入数据集”中代码
x_train, x_test, y_train, y_test = \
train_test_split(x, y, test_size=0.3, random_state=20)
在上述代码中,第5行中test_size=0.3
表示测试集的比例为30%,random_state=20
表示设置一种状态值,它的作用是使每次划分的结果都一样,同时也可以设置其他值。同时,在sklearn中对于包含随机操作的函数或者方法,一般都有这个参数,固定下来的目的是便于其他人复现你的结果。
3. 对训练集标准化
然后,对训练集进行标准化,并保存标准化过程中计算得到的相关参数。例如在以4.2.3节中的方法进行标准化时,就需要保存每个维度对应的均值μ和标准差σ。这里可以借助sklearn中的StandardScaler
方法来完成,代码如下:
from sklearn.preprocessing import StandardScaler
def load_data():
#此处接“2.划分数据集”中代码
ss = StandardScaler()
x_train = ss.fit_transform(x_train)
在上述代码中,第4行用来定义4.2.3节中的标准化方法,第5行先计算每个维度需要用到的均值和方差,然后对每个维度进行标准化,同时,第5行也可以分开来写,代码如下:
def load_data():
#此处接“2.划分数据集”中代码
ss = StandardScaler()
ss.fit(x_train) #先计算每3个维度需要用到的均值和方差
x_train = ss.transform(x_train)#再对每个维度进行标准化
4. 对测试集标准化
最后,利用在训练集上计算得到的参数,对测试集(及未来的新数据)进行标准化,代码如下: