利用人口普查收入数据集预测收入
大家好,今天的文章主要介绍如何利用人口普查收入数据集(Census Income Data Set)预测收入。
人口普查收入数据集(Census Income Data Set)包含 48,000 多个样本,其属性包含年龄,职业,教育和收入(二元标签,> 50K 或 <= 50K)。该数据集分为大约 32,000 个训练样本和 16,000 个测试样本。
这里,我们使用 wide and deep model 来预测收入标签。Wide 模型能够记忆具有大量特征的数据之间的交互,但无法在新数据上概括这些学习的交互。Deep 模型可以很好地扩展,但无法在数据中对一些异常情况进行学习。Wide and deep 模型将两种模型结合起来,能够在对异常情况进行学习的时候进行概括。
本示例代码的目的是允许模型在合理的时间内对人口普查收入数据集(Census Income Data Set)进行训练。您会注意到 Deep 模型的性能几乎与此数据集上的 wide and deep 模型一样好。对于具有高基数特征的大型数据集,其中每个特征具有数以百万计甚至数十亿的唯一可能值(这是 Wide 模型的特征),wide and deep 模型确实是一个福音。
最后还有一个关键点。作为建模者和开发人员,请考虑如何使用此数据集以及模型预测可能带来的潜在益处和危害。像这样的一个模型可能会加剧社会上的偏见和差异。 这是一个与您想要解决的问题息息相关的功能,还是会带入偏见? 有关更多信息,请阅读 ML fairness(https://developers.google.com/machine-learning/fairness-overview/)。
本文中的代码示例使用高级别 tf.estimator.Estimator API。此 API 非常适合快速迭代,无需进行重大代码检修,就能快速使模型适用于您自己的数据集。它允许您从单一工作者训练状态转移到分布式训练,并且可以轻松导出模型二进制文件来进行预测。
Estimator 的输入函数使用 tf.contrib.data.TextLineDataset,它会创建一个 Dataset 对象。Dataset API 可以轻松地将转换(map,batch,shuffle 等)应用于数据。请阅读这里以获取更多信息 Read more here(https://www.tensorflow.org/guide/datasets)。
Estimator 和 Dataset API 受到高度鼓励,可用于快速开发和高效训练。
运行代码
首先请确保已将 models 文件夹添加到 Python 路径中 added the models folder to your Python path; 否则您可能会遇到类似 ImportError:No module named official.wide_deep 之类的错误。
设置
此 Census Income Data Set 人口普查收入数据集训练样本由 UC Irvine Machine Learning Repository 托管。我们已提供了一个下载和清理必要文件的脚本。
python census_dataset.py
这会将文件下载到 / tmp / census_data。想要更改目录,请设置 --data_dir 标志。
训练
您可以在本地运行代码,如下所示:
python census_main.py
默认情况下,模型会保存到 / tmp / census_model,可以使用 --model_dir 标志进行更改。
要运行 wide 或者 deep-only 模型,请将 --model_type 标志设置为 wide 或 deep。 其他标志也是可配置的;请看 census_main.py 了解详情。
对于三种模型中的任意一种模型,其最终精度应超过 83%。
您还可以尝试使用 -inter 和 -intra 标志来探索帧间 / 帧内并行性,以获得更好的性能,如下所示:
python census_main.py --inter=<int> --intra=<int>
TensorBoard
运行 TensorBoard 以检查有关图表和训练进度的详细信息。
tensorboard --logdir=/tmp/census_model
使用 SavedModel 进行推理
您可以使用参数 --export_dir 将模型导出为 TensorFlow SavedModel 格式:
python census_main.py --export_dir /tmp/wide_deep_saved_model
模型完成训练后,使用 saved_model_cli 检查并执行 SavedModel。
请尝试以下命令来检查 SavedModel:
将 $ {TIMESTAMP} 替换为生成的文件夹(例如 1524249124)
# List possible tag_sets. Only one metagraph is saved, so there will be one option.
saved_model_cli show --dir /tmp/wide_deep_saved_model/${TIMESTAMP}/
# Show SignatureDefs for tag_set=serve. SignatureDefs define the outputs to show.
saved_model_cli show --dir /tmp/wide_deep_saved_model/${TIMESTAMP}/ \
--tag_set serve --all
推理
让我们用这个模型来预测两个示例的收入组:
saved_model_cli run --dir /tmp/wide_deep_saved_model/${TIMESTAMP}/ \
--tag_set serve --signature_def="predict" \
--input_examples='examples=[{"age":[46.], "education_num":[10.], "capital_gain":[7688.], "capital_loss":[0.], "hours_per_week":[38.]}, {"age":[24.], "education_num":[13.], "capital_gain":[0.], "capital_loss":[0.], "hours_per_week":[50.]}]'
这就会打印出预测的类和类的概率。0 级别是 <= 50k 组,1 级别是 > 50k 组。
更多 AI 相关阅读: