查看原文
其他

【译】R包介绍:Online Random Forest

顾全 R语言中文社区 2019-04-22

作者:顾全,浙江大学软件工程硕士,现任桃树科技算法工程师

地址:

https://github.com/ZJUguquan/OnlineRandomForest

参与:Cynthia

翻译:本文为天善智能编译,未经容许,禁止转载


介绍

Online Random Forest(ORF) 是由Amir Saffari等人最先提出。之后,Arthur Lui使用Python实现了算法。非常感谢他们的工作。在论文内容和Lui的算法的基础上,我通过R和R6包重构了代码。此外,ORF在此包中的实现,与randomForest结合,使它同时支持增量学习和批量学习,例如:在ORF的基础上构建树,然后通过ORF进行更新。通过这种方法,它将比以前快得多。

安装

if(!require(devtools)) install.packages("devtools")
devtools::install_github("ZJUguquan/OnlineRandomForest")

快速启动

  • 最小举例:增量学习

library(OnlineRandomForest)
param <- list('minSamples'= 1, 'minGain'= 0.1, 'numClasses'= 3, 'x.rng'= dataRange(iris[1:4]))
orf <- ORF$new(param, numTrees = 10)
for (i in 1:150) orf$update(iris[i, 1:4], as.integer(iris[i, 5]))
cat("Mean depth of trees in the forest is:", orf$meanTreeDepth(), "\n")
orf$forest[[2]]$draw()

## Mean depth of trees in the forest is: 3

## Root X4 < 1.21
## |----L: X3 < 2.38
##      |----L: Leaf 1
##      |----R: Leaf 2
## |----R: X4 < 2.15
##      |----L: X1 < 4.92
##           |----L: Leaf 3
##           |----R: Leaf 3
##      |----R: Leaf 3

  • 分类举例

library(OnlineRandomForest)

# data preparation
dat <- iris; dat[,5] <- as.integer(dat[,5])
x.rng <- dataRange(dat[1:4])
param <- list('minSamples'= 2, 'minGain'= 0.2, 'numClasses'= 3, 'x.rng'= x.rng)
ind.gen <- sample(1:150,30) # for generate ORF
ind.updt <- sample(setdiff(1:150, ind.gen), 100) # for uodate ORF
ind.test <- setdiff(setdiff(1:150, ind.gen), ind.updt) # for test

# construct ORF and update
rf <- randomForest::randomForest(factor(Species) ~ ., data = dat[ind.gen, ], maxnodes = 2, ntree = 100)
orf <- ORF$new(param)
orf$generateForest(rf, df.train = dat[ind.gen, ], y.col = "Species")
cat("Mean size of trees in the forest is:", orf$meanTreeSize(), "\n")


## Mean size of trees in the forest is: 3


for (i in ind.updt) {
 orf$update(dat[i, 1:4], dat[i, 5])
}
cat("After update, mean size of trees in the forest is:", orf$meanTreeSize(), "\n")


## After update, mean size of trees in the forest is: 11.9


# predict
orf$confusionMatrix(dat[ind.test, 1:4], dat[ind.test, 5], pretty = T)


##
##  
##    Cell Contents
## |-------------------------|
## |                       N |
## |           N / Row Total |
## |           N / Col Total |
## |-------------------------|
##
##  
## Total Observations in Table:  20
##
##  
##              | actual
##   prediction |         1 |         2 |         3 | Row Total |
## -------------|-----------|-----------|-----------|-----------|
##            1 |         4 |         0 |         0 |         4 |
##              |     1.000 |     0.000 |     0.000 |     0.200 |
##              |     1.000 |     0.000 |     0.000 |           |
## -------------|-----------|-----------|-----------|-----------|
##            2 |         0 |         9 |         2 |        11 |
##              |     0.000 |     0.818 |     0.182 |     0.550 |
##              |     0.000 |     1.000 |     0.286 |           |
## -------------|-----------|-----------|-----------|-----------|
##            3 |         0 |         0 |         5 |         5 |
##              |     0.000 |     0.000 |     1.000 |     0.250 |
##              |     0.000 |     0.000 |     0.714 |           |
## -------------|-----------|-----------|-----------|-----------|
## Column Total |         4 |         9 |         7 |        20 |
##              |     0.200 |     0.450 |     0.350 |           |
## -------------|-----------|-----------|-----------|-----------|
##
##


# compare
table(predict(rf, newdata = dat[ind.test,]) == dat[ind.test, 5])


## FALSE  TRUE
##     9    11


table(orf$predicts(X = dat[ind.test,]) == dat[ind.test, 5])


## FALSE  TRUE
##     2    18


  • 回归举例

# data preparation
if(!require(ggplot2)) install.packages("ggplot2")
data("diamonds", package = "ggplot2")
dat <- as.data.frame(diamonds[sample(1:53000,1000), c(1:6,8:10,7)])
for (col in c("cut","color","clarity")) dat[[col]] <- as.integer(dat[[col]]) # Don't forget this
x.rng <- dataRange(dat[1:9])
param <- list('minSamples'= 10, 'minGain'= 1, 'maxDepth' = 10, 'x.rng'= x.rng)
ind.gen <- sample(1:1000, 800)
ind.updt <- sample(setdiff(1:1000, ind.gen), 100)
ind.test <- setdiff(setdiff(1:1000, ind.gen), ind.updt)


# construct ORF
rf <- randomForest::randomForest(price ~ ., data = dat[ind.gen, ], maxnodes = 20, ntree = 100)
orf <- ORF$new(param)
orf$generateForest(rf, df.train = dat[ind.gen, ], y.col = "price")
orf$meanTreeSize()


## [1] 39


# and update
for (i in ind.updt) {
 orf$update(dat[i, 1:9], dat[i, 10])

}
orf$meanTreeSize()


## [1] 105.7


# predict and compare
if(!require(Metrics)) install.packages("Metrics")
preds.rf <- predict(rf, newdata = dat[ind.test,])
Metrics::rmse(preds.rf, dat$price[ind.test])


## [1] 988.8055


preds <- orf$predicts(dat[ind.test, 1:9])
Metrics::rmse(preds, dat$price[ind.test]) # make progress


## [1] 869.9613


其他用途

  • 在 Tree 类中

ta <- Tree$new("abc", NULL, NULL)
tb <- Tree$new(1, Tree$new(36), Tree$new(3))
tc <- Tree$new(89, tb, ta)
tc$draw()

# update tc
tc$right$updateChildren(Tree$new("666"), Tree$new(999))
tc$right$right$updateChildren(Tree$new("666"), Tree$new(999))
tc$draw()


  • 通过random Forest包配置一个Online random Tree,并升级

# data preparation
library(randomForest)
dat1 <- iris; dat1[,5] <- as.integer(dat1[,5])
rf <- randomForest(factor(Species) ~ ., data = dat1, maxnodes = 3)
treemat1 <- getTree(rf, 1, labelVar=F)
treemat1 <- cbind(treemat1, node.ind = 1:nrow(treemat1))
x.rng1 <- dataRange(dat1[1:4])
param1 <- list('minSamples'= 5, 'minGain'= 0.1, 'numClasses'= 3, 'x.rng'= x.rng1)
ind.gen <- sample(1:150,50) # for generate ORT
ind.updt <- setdiff(1:150, ind.gen) # for update ORT

# origin
ort2 <- ORT$new(param1)
ort2$draw()


## Root 1
##  Leaf 1


# generate a tree


ort2$generateTree(treemat1, df.node = dat1[ind.gen,])
ort2$draw()


## Root X3 < 2.45
## |----L: Leaf 1
## |----R: X3 < 4.75
##      |----L: Leaf 2
##      |----R: Leaf 3


# update this tree
for(i in ind.updt) {
 ort2$update(dat1[i,1:4], dat1[i,5])
}
ort2$draw()


## Root X3 < 2.45
## |----L: Leaf 1
## |----R: X3 < 4.75
##      |----L: Leaf 2
##      |----R: X4 < 2.19
##           |----L: X2 < 3.68
##                |----L: X1 < 7.12
##                     |----L: X3 < 4.06
##                          |----L: Leaf 1
##                          |----R: Leaf 3
##                     |----R: Leaf 3
##                |----R: Leaf 1
##           |----R: Leaf 3


大家都在看

2017年R语言发展报告(国内)

R语言中文社区历史文章整理(作者篇)

R语言中文社区历史文章整理(类型篇)


公众号后台回复关键字即可学习

回复 R                  R语言快速入门及数据挖掘 
回复 Kaggle案例  Kaggle十大案例精讲(连载中)
回复 文本挖掘      手把手教你做文本挖掘
回复 可视化          R语言可视化在商务场景中的应用 
回复 大数据         大数据系列免费视频教程 
回复 量化投资      张丹教你如何用R语言量化投资 
回复 用户画像      京东大数据,揭秘用户画像
回复 数据挖掘     常用数据挖掘算法原理解释与应用
回复 机器学习     人工智能系列之机器学习与实践
回复 爬虫            R语言爬虫实战案例分享

    您可能也对以下帖子感兴趣

    文章有问题?点此查看未经处理的缓存