查看原文
其他

第7.6节 CharRNN模型

空字符 月来客栈 2024-01-21

各位朋友大家好,欢迎来到月来客栈,我是掌柜空字符。

本期推送内容目录如下,如果本期内容对你有所帮助,欢迎点赞、转发支持掌柜!

  • 7.6 CharRNN模型
    • 7.6.1 任务构造原理
    • 7.6.2 数据预处理
    • 7.6.3 古诗生成任务
    • 7.6.4 小结
  • 引用

7.6 CharRNN模型

经过前面几节内容的介绍,我们已经清楚了RNN模型及其变体的相关原理,并且在第7.2节内容中笔者也通过两个实例详细介绍了RNN中多对一任务的构建流程。在本节内容中,笔者将会以古诗词生成为例来介绍了RNN中的多对多任务类型,即图7-3中的第3种情况。

7.6.1 任务构造原理

对于接下来要介绍的古诗生成模型其本质上就是一个简单的RNN模型,也被称为字符级循环神经网络CharRNN[1]。CharRNN通过将序列作为模型输入,将作为标签来训练模型,整个网络结构如图7-12所示。

图 7-12 古诗生成模型网络结构图

如图7-12所示,最下面为原始输入(Src Input),在转换为词表中的索引后便输入到词嵌入层(Embedding Layer)中。简单来讲,词嵌入层是一个包含有列的网络层,其中表示词表中词的数量,表示向量的维度,即词嵌入层的作用是将词表中的每个词通过一个维向量来进行表示。更多关于词嵌入层的内容将在第9.5节中进行介绍。

在经过词嵌入层的处理之后再将该结果输入到循环神经网络中;然后再将循环神经网络输出结果中的每个时刻进行分类处理,其分类类别数便是词表的长度,因为这里的预测结果是词表中的其中一个词;最后将模型的预测结果同正确标签进行损失计算并完成整个模型的训练。

当模型训练完成之后,可以通过给模型输入一个序列片段来循环完成固定长度序列的生成任务,整个预测过程原理如图7-13所示。

图 7-13 古诗生成预测网络结构图

如图7-13所示便是整个模型的预测过程。在图7-13的最左侧为预测的第1个时刻,输入为“墙”预测结果为“角”;中间为第2个时刻,输入为“墙”以及第1个时刻的预测结果“角”组成的序列,预测结果只取最后一个时刻的“数”;最右侧为第3个时刻,输入为“墙”以及第1个和第2个时刻各自的预测结果组成的序列,预测结果同样只取最后一个时刻“枝”。最后继续迭代循环,直到预测生成的整个序列长度满足预设条件为止。

7.6.2 数据预处理

1. 数据集介绍

在清楚整个模型的训练和预测过程后,我们再来如何从零构建模型训练所需要的数据集。这里我们所使用到的是一个全唐诗[2]的数据集,一共有58个json文件共计大约5.8万余首古诗。在每个json文件中,文本内容的存储形式如下所示。

1 [{"author""王安石",
2   "paragraphs": ["墙角数枝梅,凌寒独自开。","遥知不是雪,为有暗香来。"],
3   "title""梅花",
4   "id""ae7391fc-aef5-4f59-ae25-a7e7a9ee0858"},
5  {"author""佚名",
6   "paragraphs": ["自伯东去,首如飞蓬。","岂无膏沐,谁适为容。"],
7   "title""诗经·国风·卫风",
8   "id""0f0b345d-c074-4ec7-bde1-e28438712b7b"}]

从上述结果可以看出,整个json文件的最外层是一个列表,列表中的每个元素便是一个包含有一首古诗的字段,后续我们将只取每首古诗中的paragraphs来构建数据集。

2. 预处理流程

在正式介绍如何构建数据集之前我们先通过一张图了解一下整体的构建流程。假如现在有两个样本构成了一个小批量,那么其整个数据的处理流程如图7-14所示。

图 7-14 数据集构建流程图

注:图7-14中的词表是以整个训练集为语料构建而成,并非只由上述两个样本构建。

如图7-14所示,首先我们需要将原始json格式的语料抽取出出来;然后再以此为基础对句子进行分词(字)并构建词表;接着再将样本句子中的每个词转换为词表中对应的索引序号得到原始输入Src,并同时将原始输入向左平移一位得到真实标签Tgt;最后在输入模型之前再 对其进行填充处理以使得每个小批量中所有样本的长度一致。

3. 格式化样本和 Tokenize

首先,我们定义一个类TangShi并继承自在第7.2.4节中介绍的TouTiaoNews类以复用其中的部分方法,同时初始化原始数据的相关存储路径,示例代码如下所示:

1 class TangShi(TouTiaoNews):
2     DATA_DIR = os.path.join(DATA_HOME, 'peotry_tang')
3     FILE_PATH = [os.path.join(DATA_DIR, 'poet.tang.0-55.json'),  
4                  os.path.join(DATA_DIR, 'poet.tang.56-56.json'), 
5                  os.path.join(DATA_DIR, 'poet.tang.57-57.json')] 
6     def __init__(self, *args, **kwargs):
7         super(TangShi, self).__init__(*args, **kwargs)
8         self.ends = [self.vocab.stoi["。"], self.vocab.stoi["?"]]

在上述代码中,第2行用来指定原始数据存储路径;第3~5行用来指定文件名并同时划分了训练集(poet.tang.0.json~poet.tang.55000.json)、验证集(poet.tang.56000.json~poet.tang.56000.json)和测试集(poet.tang.57000.json~poet.tang.57000.json),后续将解析其中的序号来读取相应的原始文件;第8行用来指定可能的结束符,用于生成序列时的停止条件之一。

进一步,定义load_raw_data方法来完成原始所有数据的载入,示例代码如下所示:

 1     def load_raw_data(self, file_path=None):
 2 
 3         def read_json_data(file_path):
 4             samples, labels = [], []
 5             with open(file_path, encoding='utf-8'as f:
 6                 data = json.loads(f.read())
 7                 for item in data:
 8                     content = "".join(item['paragraphs'])
 9                     if not skip(content):
10                         samples.append(content) 
11                         labels.append(content[1:] + content[-1])  # 向左平移 
12             return samples, labels
13 
14         file_name = file_path.split(os.path.sep)[-1]
15         start, end = file_name.split('.')[2].split('-')
16         all_samples, all_labels = [], []
17         for i in range(int(start), int(end) + 1):
18             file_path = os.path.join(self.DATA_DIR, f'poet.tang.{i * 1000}.json')
19             samples, labels = read_json_data(file_path)  
20             all_samples += samples  
21             all_labels += labels
22         return all_samples, all_labels

在上述代码中,第3~12行是定义一个辅助函数来读取单个的原始json文件,其中第9行为根据相应条件来判断是否将部分内容过滤,第10~11行是构造对应的输入和标签;第14~16行为根据传入的参数提取文件对应序号;第17~21行为根据拼接的文件名循环读取原始json文件;第22行为返回所有格式化后的结果。

在完成上述load_raw_data方法的实现之后,在实例化类TangShi时便可同时根据训练集完成词表的构建,详见第7.2.4节中类TouTiaoNews的初始化方法。

为你认可的知识付费,欢迎订阅本专栏阅读更多优质内容!

4. 转换为索引

在完成词表构建之后,下一步则是需要将原始古诗进行分词处理,并将其转换为词表中对应的索引,示例代码如下所示:

继续滑动看下一个

第7.6节 CharRNN模型

空字符 月来客栈
向上滑动看下一个

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存