使用sklearn随机森林算法实现手写数字识别
一:随机森林算法是怎么工作的
随机森林(random forest)是2001年提出来同时支持数据的回归与分类预测算法,在具体了解随机森林算法之前,首先看一下决策树算法(Decision Tree)决策树算法通过不断的分支条件筛选,最终预测分类做出决定,举个简单的例子,你去找工作,对方给了你一个offer,下面可能就是你决定是否最终接受或者拒绝offer一系列条件就是内部节点(矩形)最终的决定就是外部节点(叶子-椭圆)
后你自己可能一个人根据上述条件决定接受了offer,但是有时候你还很不确定,你就会去很随机的问问你周围的几个朋友,他们也会根据你的情况与掌握的信息作出一系列的决策,做个形象的比喻,他们就是一棵棵单独存在的决策树,最终你根据这些结果决定接受还是拒绝offer,前一种情况你自己做出接受还是拒绝offer就叫决策树算法,后面一种情况,你一个人拿不定主意,还会随机问你周围的几个朋友一起给你参谋,最终做出接受还是拒绝offer的决定方式,你的那些朋友也是一棵棵单独存在的决策树,他们合在一起做决定,这个就叫做随机森林
当你在使用随机森林做决定时候,有时候分支条件太多,有些不是决定因素的分支条件其实你可以不考虑的,比如在决定是否接受或者拒绝offer的时候你可能不会考虑公司是否有程序员鼓励师(啊!!!!),这个时候需要对这么小分支看成噪声,进行剪枝算法处理生成决策树、最终得到随机森林。同时随机森林的规模越大(决策树越多)、它的决策准确率也越高。随机森林算法在金融风控分析、股票交易数据分析、电子商务等领域均有应用。
二:sklearn中随机森林算法函数使用
基于sklearn中随机森林算法函数创建随机森林实现mnist手写数字识别,完整的代码实现如下:
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
import tensorflow.examples.tutorials.mnist.input_data as input_data
data_dir = 'MNIST_data/'
mnist = input_data.read_data_sets(data_dir, one_hot=False)
batch_size = 50000
batch_x, batch_y = mnist.train.next_batch(batch_size)
test_x = mnist.test.images[:10000]
text_y = mnist.test.labels[:10000]
print("start random forest")
for i in range(10, 200, 10):
clf_rf = RandomForestClassifier(n_estimators=i)
clf_rf.fit(batch_x, batch_y)
y_pred_rf = clf_rf.predict(test_x)
acc_rf = accuracy_score(text_y, y_pred_rf)
print("n_estimators = %d, random forest accuracy: %f"%(i, acc_rf))
n_estimators
表示树的数量,从运行结果可以看出,随着随机森林树的数目增加,预测的准确率也在不断的提升
【推荐阅读】
tensorflow中实现神经网络训练手写数字数据集mnist
使用Tensorflow Object Detection API实现对象检测
为山者基于一篑之土,以成千丈之峭 凿井者起于三寸之坎,以就万仞之深
关注【OpenCV学堂】
长按或者扫码二维码即可关注