第6.2节 基于K近邻垃圾邮件分类
各位朋友大家好,欢迎来到月来客栈,我是掌柜空字符。
本期推送内容目录如下,如果你觉得本期内容对你所有帮助欢迎点赞、关注支持掌柜!
6.2 基于近邻算法的垃圾邮件分类 6.2.1 复用模型 6.2.2 载入原始文本 6.2.3 制作数据集 6.2.4 训练模型与测试 6.2.5 小结
6.2 基于近邻算法的垃圾邮件分类
在第6.1节内容中,笔者介绍了2中简单的词袋模型,接下来我们以第2种词袋模型表示方法为例,通过近邻算法对垃圾邮件进行分类处理。下面用到的是一个中文的邮件分类数据集,包含垃圾邮件和非垃圾邮件两类,即一个二分类任务。其中ham_5000.utf8
和spam_5000.utf8
这两个文件中分别包含5000封正常邮件和垃圾邮件,文件中每行分别表示一封邮件,示例如下:
“我的意中人是一个盖世英雄,有一天他会踩着七色的云彩来娶我,我猜中了前头,可是我猜不着这结局”世间一切美好都有有效期限吧,坦然面对,接受幸福的彩排。
总地来讲,要完成这一文本分类任务,首先需要载入原始文本并对其中的每个样本进行分词处理,接着通过上面介绍的CountVectorizer
类来完成文本的向量化表示,并制作完成每个样本对应的类别以便构成一个完整的数据集,最后根据近邻算法完成分类任务。不过在正式介绍文本分类任务之前,笔者先来介绍如何对训练好的模型进行复用。
6.2.1 复用模型
在实际的运用环境中,不可能每次在对新数据进行预测时都从头开始训练一个模型。通常,模型在第1次训练完成后会被保存下来。只要后续不需要再对模型做任何改动,在对新数据进行预测时,只需载入已有的模型进行复用[1]。完整代码见Book/Chapter06/C05_bag_of_word_cla.py
文件。
1. 保存模型
首先需要定义一个函数来对传入的模型进行保存,代码如下:
1 import joblib
2 def save_model(model, dir='MODEL', MODEL_NAME='model.pkl'):
3 if not os.path.exists(dir):
4 os.mkdir(dir)
5 path = os.path.join(dir, MODEL_NAME)
6 joblib.dump(model,path )
在上述代码中,第3~4行用来判断当前是否存在MODEL
这个目录。如果不存在,则创建;第5行是根据目录名称和模型名称拼接模型的保存路径;第6行用来将传入的模型以MODEL_NAME
的名称保存到MODEL
目录中。
2. 复用模型
在复用模型之前,需要先定义一个函数来对已有的模型进行载入,代码如下:
1 def load_model(dir='MODEL', MODEL_NAME='model.pkl'):
2 path = os.path.join(dir, MODEL_NAME)
3 if not os.path.exists(path):
4 raise FileNotFoundError(f"{path} 模型不存在,请先训练模型!")
5 model = joblib.load(path)
6 return model
在上述代码中,第2~4行用来判断给定的路径中是否存在一个名为MODEL_NAME
的模型文件。如果不存在,则进行提示。第5~6行用来返回载入后的模型。
同时,值得注意的是,保存模型的时候不仅要对最后的分类或者回归模型进行保存,还要对最开始的数据集预处理模型进行保存。因为新输入的数据一般是原始数据,需要对其进行相应的标准化(这里是向量化)处理,因此也就必然会通过在训练集上得到的参数(如此处的词表)来对新数据进行标准化,所以也需要对标准化时的模型进行保存。
下面,笔者将开始完整介绍如何在训练集上训练模型以及在测试数据上复用模型。
6.2.2 载入原始文本
首先需要完成的便是编写一个函数载入本地文本以及构造每个样本对应的标签。同时,为了方便这部分代码在后续其它地方复用,笔者将其放到了utils
下的dataset
模块中(见Book/utils/dataset.py
文件),实现代码如下:
1 DATA_HOME = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'data')
2 def load_spam():
3 data_spam_dir = os.path.join(DATA_HOME, 'spam')
4 def load_spam_data(file_path=None):
5 texts = []
6 with open(file_path, encoding='utf-8') as f:
7 for line in f:
8 line = line.strip('\n')
9 texts.append(clean_str(line))
10 return texts
11 x_pos = load_spam_data(file_path=os.path.join(data_spam_dir, 'ham_5000.utf8'))
12 x_neg = load_spam_data(file_path=os.path.join(data_spam_dir, 'spam_5000.utf8'))
13 y_pos, y_neg = [1] * len(x_pos), [0] * len(x_neg)
14 x, y = x_pos + x_neg, y_pos + y_neg
15 return x, y
在上述代码中,第1行用于获取当前工程目录下data
目录所在的绝对路径(今后本书中所使用到的数据集均会放到该目录下),其中__file__
为Python中的环境变量用于得到当前文件所在的绝对路径,而os.path.dirname
则是根据当前路径取对应的目录。例如在这里DATA_HOME
的结果将形如:
1 D:/wangcheng/gitR/MachineLearningWithMe/Book/data
第2行是定义load_spam()
函数来载入原始的垃圾邮件数据;第3行是拼接得到垃圾邮件数据集所在的目录,形如
1 D:/wangcheng/gitR/MachineLearningWithMe/Book/data/spam
第4~10行是定义一个辅助函数load_spam_data()
来按行读取本地文件中的所有文本,并去掉每行末尾的换行符,其中函数clean_str()
的作用是去掉一个字符串中的所有非中文字符,最后返回处理好的结果。第11~12行则是分别载入垃圾邮件和非垃圾邮件;第13行则是分别构造正负样本对应的样本标签;第14~15是将最后处理完成的结果进行返回。最终,x
为一个列表,每个元素为一个样本(一条文本记录);y
也为一个列表,每个元素为样本对应的标签。
在完成原始数据载入后,需要进一步对每个样本进行分词处理,以便后续通过词袋模型进行向量化处理。因此还需要定义一个辅助函数来完成这部分功能(同样保存于dataset
模块中),实现代码如下: