tensorflow.compat.v1.test.mock.Mock

Here are the examples of the python api tensorflow.compat.v1.test.mock.Mock taken from open source projects. By voting up you can indicate which examples are most useful and appropriate.

75 Examples 7

3 Source : rnn_test.py
with Apache License 2.0
from tensorflow

  def testSpecificLayerTypeArguments(self):
    """Tests arguments for specific layer types (GRU and LSTM)."""
    mock_layer_type = tf.compat.v1.test.mock.Mock()
    with tf.compat.v1.test.mock.patch.object(rnn, '_CELL_TYPE_TO_LAYER_MAPPING',
                                             {'custom-type': mock_layer_type}):
      _make_rnn_layer(
          cell_type='custom-type',
          units=11,
          return_sequences='return-seq-value')
      mock_layer_type.assert_called_once_with(
          units=11, return_sequences='return-seq-value')

  @tf.compat.v1.test.mock.patch.object(tf.keras.layers, 'RNN')

3 Source : rnn_test.py
with Apache License 2.0
from tensorflow

  def testCustomCellProvided(self, mock_rnn_layer_type):
    """Tests behavior when a custom cell type is provided."""
    mock_custom_cell = tf.compat.v1.test.mock.Mock()
    _make_rnn_layer(
        units=[10],
        cell_type=lambda units: mock_custom_cell,
        return_sequences='return-seq-value')
    mock_rnn_layer_type.assert_called_once_with(
        cell=mock_custom_cell, return_sequences='return-seq-value')

  def testMultipleCellsProvided(self):

3 Source : rnn_test.py
with Apache License 2.0
from tensorflow

  def testCustomCellFnProvided(self, mock_rnn_layer_type):
    """Tests behavior when a custom cell function is provided."""
    mock_cell_fn = tf.compat.v1.test.mock.Mock(return_value='custom-cell')
    _make_rnn_layer(
        rnn_cell_fn=mock_cell_fn, return_sequences='return-seq-value')
    mock_rnn_layer_type.assert_called_once_with(
        cell='custom-cell', return_sequences='return-seq-value')


def _mock_logits_layer(kernel, bias):

3 Source : rnn_test.py
with Apache License 2.0
from tensorflow

def _get_mock_head():
  mock_head = multi_head_lib.MultiClassHead(3)
  mock_head.create_estimator_spec = tf.compat.v1.test.mock.Mock(
      return_value=model_fn.EstimatorSpec(None))
  return mock_head


@test_util.run_all_in_graph_and_eager_modes

3 Source : training_test.py
with Apache License 2.0
from tensorflow

  def test_run_task(self):
    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)
    mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)

    with tf.compat.v1.test.mock.patch.object(
        training, '_TrainingExecutor') as mock_executor:
      mock_executor_instance = tf.compat.v1.test.mock.Mock()
      mock_executor.return_value = mock_executor_instance
      training.train_and_evaluate(mock_est, mock_train_spec, mock_eval_spec)
      mock_executor.assert_called_with(
          estimator=mock_est,
          train_spec=mock_train_spec,
          eval_spec=mock_eval_spec)
      self.assertTrue(mock_executor_instance.run.called)

  def test_error_out_if_evaluator_task_id_is_non_zero(self):

3 Source : training_test.py
with Apache License 2.0
from tensorflow

  def test_invalid_estimator(self):
    invalid_estimator = object()
    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)
    mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)

    with self.assertRaisesRegexp(TypeError, _INVALID_ESTIMATOR_MSG):
      training.train_and_evaluate(invalid_estimator, mock_train_spec,
                                  mock_eval_spec)

  def test_fail_fast_if_invalid_eval_spec(self):

3 Source : training_test.py
with Apache License 2.0
from tensorflow

  def test_fail_fast_if_invalid_eval_spec(self):
    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)
    invalid_eval_spec = object()

    with tf.compat.v1.test.mock.patch.object(
        training, '_TrainingExecutor') as mock_executor:
      with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_SPEC_MSG):
        training.train_and_evaluate(mock_est, mock_train_spec,
                                    invalid_eval_spec)

      mock_executor.assert_not_called()


class TrainingExecutorConstructorTest(tf.test.TestCase):

3 Source : training_test.py
with Apache License 2.0
from tensorflow

  def test_train_with_train_hooks(self, unused_mock_server, unused_mock_sleep):
    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
    mock_est.config = self._run_config
    train_spec = training.TrainSpec(
        input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()])
    mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)
    extra_hooks = [_FakeHook()]

    executor = training._TrainingExecutor(
        mock_est, train_spec, mock_eval_spec, train_hooks=extra_hooks)
    self._run_task(executor)

    mock_est.train.assert_called_with(
        input_fn=train_spec.input_fn,
        max_steps=train_spec.max_steps,
        hooks=list(train_spec.hooks) + extra_hooks,
        saving_listeners=tf.compat.v1.test.mock.ANY)

  @tf.compat.v1.test.mock.patch.object(time, 'sleep')

3 Source : training_test.py
with Apache License 2.0
from tensorflow

  def test_no_server_startup_in_google(self, mock_server, unused_mock_sleep):
    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
    mock_est.config = self._run_config
    mock_train_spec = tf.compat.v1.test.mock.Mock(
        spec=training.TrainSpec, hooks=[])
    mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)

    executor = training._TrainingExecutor(mock_est, mock_train_spec,
                                          mock_eval_spec)
    tf_config = {'TF_CONFIG': json.dumps(_TF_CONFIG_FOR_GOOGLE)}
    with tf.compat.v1.test.mock.patch.dict('os.environ', tf_config):
      self._run_task(executor)
      mock_server.assert_not_called()

  def test_fail_with_empty_cluster_spec(self):

3 Source : training_test.py
with Apache License 2.0
from tensorflow

  def test_fail_with_empty_cluster_spec(self):
    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)
    mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)

    mock_est.config = tf.compat.v1.test.mock.PropertyMock(
        spec=run_config_lib.RunConfig)
    mock_est.config.cluster_spec = None
    mock_est.config.master = 'grpc://...'
    mock_est.config.task_type = 'worker'
    mock_est.config.task_id = 2

    with self.assertRaisesRegexp(RuntimeError,
                                 _INVALID_CONFIG_FOR_STD_SERVER_MSG):
      self._run_task(
          training._TrainingExecutor(mock_est, mock_train_spec, mock_eval_spec))

  def test_fail_with_empty_master(self):

3 Source : training_test.py
with Apache License 2.0
from tensorflow

  def test_fail_with_empty_master(self):
    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)
    mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)

    mock_est.config = tf.compat.v1.test.mock.PropertyMock(
        spec=run_config_lib.RunConfig)
    mock_est.config.cluster_spec = tf.train.ClusterSpec(
        {'worker': ['dummy', 'dummy1']})
    mock_est.config.master = ''
    mock_est.config.task_type = 'worker'
    mock_est.config.task_id = 2

    with self.assertRaisesRegexp(RuntimeError,
                                 _INVALID_CONFIG_FOR_STD_SERVER_MSG):
      self._run_task(
          training._TrainingExecutor(mock_est, mock_train_spec, mock_eval_spec))

  @tf.compat.v1.test.mock.patch.object(time, 'sleep')

3 Source : training_test.py
with Apache License 2.0
from tensorflow

  def test_fail_with_empty_task_type(self):
    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)
    mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)

    mock_est.config = tf.compat.v1.test.mock.PropertyMock(
        spec=run_config_lib.RunConfig)
    mock_est.config.cluster_spec = tf.train.ClusterSpec({'worker': ['dummy']})
    mock_est.config.master = 'grpc://...'
    mock_est.config.task_type = ''
    mock_est.config.task_id = 2

    with self.assertRaisesRegexp(RuntimeError,
                                 _INVALID_CONFIG_FOR_STD_SERVER_MSG):
      self._run_task(
          training._TrainingExecutor(mock_est, mock_train_spec, mock_eval_spec))

  def test_fail_with_none_task_id(self):

3 Source : training_test.py
with Apache License 2.0
from tensorflow

  def test_fail_with_none_task_id(self):
    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)
    mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)

    mock_est.config = tf.compat.v1.test.mock.PropertyMock(
        spec=run_config_lib.RunConfig)
    mock_est.config.cluster_spec = tf.train.ClusterSpec({'worker': ['dummy']})
    mock_est.config.master = 'grpc://...'
    mock_est.config.task_type = 'worker'
    mock_est.config.task_id = None

    with self.assertRaisesRegexp(RuntimeError,
                                 _INVALID_CONFIG_FOR_STD_SERVER_MSG):
      self._run_task(
          training._TrainingExecutor(mock_est, mock_train_spec, mock_eval_spec))


class TrainingExecutorRunWorkerTest(_TrainingExecutorTrainingTest,

3 Source : training_test.py
with Apache License 2.0
from tensorflow

  def test_delay_for_worker(self, _):
    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
    mock_est.config = self._run_config
    mock_train_spec = tf.compat.v1.test.mock.Mock(
        spec=training.TrainSpec, hooks=[])
    mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)

    executor = training._TrainingExecutor(mock_est, mock_train_spec,
                                          mock_eval_spec)

    expected_secs = (self._run_config.task_id + 1) * _DELAY_SECS_PER_WORKER
    with tf.compat.v1.test.mock.patch.object(time, 'sleep') as mock_sleep:
      mock_sleep.side_effect = lambda s: self.assertEqual(expected_secs, s)
      self._run_task(executor)
      self.assertTrue(mock_sleep.called)

  @tf.compat.v1.test.mock.patch.object(server_lib, 'Server')

3 Source : training_test.py
with Apache License 2.0
from tensorflow

  def test_delay_disabled_for_worker(self, _):
    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
    mock_est.config = self._run_config.replace(
        experimental_max_worker_delay_secs=0)
    mock_train_spec = tf.compat.v1.test.mock.Mock(
        spec=training.TrainSpec, hooks=[])
    mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)

    executor = training._TrainingExecutor(mock_est, mock_train_spec,
                                          mock_eval_spec)

    with tf.compat.v1.test.mock.patch.object(time, 'sleep') as mock_sleep:
      self._run_task(executor)
      self.assertFalse(mock_sleep.called)


class TrainingExecutorRunChiefTest(_TrainingExecutorTrainingTest,

3 Source : training_test.py
with Apache License 2.0
from tensorflow

  def test_no_delay_for_chief(self, _):
    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
    mock_est.config = self._run_config
    mock_train_spec = tf.compat.v1.test.mock.Mock(
        spec=training.TrainSpec, hooks=[])
    mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)

    executor = training._TrainingExecutor(mock_est, mock_train_spec,
                                          mock_eval_spec)

    with tf.compat.v1.test.mock.patch.object(time, 'sleep') as mock_sleep:
      self._run_task(executor)
      mock_sleep.assert_not_called()


class TrainingExecutorRunMasterTest(tf.test.TestCase):

3 Source : training_test.py
with Apache License 2.0
from tensorflow

  def test_train_with_no_eval_spec_fails(self, mock_server, unused_mock_sleep):
    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
    mock_est.evaluate = lambda *args, **kw: {
        tf.compat.v1.GraphKeys.GLOBAL_STEP: 123
    }
    mock_est.config = self._run_config
    train_spec = training.TrainSpec(
        input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()])
    eval_spec = None

    executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
    with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_SPEC_MSG):
      executor.run_master()

  @tf.compat.v1.test.mock.patch.object(time, 'sleep')

3 Source : training_test.py
with Apache License 2.0
from tensorflow

  def test_evaluate_with_no_eval_spec_fails(self):
    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
    mock_est.latest_checkpoint.return_value = 'latest_it_is'
    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)
    self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec)

    eval_spec = None

    executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec)

    with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_SPEC_MSG):
      executor.run_evaluator()

  def test_evaluate_with_train_hooks(self):

3 Source : training_test.py
with Apache License 2.0
from tensorflow

  def test_errors_out_if_evaluate_returns_empty_dict(self):
    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
    train_spec = training.TrainSpec(input_fn=lambda: 1)
    eval_spec = training.EvalSpec(
        input_fn=(lambda: 1), start_delay_secs=0, throttle_secs=0)
    mock_est.evaluate.return_value = {}

    executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
    with self.assertRaisesRegexp(ValueError, _INVALID_EMPTY_EVAL_RESULT_ERR):
      executor.run_evaluator()

  def test_errors_out_if_evaluate_returns_non_dict(self):

3 Source : training_test.py
with Apache License 2.0
from tensorflow

  def test_errors_out_if_evaluate_returns_non_dict(self):
    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
    train_spec = training.TrainSpec(input_fn=lambda: 1)
    eval_spec = training.EvalSpec(
        input_fn=(lambda: 1), start_delay_secs=0, throttle_secs=0)
    mock_est.evaluate.return_value = 123

    executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
    with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_RESULT_TYPE_ERR):
      executor.run_evaluator()

  def test_errors_out_if_evaluate_returns_dict_without_global_step(self):

3 Source : training_test.py
with Apache License 2.0
from tensorflow

  def test_errors_out_if_evaluate_returns_dict_without_global_step(self):
    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
    train_spec = training.TrainSpec(input_fn=lambda: 1)
    eval_spec = training.EvalSpec(
        input_fn=(lambda: 1), start_delay_secs=0, throttle_secs=0)
    mock_est.evaluate.return_value = {'loss': 123}

    executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
    with self.assertRaisesRegexp(ValueError,
                                 _MISSING_GLOBAL_STEP_IN_EVAL_RESULT_ERR):
      executor.run_evaluator()


class TrainingExecutorRunPsTest(tf.test.TestCase):

3 Source : training_test.py
with Apache License 2.0
from tensorflow

  def test_fail_with_empty_task_type(self):
    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)
    mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)

    mock_est.config = tf.compat.v1.test.mock.PropertyMock(
        spec=run_config_lib.RunConfig)
    mock_est.config.cluster_spec = tf.train.ClusterSpec({'ps': ['dummy']})
    mock_est.config.master = 'grpc://...'
    mock_est.config.task_type = ''
    mock_est.config.task_id = 2

    with self.assertRaisesRegexp(RuntimeError,
                                 _INVALID_CONFIG_FOR_STD_SERVER_MSG):
      training._TrainingExecutor(mock_est, mock_train_spec,
                                 mock_eval_spec).run_ps()

  def test_fail_with_none_task_id(self):

3 Source : training_test.py
with Apache License 2.0
from tensorflow

  def test_fail_with_none_task_id(self):
    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)
    mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)

    mock_est.config = tf.compat.v1.test.mock.PropertyMock(
        spec=run_config_lib.RunConfig)
    mock_est.config.cluster_spec = tf.train.ClusterSpec({'ps': ['dummy']})
    mock_est.config.master = 'grpc://...'
    mock_est.config.task_type = 'ps'
    mock_est.config.task_id = None

    with self.assertRaisesRegexp(RuntimeError,
                                 _INVALID_CONFIG_FOR_STD_SERVER_MSG):
      training._TrainingExecutor(mock_est, mock_train_spec,
                                 mock_eval_spec).run_ps()


class StopAtSecsHookTest(tf.test.TestCase):

3 Source : training_test.py
with Apache License 2.0
from tensorflow

  def test_train_with_no_eval_spec_fails(self):
    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
    train_spec = training.TrainSpec(
        input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()])
    eval_spec = None

    executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)

    with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_SPEC_MSG):
      executor.run_local()

  def test_train_hooks(self):

3 Source : training_test.py
with Apache License 2.0
from tensorflow

  def test_errors_out_if_evaluate_returns_empty_dict(self):
    est = estimator_lib.Estimator(
        model_fn=self._model_fn,
        config=run_config_lib.RunConfig(save_checkpoints_steps=2))
    mock_est = tf.compat.v1.test.mock.Mock(
        spec=estimator_lib.Estimator, wraps=est)
    train_spec = training.TrainSpec(input_fn=self._input_fn)
    eval_spec = training.EvalSpec(
        input_fn=lambda: self._input_fn(repeat=False), throttle_secs=0)
    mock_est.evaluate.return_value = {}

    executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
    with self.assertRaisesRegexp(ValueError, _INVALID_EMPTY_EVAL_RESULT_ERR):
      executor.run_local()

  def test_errors_out_if_evaluate_returns_non_dict(self):

3 Source : training_test.py
with Apache License 2.0
from tensorflow

  def test_errors_out_if_evaluate_returns_non_dict(self):
    est = estimator_lib.Estimator(
        model_fn=self._model_fn,
        config=run_config_lib.RunConfig(save_checkpoints_steps=2))
    mock_est = tf.compat.v1.test.mock.Mock(
        spec=estimator_lib.Estimator, wraps=est)
    train_spec = training.TrainSpec(input_fn=self._input_fn)
    eval_spec = training.EvalSpec(
        input_fn=lambda: self._input_fn(repeat=False), throttle_secs=0)
    mock_est.evaluate.return_value = 123
    executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
    with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_RESULT_TYPE_ERR):
      executor.run_local()

  def test_errors_out_if_evaluate_returns_dict_without_global_step(self):

3 Source : training_test.py
with Apache License 2.0
from tensorflow

  def test_errors_out_if_evaluate_returns_dict_without_global_step(self):
    est = estimator_lib.Estimator(
        model_fn=self._model_fn,
        config=run_config_lib.RunConfig(save_checkpoints_steps=2))
    mock_est = tf.compat.v1.test.mock.Mock(
        spec=estimator_lib.Estimator, wraps=est)
    train_spec = training.TrainSpec(input_fn=self._input_fn)
    eval_spec = training.EvalSpec(
        input_fn=lambda: self._input_fn(repeat=False), throttle_secs=0)
    mock_est.evaluate.return_value = {'loss': 123}

    executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
    with self.assertRaisesRegexp(ValueError,
                                 _MISSING_GLOBAL_STEP_IN_EVAL_RESULT_ERR):
      executor.run_local()

  def test_train_and_evaluate_return_metrics(self):

3 Source : training_test.py
with Apache License 2.0
from tensorflow

  def test_invalid_task_type(self):
    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
    mock_est.config = tf.compat.v1.test.mock.Mock()
    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)
    mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)

    mock_est.config = tf.compat.v1.test.mock.Mock()
    mock_est.config.cluster_spec = tf.train.ClusterSpec({'1': ['dummy']})
    mock_est.config.task_type = ''

    executor = training._TrainingExecutor(mock_est, mock_train_spec,
                                          mock_eval_spec)
    with self.assertRaisesRegexp(ValueError, _INVALID_TASK_TYPE):
      executor.run()


class TrainAndEvaluateIntegrationTest(tf.test.TestCase):

3 Source : driver_test.py
with Apache License 2.0
from tensorflow

  def setUp(self):
    super().setUp()

    self._test_dir = os.path.join(
        os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
        self._testMethodName)

    # Mock metadata and create driver.
    self._mock_metadata = tf.compat.v1.test.mock.Mock()
    self._file_based_driver = driver.FileBasedDriver(self._mock_metadata)

  def testResolveExecProperties(self):

3 Source : publisher_test.py
with Apache License 2.0
from tensorflow

  def setUp(self):
    super().setUp()
    self._mock_metadata = tf.compat.v1.test.mock.Mock()
    self._mock_metadata.publish_execution = tf.compat.v1.test.mock.Mock()
    self._output_dict = {
        'output_data': [_OutputType()],
    }
    self._exec_properties = {'k': 'v'}
    self._pipeline_info = data_types.PipelineInfo(
        pipeline_name='my_pipeline', pipeline_root='/tmp', run_id='my_run_id')
    self._component_info = data_types.ComponentInfo(
        component_type='a.b.c',
        component_id='my_component',
        pipeline_info=self._pipeline_info)

  def testPrepareExecutionComplete(self):

0 Source : training_test.py
with Apache License 2.0
from tensorflow

  def test_error_out_if_evaluator_task_id_is_non_zero(self):
    tf_config = {
        'cluster': {
            run_config_lib.TaskType.CHIEF: ['host0:0'],
        },
        'task': {
            'type': run_config_lib.TaskType.EVALUATOR,
            'index': 1
        }
    }

    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
    mock_est.config = _create_run_config_with_cluster_spec(tf_config)
    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)
    mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)

    with self.assertRaisesRegexp(ValueError, _INVALID_EVAL_TASK_ID_ERR):
      training.train_and_evaluate(mock_est, mock_train_spec, mock_eval_spec)

  def test_invalid_estimator(self):

0 Source : training_test.py
with Apache License 2.0
from tensorflow

  def test_train_with_train_spec(self, mock_server, unused_mock_sleep):
    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
    mock_est.config = self._run_config
    train_spec = training.TrainSpec(
        input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()])
    mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)
    mock_server_instance = mock_server.return_value

    executor = training._TrainingExecutor(mock_est, train_spec, mock_eval_spec)
    self._run_task(executor)

    mock_server.assert_called_with(
        mock_est.config.cluster_spec,
        job_name=mock_est.config.task_type,
        task_index=mock_est.config.task_id,
        config=tf.compat.v1.test.mock.ANY,
        protocol=None,
        start=False)

    self.assertTrue(mock_server_instance.start.called)

    mock_est.train.assert_called_with(
        input_fn=train_spec.input_fn,
        max_steps=train_spec.max_steps,
        hooks=list(train_spec.hooks),
        saving_listeners=tf.compat.v1.test.mock.ANY)
    mock_est.evaluate.assert_not_called()
    mock_est.export_saved_model.assert_not_called()

  @tf.compat.v1.test.mock.patch.object(time, 'sleep')

0 Source : training_test.py
with Apache License 2.0
from tensorflow

  def test_train_with_no_eval_spec(self, mock_server, unused_mock_sleep):
    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
    mock_est.config = self._run_config
    train_spec = training.TrainSpec(
        input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()])
    eval_spec = None
    mock_server_instance = mock_server.return_value

    executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
    self._run_task(executor)

    mock_server.assert_called_with(
        mock_est.config.cluster_spec,
        job_name=mock_est.config.task_type,
        task_index=mock_est.config.task_id,
        config=tf.compat.v1.test.mock.ANY,
        protocol=None,
        start=False)

    self.assertTrue(mock_server_instance.start.called)

    mock_est.train.assert_called_with(
        input_fn=train_spec.input_fn,
        max_steps=train_spec.max_steps,
        hooks=list(train_spec.hooks),
        saving_listeners=tf.compat.v1.test.mock.ANY)
    mock_est.evaluate.assert_not_called()
    mock_est.export_saved_model.assert_not_called()

  @tf.compat.v1.test.mock.patch.object(time, 'sleep')

0 Source : training_test.py
with Apache License 2.0
from tensorflow

  def test_single_worker_node_with_empty_tf_master(self, mock_server,
                                                   unused_mock_sleep):
    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
    mock_train_spec = tf.compat.v1.test.mock.Mock(
        spec=training.TrainSpec, hooks=[])
    mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)

    mock_est.config = tf.compat.v1.test.mock.PropertyMock(
        spec=run_config_lib.RunConfig)
    # Single node cluster.
    mock_est.config.cluster_spec = tf.train.ClusterSpec({'worker': ['dummy']})
    mock_est.config.master = ''
    mock_est.config.task_type = 'worker'
    mock_est.config.task_id = 2

    self._run_task(
        training._TrainingExecutor(mock_est, mock_train_spec, mock_eval_spec))
    self.assertTrue(mock_est.train.called)
    mock_server.assert_not_called()

  def test_fail_with_empty_task_type(self):

0 Source : training_test.py
with Apache License 2.0
from tensorflow

  def test_no_delay_for_master(self, _):
    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
    mock_est.evaluate = lambda *args, **kw: {
        tf.compat.v1.GraphKeys.GLOBAL_STEP: 123
    }
    mock_est.config = self._run_config
    mock_train_spec = tf.compat.v1.test.mock.Mock(
        spec=training.TrainSpec, max_steps=123, hooks=[])
    mock_eval_spec = tf.compat.v1.test.mock.Mock(
        spec=training.EvalSpec, exporters=[])

    mock_train_spec.saving_listeners = tuple([])

    executor = training._TrainingExecutor(mock_est, mock_train_spec,
                                          mock_eval_spec)

    with tf.compat.v1.test.mock.patch.object(time, 'sleep') as mock_sleep:
      executor.run_master()
      mock_sleep.assert_not_called()

  @tf.compat.v1.test.mock.patch.object(time, 'sleep')

0 Source : training_test.py
with Apache License 2.0
from tensorflow

  def test_train_with_train_spec(self, mock_server, unused_mock_sleep):
    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
    mock_est.evaluate = lambda *args, **kw: {
        tf.compat.v1.GraphKeys.GLOBAL_STEP: 123
    }
    mock_est.config = self._run_config
    train_spec = training.TrainSpec(
        input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()])
    mock_eval_spec = tf.compat.v1.test.mock.Mock(
        spec=training.EvalSpec, exporters=[])
    mock_server_instance = mock_server.return_value

    executor = training._TrainingExecutor(mock_est, train_spec, mock_eval_spec)
    executor.run_master()

    mock_server.assert_called_with(
        mock_est.config.cluster_spec,
        job_name=mock_est.config.task_type,
        task_index=mock_est.config.task_id,
        config=tf.compat.v1.test.mock.ANY,
        protocol=None,
        start=False)

    self.assertTrue(mock_server_instance.start.called)

    mock_est.train.assert_called_with(
        input_fn=train_spec.input_fn,
        max_steps=train_spec.max_steps,
        hooks=list(train_spec.hooks),
        saving_listeners=tf.compat.v1.test.mock.ANY)
    mock_est.export_saved_model.assert_not_called()

  @tf.compat.v1.test.mock.patch.object(time, 'sleep')

0 Source : training_test.py
with Apache License 2.0
from tensorflow

  def test_train_with_train_hooks(self, mock_server, unused_mock_sleep):
    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
    mock_est.evaluate = lambda *args, **kw: {
        tf.compat.v1.GraphKeys.GLOBAL_STEP: 123
    }
    mock_est.config = self._run_config
    train_spec = training.TrainSpec(
        input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()])
    mock_eval_spec = tf.compat.v1.test.mock.Mock(
        spec=training.EvalSpec, exporters=[])
    extra_hooks = [_FakeHook()]

    executor = training._TrainingExecutor(
        mock_est, train_spec, mock_eval_spec, train_hooks=extra_hooks)
    executor.run_master()

    mock_est.train.assert_called_with(
        input_fn=train_spec.input_fn,
        max_steps=train_spec.max_steps,
        hooks=list(train_spec.hooks) + extra_hooks,
        saving_listeners=tf.compat.v1.test.mock.ANY)

  @tf.compat.v1.test.mock.patch.object(time, 'sleep')

0 Source : training_test.py
with Apache License 2.0
from tensorflow

  def test_no_server_startup_in_google(self, mock_server, unused_mock_sleep):
    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
    mock_est.evaluate = lambda *args, **kw: {
        tf.compat.v1.GraphKeys.GLOBAL_STEP: 123
    }
    mock_est.config = self._run_config
    mock_train_spec = tf.compat.v1.test.mock.Mock(
        spec=training.TrainSpec, max_steps=123, hooks=[])
    mock_eval_spec = tf.compat.v1.test.mock.Mock(
        spec=training.EvalSpec, exporters=[])

    mock_train_spec.saving_listeners = tuple([])

    executor = training._TrainingExecutor(mock_est, mock_train_spec,
                                          mock_eval_spec)
    tf_config = {'TF_CONFIG': json.dumps(_TF_CONFIG_FOR_GOOGLE)}
    with tf.compat.v1.test.mock.patch.dict('os.environ', tf_config):
      executor.run_master()
      mock_server.assert_not_called()

  def test_fail_with_empty_cluster_spec(self):

0 Source : training_test.py
with Apache License 2.0
from tensorflow

  def test_fail_with_empty_cluster_spec(self):
    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)
    mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)

    mock_est.config = tf.compat.v1.test.mock.PropertyMock(
        spec=run_config_lib.RunConfig)
    mock_est.config.cluster_spec = None
    mock_est.config.master = 'grpc://...'
    mock_est.config.task_type = 'master'
    mock_est.config.task_id = 2

    mock_train_spec.saving_listeners = tuple([])

    with self.assertRaisesRegexp(RuntimeError,
                                 _INVALID_CONFIG_FOR_STD_SERVER_MSG):
      training._TrainingExecutor(mock_est, mock_train_spec,
                                 mock_eval_spec).run_master()

  def test_fail_with_empty_master(self):

0 Source : training_test.py
with Apache License 2.0
from tensorflow

  def test_fail_with_empty_master(self):
    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)
    mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)

    mock_est.config = tf.compat.v1.test.mock.PropertyMock(
        spec=run_config_lib.RunConfig)
    mock_est.config.cluster_spec = tf.train.ClusterSpec({
        'master': ['dummy'],
        'worker': ['dummy1']
    })
    mock_est.config.master = ''
    mock_est.config.task_type = 'master'
    mock_est.config.task_id = 0

    mock_train_spec.saving_listeners = tuple([])

    with self.assertRaisesRegexp(RuntimeError,
                                 _INVALID_CONFIG_FOR_STD_SERVER_MSG):
      training._TrainingExecutor(mock_est, mock_train_spec,
                                 mock_eval_spec).run_master()

  @tf.compat.v1.test.mock.patch.object(time, 'sleep')

0 Source : training_test.py
with Apache License 2.0
from tensorflow

  def test_single_master_node_with_empty_tf_master(self, mock_server,
                                                   unused_mock_sleep):
    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
    mock_est.evaluate = lambda *args, **kw: {
        tf.compat.v1.GraphKeys.GLOBAL_STEP: 123
    }

    mock_train_spec = tf.compat.v1.test.mock.Mock(
        spec=training.TrainSpec, max_steps=123, hooks=[])
    mock_eval_spec = tf.compat.v1.test.mock.Mock(
        spec=training.EvalSpec, exporters=[])

    mock_est.config = tf.compat.v1.test.mock.PropertyMock(
        spec=run_config_lib.RunConfig)
    mock_est.config.cluster_spec = tf.train.ClusterSpec({'master': ['dummy']})
    mock_est.config.master = ''
    mock_est.config.task_type = 'master'
    mock_est.config.task_id = 0

    mock_train_spec.saving_listeners = tuple([])

    executor = training._TrainingExecutor(mock_est, mock_train_spec,
                                          mock_eval_spec)
    executor.run_master()

    mock_server.assert_not_called()
    self.assertTrue(mock_est.train.called)

  def test_fail_with_empty_task_type(self):

0 Source : training_test.py
with Apache License 2.0
from tensorflow

  def test_fail_with_empty_task_type(self):
    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)
    mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)

    mock_est.config = tf.compat.v1.test.mock.PropertyMock(
        spec=run_config_lib.RunConfig)
    mock_est.config.cluster_spec = tf.train.ClusterSpec({'master': ['dummy']})
    mock_est.config.master = 'grpc://...'
    mock_est.config.task_type = ''
    mock_est.config.task_id = 2

    mock_train_spec.saving_listeners = tuple([])

    with self.assertRaisesRegexp(RuntimeError,
                                 _INVALID_CONFIG_FOR_STD_SERVER_MSG):
      training._TrainingExecutor(mock_est, mock_train_spec,
                                 mock_eval_spec).run_master()

  def test_fail_with_none_task_id(self):

0 Source : training_test.py
with Apache License 2.0
from tensorflow

  def test_fail_with_none_task_id(self):
    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)
    mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)

    mock_est.config = tf.compat.v1.test.mock.PropertyMock(
        spec=run_config_lib.RunConfig)
    mock_est.config.cluster_spec = tf.train.ClusterSpec({'master': ['dummy']})
    mock_est.config.master = 'grpc://...'
    mock_est.config.task_type = 'master'
    mock_est.config.task_id = None

    mock_train_spec.saving_listeners = tuple([])

    with self.assertRaisesRegexp(RuntimeError,
                                 _INVALID_CONFIG_FOR_STD_SERVER_MSG):
      training._TrainingExecutor(mock_est, mock_train_spec,
                                 mock_eval_spec).run_master()

  @tf.compat.v1.test.mock.patch.object(server_lib, 'Server')

0 Source : training_test.py
with Apache License 2.0
from tensorflow

  def test_run_master_triggers_evaluate_and_export(self, _):

    def estimator_train(saving_listeners, *args, **kwargs):
      #  There shalt be a saving_listener.  Estimator is going to call
      # `after_save`.
      del args, kwargs
      saving_listeners[0].begin()
      saving_listeners[0].after_save(session=None, global_step_value=0)
      saving_listeners[0].after_save(session=None, global_step_value=10)

    mock_est = tf.compat.v1.test.mock.Mock(
        spec=estimator_lib.Estimator, model_dir='path/', train=estimator_train)
    mock_est.latest_checkpoint.return_value = 'checkpoint_path/'
    mock_est.config = self._run_config

    exporter = tf.compat.v1.test.mock.PropertyMock(spec=exporter_lib.Exporter)
    exporter.name = 'see_whether_export_is_called'

    train_spec = training.TrainSpec(input_fn=lambda: 1, max_steps=300)
    eval_spec = training.EvalSpec(
        input_fn=lambda: 1, steps=2, exporters=exporter)
    eval_result = {_GLOBAL_STEP_KEY: train_spec.max_steps}
    mock_est.evaluate.return_value = eval_result

    executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
    executor.run_master()

    mock_est.evaluate.assert_called_with(
        name=eval_spec.name,
        input_fn=eval_spec.input_fn,
        steps=eval_spec.steps,
        checkpoint_path='checkpoint_path/',
        hooks=eval_spec.hooks)
    self.assertEqual(1, exporter.export.call_count)
    exporter.export.assert_called_with(
        estimator=mock_est,
        export_path=os.path.join('path/', 'export', exporter.name),
        checkpoint_path='checkpoint_path/',
        eval_result=eval_result,
        is_the_final_export=True)

  @tf.compat.v1.test.mock.patch.object(basic_session_run_hooks,

0 Source : training_test.py
with Apache License 2.0
from tensorflow

  def test_run_master_throttle_eval(self, _, mock_timer_class):
    mock_est = tf.compat.v1.test.mock.Mock(
        spec=estimator_lib.Estimator, model_dir='path/')

    mock_timer = tf.compat.v1.test.mock.Mock()
    mock_timer_class.return_value = mock_timer

    def estimator_train(saving_listeners, *args, **kwargs):
      del args, kwargs
      saving_listeners[0].begin()

      # Call four times.
      mock_timer.should_trigger_for_step.return_value = True
      saving_listeners[0].after_save(session=None, global_step_value=None)

      mock_timer.should_trigger_for_step.return_value = True
      saving_listeners[0].after_save(session=None, global_step_value=None)

      mock_timer.should_trigger_for_step.return_value = False
      saving_listeners[0].after_save(session=None, global_step_value=None)

      mock_timer.should_trigger_for_step.return_value = True
      saving_listeners[0].after_save(session=None, global_step_value=None)

    mock_est.train = estimator_train
    mock_est.latest_checkpoint.side_effect = ['ckpt1', 'ckpt2']
    mock_est.config = self._run_config

    exporter = tf.compat.v1.test.mock.PropertyMock(spec=exporter_lib.Exporter)
    exporter.name = 'see_whether_export_is_called'

    train_spec = training.TrainSpec(input_fn=lambda: 1, max_steps=300)
    eval_spec = training.EvalSpec(
        input_fn=lambda: 1, steps=2, exporters=exporter, throttle_secs=10)

    mock_est.evaluate.side_effect = [{
        _GLOBAL_STEP_KEY: train_spec.max_steps // 2
    }, {
        _GLOBAL_STEP_KEY: train_spec.max_steps
    }]

    executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
    executor.run_master()

    self.assertEqual(2, mock_est.evaluate.call_count)
    self.assertEqual(2, exporter.export.call_count)

    is_final_export_list = [
        call[1]['is_the_final_export']
        for call in exporter.export.call_args_list
    ]
    self.assertEqual([False, True], is_final_export_list)

  @tf.compat.v1.test.mock.patch.object(basic_session_run_hooks,

0 Source : training_test.py
with Apache License 2.0
from tensorflow

  def test_run_master_throttle_eval_which_skips_final_ckpt(
      self, _, mock_timer_class):
    mock_est = tf.compat.v1.test.mock.Mock(
        spec=estimator_lib.Estimator, model_dir='path/')

    mock_timer = tf.compat.v1.test.mock.Mock()
    mock_timer_class.return_value = mock_timer

    def estimator_train(saving_listeners, *args, **kwargs):
      del args, kwargs
      saving_listeners[0].begin()

      # Call tree times (one for first saving).
      mock_timer.should_trigger_for_step.return_value = True
      saving_listeners[0].after_save(session=None, global_step_value=0)

      mock_timer.should_trigger_for_step.return_value = True
      saving_listeners[0].after_save(session=None, global_step_value=125)

      mock_timer.should_trigger_for_step.return_value = False
      saving_listeners[0].after_save(session=None, global_step_value=250)

      # At the end evaluate should be called even if throttle secs prevents it.
      mock_timer.should_trigger_for_step.return_value = False
      saving_listeners[0].end(session=None, global_step_value=300)

    mock_est.train = estimator_train
    mock_est.latest_checkpoint.side_effect = ['ckpt1', 'ckpt2']
    mock_est.config = self._run_config

    exporter = tf.compat.v1.test.mock.PropertyMock(spec=exporter_lib.Exporter)
    exporter.name = 'see_whether_export_is_called'

    train_spec = training.TrainSpec(input_fn=lambda: 1, max_steps=300)
    eval_spec = training.EvalSpec(
        input_fn=lambda: 1, steps=2, exporters=exporter, throttle_secs=10)

    mock_est.evaluate.side_effect = [{
        _GLOBAL_STEP_KEY: train_spec.max_steps // 2
    }, {
        _GLOBAL_STEP_KEY: train_spec.max_steps
    }]

    executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
    executor.run_master()

    self.assertEqual(2, mock_est.evaluate.call_count)
    self.assertEqual(2, exporter.export.call_count)

    is_final_export_list = [
        call[1]['is_the_final_export']
        for call in exporter.export.call_args_list
    ]
    self.assertEqual([False, True], is_final_export_list)


class TrainingExecutorRunEvaluatorTest(tf.test.TestCase):

0 Source : training_test.py
with Apache License 2.0
from tensorflow

  def test_evaluate_with_evaluate_spec(self):
    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
    mock_est.latest_checkpoint.return_value = 'latest_it_is'
    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)
    self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec)

    eval_spec = training.EvalSpec(
        input_fn=lambda: 1,
        steps=2,
        hooks=[_FakeHook()],
        name='cont_eval',
        start_delay_secs=0,
        throttle_secs=0)

    executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec)
    executor.run_evaluator()

    mock_est.evaluate.assert_called_with(
        name='cont_eval',
        input_fn=eval_spec.input_fn,
        steps=eval_spec.steps,
        checkpoint_path='latest_it_is',
        hooks=eval_spec.hooks)
    self.assertFalse(mock_est.train.called)

  def test_evaluate_with_no_eval_spec_fails(self):

0 Source : training_test.py
with Apache License 2.0
from tensorflow

  def test_evaluate_with_train_hooks(self):
    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
    mock_est.latest_checkpoint.return_value = 'latest_it_is'
    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)
    self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec)

    eval_spec = training.EvalSpec(
        input_fn=lambda: 1,
        steps=2,
        hooks=[_FakeHook()],
        name='cont_eval',
        start_delay_secs=0,
        throttle_secs=0)

    # The train_hooks will not be called during eval.
    mock_hook = tf.compat.v1.test.mock.Mock(
        spec=tf.compat.v1.train.SessionRunHook)
    executor = training._TrainingExecutor(
        mock_est, mock_train_spec, eval_spec, train_hooks=[mock_hook])
    executor.run_evaluator()

    mock_hook.begin.assert_not_called()

  def test_evaluate_multiple_times(self):

0 Source : training_test.py
with Apache License 2.0
from tensorflow

  def test_evaluate_multiple_times(self):
    training_max_step = 200

    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
    mock_est.model_dir = tf.compat.as_bytes(tf.compat.v1.test.get_temp_dir())
    mock_est.evaluate.side_effect = [{
        _GLOBAL_STEP_KEY: training_max_step // 2
    }, {
        _GLOBAL_STEP_KEY: training_max_step
    }]
    mock_est.latest_checkpoint.side_effect = ['path_1', 'path_2']

    mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)
    mock_train_spec.max_steps = training_max_step

    exporter = tf.compat.v1.test.mock.PropertyMock(spec=exporter_lib.Exporter)
    exporter.name = 'see_how_many_times_export_is_called'

    mock_est.times_export_was_called = 0
    mock_est.times_final_export_was_called = 0

    def export(estimator, export_path, checkpoint_path, eval_result,
               is_the_final_export):
      del export_path, checkpoint_path, eval_result
      estimator.times_export_was_called += 1
      # final_export is happened at the end.
      self.assertEqual(0, estimator.times_final_export_was_called)
      if is_the_final_export:
        estimator.times_final_export_was_called += 1

    exporter.export = export

    eval_spec = training.EvalSpec(
        input_fn=lambda: 1,
        start_delay_secs=0,
        throttle_secs=0,
        exporters=exporter)

    executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec)
    executor.run_evaluator()

    self.assertEqual(2, mock_est.evaluate.call_count)
    self.assertEqual(2, mock_est.times_export_was_called)
    self.assertEqual(1, mock_est.times_final_export_was_called)

  def test_evaluate_listener_before_eval(self):

0 Source : training_test.py
with Apache License 2.0
from tensorflow

  def test_evaluate_listener_before_eval(self):
    training_max_step = 200

    mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
    mock_est.model_dir = tf.compat.as_bytes(tf.compat.v1.test.get_temp_dir())
    # Without early stopping, this eval will be run twice.
    mock_est.evaluate.side_effect = [{
        _GLOBAL_STEP_KEY: training_max_step // 2
    }, {
        _GLOBAL_STEP_KEY: training_max_step
    }]
    mock_est.latest_checkpoint.side_effect = ['path_1', 'path_2']

    mock_train_spec = tf.compat.v1.test.mock.Mock(
        spec=training.TrainSpec, hooks=[])
    mock_train_spec.max_steps = training_max_step

    class _Listener(training._ContinuousEvalListener):

      def __init__(self):
        self.call_count = 0

      def before_eval(self):
        self.call_count += 1
        return self.call_count == 1

    listener = _Listener()

    eval_spec = training.EvalSpec(
        input_fn=lambda: 1, start_delay_secs=0, throttle_secs=0)

    training._TrainingExecutor(
        mock_est, mock_train_spec, eval_spec,
        continuous_eval_listener=listener).run_evaluator()

    # Before_eval returns False during the second time, so, evaluate will be
    # called once.
    self.assertEqual(1, mock_est.evaluate.call_count)
    self.assertEqual(2, listener.call_count)

  def test_evaluate_listener_after_eval(self):

See More Examples