查看原文
其他

第4.6节 实例分析手写体识别

空字符 月来客栈 2024-01-19

各位朋友大家好,欢迎来到月来客栈,我是掌柜空字符。

本期推送内容目录如下,如果你觉得本期内容对你所有帮助欢迎点个赞、关个注、下回更新不迷路

  • 4.6 实例分析手写体识别
    • 4.6.1数据预处理
    • 4.6.2 模型选择
    • 4.6.3 模型测试
    • 4.6.4 小结

4.6 实例分析手写体识别

经过前面几节的介绍,我们对模型的改善与泛化已经有了一定的认识。下面笔者就通过一个实际的手写体分类任务进行示范,介绍一下常见的操作流程。同时顺便介绍一下sklearn和matplotlib中常见方法的使用,完整代码见Book/Chapter04/05_digits_classification.py文件。

4.6.1数据预处理

在4.1.3节中,笔者详细介绍了为何需要对输入特征进行标准化操作,以及一种常见的标准化方法。接下来,就来看一下标准化在模型训练过程中的具体流程。

1. 载入数据集

首先,需要载入在训练模型时所用到的数据集。这里以sklearn中常见的手写体数据集load_digits为例,代码如下:

def load_data():
    data = load_digits()
    x, y = data.data, data.target

load_digits数据集一共包含1797个样本共10个类别,每个样本包含64个特征维度。在载入完成后还可以对其进行可视化,如图4-23所示。

图 4-23 手写体可视化

2. 划分数据集

其次,在开始进行标准化之前,需要将数据集分割成训练集和测试集两部分,这里可以借助sklearn中的train_test_split方法来完成,代码如下:

from sklearn.model_selection import train_test_split
def load_data():
    #此处接“1.载入数据集”中代码
    x_train, x_test, y_train, y_test = \
            train_test_split(x, y, test_size=0.3, random_state=20)

在上述代码中,第5行中test_size=0.3表示测试集的比例为30%,random_state=20表示设置一种状态值,它的作用是使每次划分的结果都一样,同时也可以设置其他值。同时,在sklearn中对于包含随机操作的函数或者方法,一般都有这个参数,固定下来的目的是便于其他人复现你的结果。

3. 对训练集标准化

然后,对训练集进行标准化,并保存标准化过程中计算得到的相关参数。例如在以4.2.3节中的方法进行标准化时,就需要保存每个维度对应的均值μ和标准差σ。这里可以借助sklearn中的StandardScaler方法来完成,代码如下:

from sklearn.preprocessing import StandardScaler
def load_data():
    #此处接“2.划分数据集”中代码
    ss = StandardScaler()
    x_train = ss.fit_transform(x_train)

在上述代码中,第4行用来定义4.2.3节中的标准化方法,第5行先计算每个维度需要用到的均值和方差,然后对每个维度进行标准化,同时,第5行也可以分开来写,代码如下:

def load_data():
    #此处接“2.划分数据集”中代码
    ss = StandardScaler()
    ss.fit(x_train) #先计算每3个维度需要用到的均值和方差 
    x_train = ss.transform(x_train)#再对每个维度进行标准化

4. 对测试集标准化

最后,利用在训练集上计算得到的参数,对测试集(及未来的新数据)进行标准化,代码如下:

继续滑动看下一个

第4.6节 实例分析手写体识别

空字符 月来客栈
向上滑动看下一个

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存