查看原文
其他

在 HIGGS 数据集中对希格斯玻色子过程进行分类

Google TensorFlow 2019-02-15

HIGGS Data Set HIGGS 数据集包含有 1100 万个样本,具有 28 个特征,用于分类问题,来区分产生希格斯玻色子的信号过程和不产生希格斯玻色子的后台过程。


我们使用 Gradient Boosted Trees 算法来区分这两个类别。


代码示例使用高级别的 tf.estimator.Estimator 和 tf.data.Dataset。这些 API 非常适合快速迭代,无需进行重大代码检修就能快速使模型适应您自己的数据集。它允许您从单一工作者训练转到分布式训练,并且可以轻松导出模型二进制文件来进行预测。这里,为了进一步简化和更迅速地执行,我们使用实用函数 tf.contrib.estimator.boosted_trees_classifier_train_in_memory。当输入作为内存数据集(如 numpy 数组)提供时,此实用程序功能尤其出彩。


Estimator 的输入函数通常使用 tf.data.Dataset API,它可以处理各种数据控制,诸如流,批处理,转换和重排。但是,boosted_trees_classifier_train_in_memory() 实用程序函数要求将整个数据作为单个批处理提供(即不使用 batch() API)。因此,在本实例中,只使用 Dataset.from_tensors() 将 numpy 数组转换为结构化张量,Dataset.zip() 用于将特征和标签结合在一起。有关数据集的更多参考内容,请在 

https://www.tensorflow.org/guide/datasets 阅读更多内容。



运行代码

首先,确保已将 models 文件夹添加到 Python 路径中 added the models folder to your Python path; 否则,您可能会遭遇到 ImportError:没有名为 official.boosted_trees 的模块,类似这样的错误。


设置

HIGGS Data Set HIGGS 数据集用于训练的示例由 UC Irvine 机器学习库 UC Irvine Machine Learning Repository 托管。我们已经提供了一个下载和清理必要文件的脚本。

python data_download.py


打开此链接 https://archive.ics.uci.edu/ml/datasets/HIGGS 下载文件并将处理过的文件存储在 --data_dir 指定的目录下(默认为 / tmp / higgs_data /)。要更改目标目录,请设置 --data_dir 标志。该目录可以是 TensorFlow 支持的网络存储(如 Google Cloud Storage,gs:// <bucket> / <path> /)。下载到本地临时文件夹的文件大约为 2.8 GB,处理过的文件大约是 0.8 GB,因此必须预留足够的存储空间。


训练

本示例在训练期间会使用大约 3 GB 的 RAM。 您可以在本地运行代码,如下所示:

python train_higgs.py


该模型默认保存为 / tmp / higgs_model,可以使用 --model_dir 标志更改。请注意,每次训练开始之前 model_dir 都会被清理。


模型参数可以通过标志来调整,例如 --n_trees, -  max_depth,--learning_rate 等。查看代码了解详细信息。


当使用默认参数训练时,最终的精度将在 74% 左右,并且在 eval 集上的损失大约为 0.516。


默认情况下,1100 万个样本中前 100 万个会被用于训练,最后 100 万个会用于评估。可以通过标志 --train_start, -  train_count, -  eval_start, -  eval_count 等选择训练 / 评估数据作为索引范围。


TensorBoard

运行 TensorBoard,检查有关图表和训练进度的详细信息。

tensorboard --logdir=/tmp/higgs_model  # set logdir as --model_dir set during training.


使用 SavedModel 进行推理

您可以使用参数 --export_dir 将模型导出为 TensorFlow SavedModel 格式:

python train_higgs.py --export_dir /tmp/higgs_boosted_trees_saved_model


模型完成训练后,使用 saved_model_cli 检查并执行 SavedModel。


请尝试以下命令来检查 SavedModel:

将 $ {TIMESTAMP} 替换为生成的文件夹(例如1524249124)

# List possible tag_sets. Only one metagraph is saved, so there will be one option.
saved_model_cli show --dir /tmp/higgs_boosted_trees_saved_model/${TIMESTAMP}/

# Show SignatureDefs for tag_set=serve. SignatureDefs define the outputs to show.
saved_model_cli show --dir /tmp/higgs_boosted_trees_saved_model/${TIMESTAMP}/ \
   --tag_set serve --all


推理

让我们用这个模型来预测两个示例的收入组。请注意,此模型使用自定义解析模块导出 SavedModel,该模块接受 csv 行作为特征。(每行是一个包含 28 列的示例;这与训练数据不同,请注意不要添加标签列。)

saved_model_cli run --dir /tmp/boosted_trees_higgs_saved_model/${TIMESTAMP}/ \
   --tag_set serve --signature_def="predict" \
   --input_exprs='inputs=["0.869293,-0.635082,0.225690,0.327470,-0.689993,0.754202,-0.248573,-1.092064,0.0,1.374992,-0.653674,0.930349,1.107436,1.138904,-1.578198,-1.046985,0.0,0.657930,-0.010455,-0.045767,3.101961,1.353760,0.979563,0.978076,0.920005,0.721657,0.988751,0.876678", "1.595839,-0.607811,0.007075,1.818450,-0.111906,0.847550,-0.566437,1.581239,2.173076,0.755421,0.643110,1.426367,0.0,0.921661,-1.190432,-1.615589,0.0,0.651114,-0.654227,-1.274345,3.101961,0.823761,0.938191,0.971758,0.789176,0.430553,0.961357,0.957818"]'


这就会打印出预测的类别和类别概率。类似如下所示:

Result for output key class_ids:
[[1]
[0]]
Result for output key classes:
[['1']
['0']]
Result for output key logistic:
[[0.6440273 ]
[0.10902369]]
Result for output key logits:
[[ 0.59288704]
[-2.1007526 ]]
Result for output key probabilities:
[[0.3559727 0.6440273]
[0.8909763 0.1090237]]


请注意,“预测” signature_def 给出的结果和 “classification分类” 或 “serving_default” 不同(更详细)。



更多 AI 相关阅读:



    您可能也对以下帖子感兴趣

    文章有问题?点此查看未经处理的缓存