查看原文
其他

教程 | 如何使用TensorFlow中的高级API:Estimator、Experiment和Dataset

2017-09-09 机器之心

选自Medium

作者:Peter Roelants

机器之心编译

参与:李泽南、黄小天


近日,背景调查公司 Onfido 研究主管 Peter Roelants 在 Medium 上发表了一篇题为《Higher-Level APIs in TensorFlow》的文章,通过实例详细介绍了如何使用 TensorFlow 中的高级 API(Estimator、Experiment 和 Dataset)训练模型。值得一提的是 Experiment 和 Dataset 可以独立使用。这些高级 API 已被最新发布的 TensorFlow1.3 版收录。


TensorFlow 中有许多流行的库,如 Keras、TFLearn 和 Sonnet,它们可以让你轻松训练模型,而无需接触哪些低级别函数。目前,Keras API 正倾向于直接在 TensorFlow 中实现,TensorFlow 也在提供越来越多的高级构造,其中的一些已经被最新发布的 TensorFlow1.3 版收录。


在本文中,我们将通过一个例子来学习如何使用一些高级构造,其中包括 Estimator、Experiment 和 Dataset。阅读本文需要预先了解有关 TensorFlow 的基本知识。



Experiment、Estimator 和 DataSet 框架和它们的相互作用(以下将对这些组件进行说明)


在本文中,我们使用 MNIST 作为数据集。它是一个易于使用的数据集,可以通过 TensorFlow 访问。你可以在这个 gist 中找到完整的示例代码。使用这些框架的一个好处是我们不需要直接处理图形和会话。


Estimator


Estimator(评估器)类代表一个模型,以及这些模型被训练和评估的方式。我们可以这样构建一个评估器:


  1. return tf.estimator.Estimator(

  2.    model_fn=model_fn,  # First-class function

  3.    params=params,  # HParams

  4.    config=run_config  # RunConfig

  5. )


为了构建一个 Estimator,我们需要传递一个模型函数,一个参数集合以及一些配置。


  • 参数应该是模型超参数的集合,它可以是一个字典,但我们将在本示例中将其表示为 HParams 对象,用作 namedtuple。

  • 该配置指定如何运行训练和评估,以及如何存出结果。这些配置通过 RunConfig 对象表示,该对象传达 Estimator 需要了解的关于运行模型的环境的所有内容。

  • 模型函数是一个 Python 函数,它构建了给定输入的模型(见后文)。


模型函数


模型函数是一个 Python 函数,它作为第一级函数传递给 Estimator。稍后我们就会看到,TensorFlow 也会在其他地方使用第一级函数。模型表示为函数的好处在于模型可以通过实例化函数不断重新构建。该模型可以在训练过程中被不同的输入不断创建,例如:在训练期间运行验证测试。


模型函数将输入特征作为参数,相应标签作为张量。它还有一种模式来标记模型是否正在训练、评估或执行推理。模型函数的最后一个参数是超参数的集合,它们与传递给 Estimator 的内容相同。模型函数需要返回一个 EstimatorSpec 对象——它会定义完整的模型。


EstimatorSpec 接受预测,损失,训练和评估几种操作,因此它定义了用于训练,评估和推理的完整模型图。由于 EstimatorSpec 采用常规 TensorFlow Operations,因此我们可以使用像 TF-Slim 这样的框架来定义自己的模型。


Experiment


Experiment(实验)类是定义如何训练模型,并将其与 Estimator 进行集成的方式。我们可以这样创建一个实验类:


  1. experiment = tf.contrib.learn.Experiment(

  2.    estimator=estimator,  # Estimator

  3.    train_input_fn=train_input_fn,  # First-class function

  4.    eval_input_fn=eval_input_fn,  # First-class function

  5.    train_steps=params.train_steps,  # Minibatch steps

  6.    min_eval_frequency=params.min_eval_frequency,  # Eval frequency

  7.    train_monitors=[train_input_hook],  # Hooks for training

  8.    eval_hooks=[eval_input_hook],  # Hooks for evaluation

  9.    eval_steps=None  # Use evaluation feeder until its empty

  10. )


Experiment 作为输入:


  • 一个 Estimator(例如上面定义的那个)。

  • 训练和评估数据作为第一级函数。这里用到了和前述模型函数相同的概念,通过传递函数而非操作,如有需要,输入图可以被重建。我们会在后面继续讨论这个概念。

  • 训练和评估钩子(hooks)。这些钩子可以用于监视或保存特定内容,或在图形和会话中进行一些操作。例如,我们将通过操作来帮助初始化数据加载器。

  • 不同参数解释了训练时间和评估时间。


一旦我们定义了 experiment,我们就可以通过 learn_runner.run 运行它来训练和评估模型:


  1. learn_runner.run(

  2.    experiment_fn=experiment_fn,  # First-class function

  3.    run_config=run_config,  # RunConfig

  4.    schedule="train_and_evaluate",  # What to run

  5.    hparams=params  # HParams

  6. )


与模型函数和数据函数一样,函数中的学习运算符将创建 experiment 作为参数。


Dataset


我们将使用 Dataset 类和相应的 Iterator 来表示我们的训练和评估数据,并创建在训练期间迭代数据的数据馈送器。在本示例中,我们将使用 TensorFlow 中可用的 MNIST 数据,并在其周围构建一个 Dataset 包装器。例如,我们把训练的输入数据表示为:


  1. # Define the training inputs

  2. def get_train_inputs(batch_size, mnist_data):

  3.    """Return the input function to get the training data.

  4.    Args:

  5.        batch_size (int): Batch size of training iterator that is returned

  6.                          by the input function.

  7.        mnist_data (Object): Object holding the loaded mnist data.

  8.    Returns:

  9.        (Input function, IteratorInitializerHook):

  10.            - Function that returns (features, labels) when called.

  11.            - Hook to initialise input iterator.

  12.    """

  13.    iterator_initializer_hook = IteratorInitializerHook()

  14.    def train_inputs():

  15.        """Returns training set as Operations.

  16.        Returns:

  17.            (features, labels) Operations that iterate over the dataset

  18.            on every evaluation

  19.        """

  20.        with tf.name_scope('Training_data'):

  21.            # Get Mnist data

  22.            images = mnist_data.train.images.reshape([-1, 28, 28, 1])

  23.            labels = mnist_data.train.labels

  24.            # Define placeholders

  25.            images_placeholder = tf.placeholder(

  26.                images.dtype, images.shape)

  27.            labels_placeholder = tf.placeholder(

  28.                labels.dtype, labels.shape)

  29.            # Build dataset iterator

  30.            dataset = tf.contrib.data.Dataset.from_tensor_slices(

  31.                (images_placeholder, labels_placeholder))

  32.            dataset = dataset.repeat(None)  # Infinite iterations

  33.            dataset = dataset.shuffle(buffer_size=10000)

  34.            dataset = dataset.batch(batch_size)

  35.            iterator = dataset.make_initializable_iterator()

  36.            next_example, next_label = iterator.get_next()

  37.            # Set runhook to initialize iterator

  38.            iterator_initializer_hook.iterator_initializer_func = \

  39.                lambda sess: sess.run(

  40.                    iterator.initializer,

  41.                    feed_dict={images_placeholder: images,

  42.                               labels_placeholder: labels})

  43.            # Return batched (features, labels)

  44.            return next_example, next_label

  45.    # Return function and hook

  46.    return train_inputs, iterator_initializer_hook


调用这个 get_train_inputs 会返回一个一级函数,它在 TensorFlow 图中创建数据加载操作,以及一个 Hook 初始化迭代器。


本示例中,我们使用的 MNIST 数据最初表示为 Numpy 数组。我们创建一个占位符张量来获取数据,再使用占位符来避免数据被复制。接下来,我们在 from_tensor_slices 的帮助下创建一个切片数据集。我们将确保该数据集运行无限长时间(experiment 可以考虑 epoch 的数量),让数据得到清晰,并分成所需的尺寸。


为了迭代数据,我们需要在数据集的基础上创建迭代器。因为我们正在使用占位符,所以我们需要在 NumPy 数据的相关会话中初始化占位符。我们可以通过创建一个可初始化的迭代器来实现。创建图形时,我们将创建一个自定义的 IteratorInitializerHook 对象来初始化迭代器:


  1. class IteratorInitializerHook(tf.train.SessionRunHook):

  2.    """Hook to initialise data iterator after Session is created."""

  3.    def __init__(self):

  4.        super(IteratorInitializerHook, self).__init__()

  5.        self.iterator_initializer_func = None

  6.    def after_create_session(self, session, coord):

  7.        """Initialise the iterator after the session has been created."""

  8.        self.iterator_initializer_func(session)


IteratorInitializerHook 继承自 SessionRunHook。一旦创建了相关会话,这个钩子就会调用 call after_create_session,并用正确的数据初始化占位符。这个钩子会通过 get_train_inputs 函数返回,并在创建时传递给 Experiment 对象。


train_inputs 函数返回的数据加载操作是 TensorFlow 操作,每次评估时都会返回一个新的批处理。


运行代码


现在我们已经定义了所有的东西,我们可以用以下命令运行代码:


  1. python mnist_estimator.py --model_dir ./mnist_training --data_dir ./mnist_data


如果你不传递参数,它将使用文件顶部的默认标志来确定保存数据和模型的位置。训练将在终端输出全局步长、损失、精度等信息。除此之外,实验和估算器框架将记录 TensorBoard 可以显示的某些统计信息。如果我们运行:


  1. tensorboard --logdir='./mnist_training'


我们就可以看到所有训练统计数据,如训练损失、评估准确性、每步时间和模型图。



评估精度在 TensorBoard 中的可视化


在 TensorFlow 中,有关 Estimator、Experiment 和 Dataset 框架的示例很少,这也是本文存在的原因。希望这篇文章可以向大家介绍这些架构工作的原理,它们应该采用哪些抽象方法,以及如何使用它们。如果你对它们很感兴趣,以下是其他相关文档。


关于 Estimator、Experiment 和 Dataset 的注释


  • 论文《TensorFlow Estimators: Managing Simplicity vs. Flexibility in High-Level Machine Learning Frameworks》:https://terrytangyuan.github.io/data/papers/tf-estimators-kdd-paper.pdf

  • Using the Dataset API for TensorFlow Input Pipelines:https://www.tensorflow.org/versions/r1.3/programmers_guide/datasets

  • tf.estimator.Estimator:https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator

  • tf.contrib.learn.RunConfig:https://www.tensorflow.org/api_docs/python/tf/contrib/learn/RunConfig

  • tf.estimator.DNNClassifier:https://www.tensorflow.org/api_docs/python/tf/estimator/DNNClassifier

  • tf.estimator.DNNRegressor:https://www.tensorflow.org/api_docs/python/tf/estimator/DNNRegressor

  • Creating Estimators in tf.estimator:https://www.tensorflow.org/extend/estimators

  • tf.contrib.learn.Head:https://www.tensorflow.org/api_docs/python/tf/contrib/learn/Head

  • 本文用到的 Slim 框架:https://github.com/tensorflow/models/tree/master/slim


完整示例


  1. """Script to illustrate usage of tf.estimator.Estimator in TF v1.3"""

  2. import tensorflow as tf

  3. from tensorflow.examples.tutorials.mnist import input_data as mnist_data

  4. from tensorflow.contrib import slim

  5. from tensorflow.contrib.learn import ModeKeys

  6. from tensorflow.contrib.learn import learn_runner

  7. # Show debugging output

  8. tf.logging.set_verbosity(tf.logging.DEBUG)

  9. # Set default flags for the output directories

  10. FLAGS = tf.app.flags.FLAGS

  11. tf.app.flags.DEFINE_string(

  12.    flag_name='model_dir', default_value='./mnist_training',

  13.    docstring='Output directory for model and training stats.')

  14. tf.app.flags.DEFINE_string(

  15.    flag_name='data_dir', default_value='./mnist_data',

  16.    docstring='Directory to download the data to.')

  17. # Define and run experiment ###############################

  18. def run_experiment(argv=None):

  19.    """Run the training experiment."""

  20.    # Define model parameters

  21.    params = tf.contrib.training.HParams(

  22.        learning_rate=0.002,

  23.        n_classes=10,

  24.        train_steps=5000,

  25.        min_eval_frequency=100

  26.    )

  27.    # Set the run_config and the directory to save the model and stats

  28.    run_config = tf.contrib.learn.RunConfig()

  29.    run_config = run_config.replace(model_dir=FLAGS.model_dir)

  30.    learn_runner.run(

  31.        experiment_fn=experiment_fn,  # First-class function

  32.        run_config=run_config,  # RunConfig

  33.        schedule="train_and_evaluate",  # What to run

  34.        hparams=params  # HParams

  35.    )

  36. def experiment_fn(run_config, params):

  37.    """Create an experiment to train and evaluate the model.

  38.    Args:

  39.        run_config (RunConfig): Configuration for Estimator run.

  40.        params (HParam): Hyperparameters

  41.    Returns:

  42.        (Experiment) Experiment for training the mnist model.

  43.    """

  44.    # You can change a subset of the run_config properties as

  45.    run_config = run_config.replace(

  46.        save_checkpoints_steps=params.min_eval_frequency)

  47.    # Define the mnist classifier

  48.    estimator = get_estimator(run_config, params)

  49.    # Setup data loaders

  50.    mnist = mnist_data.read_data_sets(FLAGS.data_dir, one_hot=False)

  51.    train_input_fn, train_input_hook = get_train_inputs(

  52.        batch_size=128, mnist_data=mnist)

  53.    eval_input_fn, eval_input_hook = get_test_inputs(

  54.        batch_size=128, mnist_data=mnist)

  55.    # Define the experiment

  56.    experiment = tf.contrib.learn.Experiment(

  57.        estimator=estimator,  # Estimator

  58.        train_input_fn=train_input_fn,  # First-class function

  59.        eval_input_fn=eval_input_fn,  # First-class function

  60.        train_steps=params.train_steps,  # Minibatch steps

  61.        min_eval_frequency=params.min_eval_frequency,  # Eval frequency

  62.        train_monitors=[train_input_hook],  # Hooks for training

  63.        eval_hooks=[eval_input_hook],  # Hooks for evaluation

  64.        eval_steps=None  # Use evaluation feeder until its empty

  65.    )

  66.    return experiment

  67. # Define model ############################################

  68. def get_estimator(run_config, params):

  69.    """Return the model as a Tensorflow Estimator object.

  70.    Args:

  71.         run_config (RunConfig): Configuration for Estimator run.

  72.         params (HParams): hyperparameters.

  73.    """

  74.    return tf.estimator.Estimator(

  75.        model_fn=model_fn,  # First-class function

  76.        params=params,  # HParams

  77.        config=run_config  # RunConfig

  78.    )

  79. def model_fn(features, labels, mode, params):

  80.    """Model function used in the estimator.

  81.    Args:

  82.        features (Tensor): Input features to the model.

  83.        labels (Tensor): Labels tensor for training and evaluation.

  84.        mode (ModeKeys): Specifies if training, evaluation or prediction.

  85.        params (HParams): hyperparameters.

  86.    Returns:

  87.        (EstimatorSpec): Model to be run by Estimator.

  88.    """

  89.    is_training = mode == ModeKeys.TRAIN

  90.    # Define model's architecture

  91.    logits = architecture(features, is_training=is_training)

  92.    predictions = tf.argmax(logits, axis=-1)

  93.    # Loss, training and eval operations are not needed during inference.

  94.    loss = None

  95.    train_op = None

  96.    eval_metric_ops = {}

  97.    if mode != ModeKeys.INFER:

  98.        loss = tf.losses.sparse_softmax_cross_entropy(

  99.            labels=tf.cast(labels, tf.int32),

  100.            logits=logits)

  101.        train_op = get_train_op_fn(loss, params)

  102.        eval_metric_ops = get_eval_metric_ops(labels, predictions)

  103.    return tf.estimator.EstimatorSpec(

  104.        mode=mode,

  105.        predictions=predictions,

  106.        loss=loss,

  107.        train_op=train_op,

  108.        eval_metric_ops=eval_metric_ops

  109.    )

  110. def get_train_op_fn(loss, params):

  111.    """Get the training Op.

  112.    Args:

  113.         loss (Tensor): Scalar Tensor that represents the loss function.

  114.         params (HParams): Hyperparameters (needs to have `learning_rate`)

  115.    Returns:

  116.        Training Op

  117.    """

  118.    return tf.contrib.layers.optimize_loss(

  119.        loss=loss,

  120.        global_step=tf.contrib.framework.get_global_step(),

  121.        optimizer=tf.train.AdamOptimizer,

  122.        learning_rate=params.learning_rate

  123.    )

  124. def get_eval_metric_ops(labels, predictions):

  125.    """Return a dict of the evaluation Ops.

  126.    Args:

  127.        labels (Tensor): Labels tensor for training and evaluation.

  128.        predictions (Tensor): Predictions Tensor.

  129.    Returns:

  130.        Dict of metric results keyed by name.

  131.    """

  132.    return {

  133.        'Accuracy': tf.metrics.accuracy(

  134.            labels=labels,

  135.            predictions=predictions,

  136.            name='accuracy')

  137.    }

  138. def architecture(inputs, is_training, scope='MnistConvNet'):

  139.    """Return the output operation following the network architecture.

  140.    Args:

  141.        inputs (Tensor): Input Tensor

  142.        is_training (bool): True iff in training mode

  143.        scope (str): Name of the scope of the architecture

  144.    Returns:

  145.         Logits output Op for the network.

  146.    """

  147.    with tf.variable_scope(scope):

  148.        with slim.arg_scope(

  149.                [slim.conv2d, slim.fully_connected],

  150.                weights_initializer=tf.contrib.layers.xavier_initializer()):

  151.            net = slim.conv2d(inputs, 20, [5, 5], padding='VALID',

  152.                              scope='conv1')

  153.            net = slim.max_pool2d(net, 2, stride=2, scope='pool2')

  154.            net = slim.conv2d(net, 40, [5, 5], padding='VALID',

  155.                              scope='conv3')

  156.            net = slim.max_pool2d(net, 2, stride=2, scope='pool4')

  157.            net = tf.reshape(net, [-1, 4 * 4 * 40])

  158.            net = slim.fully_connected(net, 256, scope='fn5')

  159.            net = slim.dropout(net, is_training=is_training,

  160.                               scope='dropout5')

  161.            net = slim.fully_connected(net, 256, scope='fn6')

  162.            net = slim.dropout(net, is_training=is_training,

  163.                               scope='dropout6')

  164.            net = slim.fully_connected(net, 10, scope='output',

  165.                                       activation_fn=None)

  166.        return net

  167. # Define data loaders #####################################

  168. class IteratorInitializerHook(tf.train.SessionRunHook):

  169.    """Hook to initialise data iterator after Session is created."""

  170.    def __init__(self):

  171.        super(IteratorInitializerHook, self).__init__()

  172.        self.iterator_initializer_func = None

  173.    def after_create_session(self, session, coord):

  174.        """Initialise the iterator after the session has been created."""

  175.        self.iterator_initializer_func(session)

  176. # Define the training inputs

  177. def get_train_inputs(batch_size, mnist_data):

  178.    """Return the input function to get the training data.

  179.    Args:

  180.        batch_size (int): Batch size of training iterator that is returned

  181.                          by the input function.

  182.        mnist_data (Object): Object holding the loaded mnist data.

  183.    Returns:

  184.        (Input function, IteratorInitializerHook):

  185.            - Function that returns (features, labels) when called.

  186.            - Hook to initialise input iterator.

  187.    """

  188.    iterator_initializer_hook = IteratorInitializerHook()

  189.    def train_inputs():

  190.        """Returns training set as Operations.

  191.        Returns:

  192.            (features, labels) Operations that iterate over the dataset

  193.            on every evaluation

  194.        """

  195.        with tf.name_scope('Training_data'):

  196.            # Get Mnist data

  197.            images = mnist_data.train.images.reshape([-1, 28, 28, 1])

  198.            labels = mnist_data.train.labels

  199.            # Define placeholders

  200.            images_placeholder = tf.placeholder(

  201.                images.dtype, images.shape)

  202.            labels_placeholder = tf.placeholder(

  203.                labels.dtype, labels.shape)

  204.            # Build dataset iterator

  205.            dataset = tf.contrib.data.Dataset.from_tensor_slices(

  206.                (images_placeholder, labels_placeholder))

  207.            dataset = dataset.repeat(None)  # Infinite iterations

  208.            dataset = dataset.shuffle(buffer_size=10000)

  209.            dataset = dataset.batch(batch_size)

  210.            iterator = dataset.make_initializable_iterator()

  211.            next_example, next_label = iterator.get_next()

  212.            # Set runhook to initialize iterator

  213.            iterator_initializer_hook.iterator_initializer_func = \

  214.                lambda sess: sess.run(

  215.                    iterator.initializer,

  216.                    feed_dict={images_placeholder: images,

  217.                               labels_placeholder: labels})

  218.            # Return batched (features, labels)

  219.            return next_example, next_label

  220.    # Return function and hook

  221.    return train_inputs, iterator_initializer_hook

  222. def get_test_inputs(batch_size, mnist_data):

  223.    """Return the input function to get the test data.

  224.    Args:

  225.        batch_size (int): Batch size of training iterator that is returned

  226.                          by the input function.

  227.        mnist_data (Object): Object holding the loaded mnist data.

  228.    Returns:

  229.        (Input function, IteratorInitializerHook):

  230.            - Function that returns (features, labels) when called.

  231.            - Hook to initialise input iterator.

  232.    """

  233.    iterator_initializer_hook = IteratorInitializerHook()

  234.    def test_inputs():

  235.        """Returns training set as Operations.

  236.        Returns:

  237.            (features, labels) Operations that iterate over the dataset

  238.            on every evaluation

  239.        """

  240.        with tf.name_scope('Test_data'):

  241.            # Get Mnist data

  242.            images = mnist_data.test.images.reshape([-1, 28, 28, 1])

  243.            labels = mnist_data.test.labels

  244.            # Define placeholders

  245.            images_placeholder = tf.placeholder(

  246.                images.dtype, images.shape)

  247.            labels_placeholder = tf.placeholder(

  248.                labels.dtype, labels.shape)

  249.            # Build dataset iterator

  250.            dataset = tf.contrib.data.Dataset.from_tensor_slices(

  251.                (images_placeholder, labels_placeholder))

  252.            dataset = dataset.batch(batch_size)

  253.            iterator = dataset.make_initializable_iterator()

  254.            next_example, next_label = iterator.get_next()

  255.            # Set runhook to initialize iterator

  256.            iterator_initializer_hook.iterator_initializer_func = \

  257.                lambda sess: sess.run(

  258.                    iterator.initializer,

  259.                    feed_dict={images_placeholder: images,

  260.                               labels_placeholder: labels})

  261.            return next_example, next_label

  262.    # Return function and hook

  263.    return test_inputs, iterator_initializer_hook

  264. # Run script ##############################################

  265. if __name__ == "__main__":

  266.    tf.app.run(

  267.        main=run_experiment

  268.    )


推理训练模式


在训练模型后,我们可以运行 estimateator.predict 来预测给定图像的类别。可使用以下代码示例。 


  1. """Script to illustrate inference of a trained tf.estimator.Estimator.

  2. NOTE: This is dependent on mnist_estimator.py which defines the model.

  3. mnist_estimator.py can be found at:

  4. https://gist.github.com/peterroelants/9956ec93a07ca4e9ba5bc415b014bcca

  5. """

  6. import numpy as np

  7. import skimage.io

  8. import tensorflow as tf

  9. from mnist_estimator import get_estimator

  10. # Set default flags for the output directories

  11. FLAGS = tf.app.flags.FLAGS

  12. tf.app.flags.DEFINE_string(

  13.    flag_name='saved_model_dir', default_value='./mnist_training',

  14.    docstring='Output directory for model and training stats.')

  15. # MNIST sample images

  16. IMAGE_URLS = [

  17.    'https://i.imgur.com/SdYYBDt.png',  # 0

  18.    'https://i.imgur.com/Wy7mad6.png',  # 1

  19.    'https://i.imgur.com/nhBZndj.png',  # 2

  20.    'https://i.imgur.com/V6XeoWZ.png',  # 3

  21.    'https://i.imgur.com/EdxBM1B.png',  # 4

  22.    'https://i.imgur.com/zWSDIuV.png',  # 5

  23.    'https://i.imgur.com/Y28rZho.png',  # 6

  24.    'https://i.imgur.com/6qsCz2W.png',  # 7

  25.    'https://i.imgur.com/BVorzCP.png',  # 8

  26.    'https://i.imgur.com/vt5Edjb.png',  # 9

  27. ]

  28. def infer(argv=None):

  29.    """Run the inference and print the results to stdout."""

  30.    params = tf.contrib.training.HParams()  # Empty hyperparameters

  31.    # Set the run_config where to load the model from

  32.    run_config = tf.contrib.learn.RunConfig()

  33.    run_config = run_config.replace(model_dir=FLAGS.saved_model_dir)

  34.    # Initialize the estimator and run the prediction

  35.    estimator = get_estimator(run_config, params)

  36.    result = estimator.predict(input_fn=test_inputs)

  37.    for r in result:

  38.        print(r)

  39. def test_inputs():

  40.    """Returns training set as Operations.

  41.    Returns:

  42.        (features, ) Operations that iterate over the test set.

  43.    """

  44.    with tf.name_scope('Test_data'):

  45.        images = tf.constant(load_images(), dtype=np.float32)

  46.        dataset = tf.contrib.data.Dataset.from_tensor_slices((images,))

  47.        # Return as iteration in batches of 1

  48.        return dataset.batch(1).make_one_shot_iterator().get_next()

  49. def load_images():

  50.    """Load MNIST sample images from the web and return them in an array.

  51.    Returns:

  52.        Numpy array of size (10, 28, 28, 1) with MNIST sample images.

  53.    """

  54.    images = np.zeros((10, 28, 28, 1))

  55.    for idx, url in enumerate(IMAGE_URLS):

  56.        images[idx, :, :, 0] = skimage.io.imread(url)

  57.    return images

  58. # Run script ##############################################

  59. if __name__ == "__main__":

  60.    tf.app.run(main=infer)


原文链接:https://medium.com/onfido-tech/higher-level-apis-in-tensorflow-67bfb602e6c0


本文为机器之心编译,转载请联系本公众号获得授权。

✄------------------------------------------------

加入机器之心(全职记者/实习生):hr@jiqizhixin.com

投稿或寻求报道:content@jiqizhixin.com

广告&商务合作:bd@jiqizhixin.com

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存