查看原文
其他

使用sklearn随机森林算法实现手写数字识别

gloomyfish OpenCV学堂 2020-02-04

一:随机森林算法是怎么工作的

随机森林(random forest)是2001年提出来同时支持数据的回归与分类预测算法,在具体了解随机森林算法之前,首先看一下决策树算法(Decision Tree)决策树算法通过不断的分支条件筛选,最终预测分类做出决定,举个简单的例子,你去找工作,对方给了你一个offer,下面可能就是你决定是否最终接受或者拒绝offer一系列条件就是内部节点(矩形)最终的决定就是外部节点(叶子-椭圆)

后你自己可能一个人根据上述条件决定接受了offer,但是有时候你还很不确定,你就会去很随机的问问你周围的几个朋友,他们也会根据你的情况与掌握的信息作出一系列的决策,做个形象的比喻,他们就是一棵棵单独存在的决策树,最终你根据这些结果决定接受还是拒绝offer,前一种情况你自己做出接受还是拒绝offer就叫决策树算法,后面一种情况,你一个人拿不定主意,还会随机问你周围的几个朋友一起给你参谋,最终做出接受还是拒绝offer的决定方式,你的那些朋友也是一棵棵单独存在的决策树,他们合在一起做决定,这个就叫做随机森林

当你在使用随机森林做决定时候,有时候分支条件太多,有些不是决定因素的分支条件其实你可以不考虑的,比如在决定是否接受或者拒绝offer的时候你可能不会考虑公司是否有程序员鼓励师(啊!!!!),这个时候需要对这么小分支看成噪声,进行剪枝算法处理生成决策树、最终得到随机森林。同时随机森林的规模越大(决策树越多)、它的决策准确率也越高。随机森林算法在金融风控分析、股票交易数据分析、电子商务等领域均有应用。

二:sklearn中随机森林算法函数使用

基于sklearn中随机森林算法函数创建随机森林实现mnist手写数字识别,完整的代码实现如下:

  1. from sklearn.ensemble import RandomForestClassifier

  2. from sklearn.metrics import accuracy_score

  3. import tensorflow.examples.tutorials.mnist.input_data as input_data

  4. data_dir = 'MNIST_data/'

  5. mnist = input_data.read_data_sets(data_dir, one_hot=False)

  6. batch_size = 50000

  7. batch_x, batch_y = mnist.train.next_batch(batch_size)

  8. test_x = mnist.test.images[:10000]

  9. text_y = mnist.test.labels[:10000]

  10. print("start random forest")

  11. for i in range(10, 200, 10):

  12.    clf_rf = RandomForestClassifier(n_estimators=i)

  13.    clf_rf.fit(batch_x, batch_y)

  14.    y_pred_rf = clf_rf.predict(test_x)

  15.    acc_rf = accuracy_score(text_y, y_pred_rf)

  16.    print("n_estimators = %d, random forest accuracy: %f"%(i, acc_rf))

n_estimators 

表示树的数量,从运行结果可以看出,随着随机森林树的数目增加,预测的准确率也在不断的提升


【推荐阅读】

OpenCV Gabor滤波器实现纹理提取与缺陷分析

OpenCV中如何获得物体的主要方向

tensorflow中实现神经网络训练手写数字数据集mnist

新课程发布 - 《tensorflow零基础入门视频教程》

使用Tensorflow Object Detection API实现对象检测

为山者基于一篑之土,以成千丈之峭 凿井者起于三寸之坎,以就万仞之深


关注【OpenCV学堂】

长按或者扫码二维码即可关注


    您可能也对以下帖子感兴趣

    文章有问题?点此查看未经处理的缓存