【教程】Python实现随机森林遥感图像分类
本期以landsat提取植被为例,更新如何用python实现随机森林遥感图像分类,原作者知乎王振庆,如对该文有任何疑问,请点击文章最下方阅读原文,与作者沟通交流。
随机森林,顾名思义是用随机的方式建立一个森林,森林里面有很多的决策树组成,随机森林的每一棵决策树之间是没有关联的。
以landsat提取植被为例,其实不论什么影像分什么类,操作都是一样的。
(1) 制作样本
a. 数字矢量化样本标签图
随机森林属于监督分类,监督分类是一定需要样本的。我们在Arcgis(ENVI也可)中目视解译矢量化一些植被与非植被的典型样本,然后【要素转栅格】将矢量数据转为栅格标签图。其中要注意:植被与非植被的值要设置为不同;转栅格的范围要与遥感图像一致。这样做的目的是为了方便抓取与标签图对应位置的遥感图像各波段值。
图1是landset的真彩色图像,图2是数字化样本并转成栅格的标签图像,标签图为单波段灰度图,为了更好地展示,我进行了RGB渲染。其中绿色的为植被样本,紫色的为非植被样本。
b. 样本数据集制作
样本数据集为txt,格式如图3所示。每行的前7个数为landset的7个波段值,Vegetation和Non-Vegetation表示该数据为植被还是非植被。具体制作过程直接上代码,注释很详细。
图3 样本数据集示意图
import gdal
import os
import random
#读取tif数据集
def readTif(fileName):
dataset = gdal.Open(fileName)
if dataset == None:
print(fileName+"文件无法打开")
return dataset
Landset_Path = r"D:\ROI.tif"
LabelPath = r"D:\label.tif"
txt_Path = r"D:\data.txt"
# 读取图像数据
dataset = readTif(Landset_Path)
Tif_width = dataset.RasterXSize #栅格矩阵的列数
Tif_height = dataset.RasterYSize #栅格矩阵的行数
Tif_bands = dataset.RasterCount #波段数
Tif_geotrans = dataset.GetGeoTransform()#获取仿射矩阵信息
Landset_data = dataset.ReadAsArray(0,0,Tif_width,Tif_height)
dataset = readTif(LabelPath)
Label_data = dataset.ReadAsArray(0,0,Tif_width,Tif_height)
# 写之前,先检验文件是否存在,存在就删掉
if os.path.exists(txt_Path):
os.remove(txt_Path)
# 以写的方式打开文件,如果文件不存在,就会自动创建
file_write_obj = open(txt_Path, 'w')
#首先收集植被类别样本,
#遍历所有像素值,
#为植被的像元全部收集。
count = 0
for i in range(Label_data.shape[0]):
for j in range(Label_data.shape[1]):
# 我设置的植被类别在标签图中像元值为1
if(Label_data[i][j] == 1):
var = ""
for k in range(Landset_data.shape[0]):
var = var + str(Landset_data[k][i][j])+","
var = var + "Vegetation"
file_write_obj.writelines(var)
file_write_obj.write('\n')
count = count + 1
#其次收集非植被类别样本,
#因为非植被样本比植被样本多很多,
#所以采用在所有非植被类别中随机选择非植被样本,
#数量与植被样本数量保持一致。
Threshold = count
count = 0
for i in range(10000000000):
X_random = random.randint(0,Label_data.shape[0]-1)
Y_random = random.randint(0,Label_data.shape[1]-1)
# 我设置的非植被类别在标签图中像元值为0
if(Label_data[X_random][Y_random] == 0):
var = ""
for k in range(Landset_data.shape[0]):
var = var + str(Landset_data[k][X_random][Y_random])+","
var = var + "Non-Vegetation"
file_write_obj.writelines(var)
file_write_obj.write('\n')
count = count + 1
if(count == Threshold):
break
file_write_obj.close()
(2) 模型训练
随机森林模型我们采用sklearn库中自带的随机森林模型RandomForestClassifier。具体训练过程直接上代码,注释很详细。
from sklearn.ensemble import RandomForestClassifier
import numpy as np
from sklearn import model_selection
import pickle
# 定义字典,便于来解析样本数据集txt
def Iris_label(s):
it={b'Vegetation':0, b'Non-Vegetation':1}
return it[s]
path=r"D:\data.txt"
SavePath = r"D:\model.pickle"
# 1.读取数据集
data=np.loadtxt(path, dtype=float, delimiter=',', converters={7:Iris_label} )
# converters={7:Iris_label}中“7”指的是第8列:将第8列的str转化为label(number)
# 2.划分数据与标签
x,y=np.split(data,indices_or_sections=(7,),axis=1) #x为数据,y为标签
x=x[:,0:7] #选取前7个波段作为特征
train_data,test_data,train_label,test_label = model_selection.train_test_split(x,y, random_state=1, train_size=0.9,test_size=0.1)
# 3.用100个树来创建随机森林模型,训练随机森林
classifier = RandomForestClassifier(n_estimators=100,
bootstrap = True,
max_features = 'sqrt')
classifier.fit(train_data, train_label.ravel())#ravel函数拉伸到一维
# 4.计算随机森林的准确率
print("训练集:",classifier.score(train_data,train_label))
print("测试集:",classifier.score(test_data,test_label))
# 5.保存模型
#以二进制的方式打开文件:
file = open(SavePath, "wb")
#将模型写入文件:
pickle.dump(classifier, file)
#最后关闭文件:
file.close()
(3) 模型预测
import numpy as np
import gdal
import pickle
#读取tif数据集
def readTif(fileName):
dataset = gdal.Open(fileName)
if dataset == None:
print(fileName+"文件无法打开")
return dataset
#保存tif文件函数
def writeTiff(im_data,im_geotrans,im_proj,path):
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
elif len(im_data.shape) == 2:
im_data = np.array([im_data])
im_bands, im_height, im_width = im_data.shape
#创建文件
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(path, int(im_width), int(im_height), int(im_bands), datatype)
if(dataset!= None):
dataset.SetGeoTransform(im_geotrans) #写入仿射变换参数
dataset.SetProjection(im_proj) #写入投影
for i in range(im_bands):
dataset.GetRasterBand(i+1).WriteArray(im_data[i])
del dataset
RFpath = r"D:\model.pickle"
Landset_Path = r"D:\20130514_ROI.tif"
SavePath = r"D:\save.tif"
dataset = readTif(Landset_Path)
Tif_width = dataset.RasterXSize #栅格矩阵的列数
Tif_height = dataset.RasterYSize #栅格矩阵的行数
Tif_geotrans = dataset.GetGeoTransform()#获取仿射矩阵信息
Tif_proj = dataset.GetProjection()#获取投影信息
Landset_data = dataset.ReadAsArray(0,0,Tif_width,Tif_height)
#调用保存好的模型
#以读二进制的方式打开文件
file = open(RFpath, "rb")
#把模型从文件中读取出来
rf_model = pickle.load(file)
#关闭文件
file.close()
#用读入的模型进行预测
# 在与测试前要调整一下数据的格式
data = np.zeros((Landset_data.shape[0],Landset_data.shape[1]*Landset_data.shape[2]))
for i in range(Landset_data.shape[0]):
data[i] = Landset_data[i].flatten()
data = data.swapaxes(0,1)
# 对调整好格式的数据进行预测
pred = rf_model.predict(data)
# 同样地,我们对预测好的数据调整为我们图像的格式
pred = pred.reshape(Landset_data.shape[1],Landset_data.shape[2])*255
pred = pred.astype(np.uint8)
# 将结果写到tif图像里
writeTiff(pred,Tif_geotrans,Tif_proj,SavePath)
-----END-----
社群交流 / 原创投稿 / 商务合作
(请添加下方小助手微信)
来源:生态遥感笔记
推荐阅读
推荐关注
温馨提示:近期,微信公众号信息流改版。每个用户可以设置 常读订阅号,这些订阅号将以大卡片的形式展示。因此,如果不想错过“测绘之家”的文章,你一定要进行以下操作:进入“测绘之家”公众号 → 点击右上角的 ··· 菜单 → 选择「设为星标」
↓↓↓点击下方“阅读原文”查看更多精彩内容...