Here are the examples of the python api tensorflow.compat.v1.test.mock.Mock taken from open source projects. By voting up you can indicate which examples are most useful and appropriate.
75 Examples
3
Source : rnn_test.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def testSpecificLayerTypeArguments(self):
"""Tests arguments for specific layer types (GRU and LSTM)."""
mock_layer_type = tf.compat.v1.test.mock.Mock()
with tf.compat.v1.test.mock.patch.object(rnn, '_CELL_TYPE_TO_LAYER_MAPPING',
{'custom-type': mock_layer_type}):
_make_rnn_layer(
cell_type='custom-type',
units=11,
return_sequences='return-seq-value')
mock_layer_type.assert_called_once_with(
units=11, return_sequences='return-seq-value')
@tf.compat.v1.test.mock.patch.object(tf.keras.layers, 'RNN')
3
Source : rnn_test.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def testCustomCellProvided(self, mock_rnn_layer_type):
"""Tests behavior when a custom cell type is provided."""
mock_custom_cell = tf.compat.v1.test.mock.Mock()
_make_rnn_layer(
units=[10],
cell_type=lambda units: mock_custom_cell,
return_sequences='return-seq-value')
mock_rnn_layer_type.assert_called_once_with(
cell=mock_custom_cell, return_sequences='return-seq-value')
def testMultipleCellsProvided(self):
3
Source : rnn_test.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def testCustomCellFnProvided(self, mock_rnn_layer_type):
"""Tests behavior when a custom cell function is provided."""
mock_cell_fn = tf.compat.v1.test.mock.Mock(return_value='custom-cell')
_make_rnn_layer(
rnn_cell_fn=mock_cell_fn, return_sequences='return-seq-value')
mock_rnn_layer_type.assert_called_once_with(
cell='custom-cell', return_sequences='return-seq-value')
def _mock_logits_layer(kernel, bias):
3
Source : rnn_test.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def _get_mock_head():
mock_head = multi_head_lib.MultiClassHead(3)
mock_head.create_estimator_spec = tf.compat.v1.test.mock.Mock(
return_value=model_fn.EstimatorSpec(None))
return mock_head
@test_util.run_all_in_graph_and_eager_modes
3
Source : training_test.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def test_run_task(self):
mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)
mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)
with tf.compat.v1.test.mock.patch.object(
training, '_TrainingExecutor') as mock_executor:
mock_executor_instance = tf.compat.v1.test.mock.Mock()
mock_executor.return_value = mock_executor_instance
training.train_and_evaluate(mock_est, mock_train_spec, mock_eval_spec)
mock_executor.assert_called_with(
estimator=mock_est,
train_spec=mock_train_spec,
eval_spec=mock_eval_spec)
self.assertTrue(mock_executor_instance.run.called)
def test_error_out_if_evaluator_task_id_is_non_zero(self):
3
Source : training_test.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def test_invalid_estimator(self):
invalid_estimator = object()
mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)
mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)
with self.assertRaisesRegexp(TypeError, _INVALID_ESTIMATOR_MSG):
training.train_and_evaluate(invalid_estimator, mock_train_spec,
mock_eval_spec)
def test_fail_fast_if_invalid_eval_spec(self):
3
Source : training_test.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def test_fail_fast_if_invalid_eval_spec(self):
mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)
invalid_eval_spec = object()
with tf.compat.v1.test.mock.patch.object(
training, '_TrainingExecutor') as mock_executor:
with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_SPEC_MSG):
training.train_and_evaluate(mock_est, mock_train_spec,
invalid_eval_spec)
mock_executor.assert_not_called()
class TrainingExecutorConstructorTest(tf.test.TestCase):
3
Source : training_test.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def test_train_with_train_hooks(self, unused_mock_server, unused_mock_sleep):
mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
mock_est.config = self._run_config
train_spec = training.TrainSpec(
input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()])
mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)
extra_hooks = [_FakeHook()]
executor = training._TrainingExecutor(
mock_est, train_spec, mock_eval_spec, train_hooks=extra_hooks)
self._run_task(executor)
mock_est.train.assert_called_with(
input_fn=train_spec.input_fn,
max_steps=train_spec.max_steps,
hooks=list(train_spec.hooks) + extra_hooks,
saving_listeners=tf.compat.v1.test.mock.ANY)
@tf.compat.v1.test.mock.patch.object(time, 'sleep')
3
Source : training_test.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def test_no_server_startup_in_google(self, mock_server, unused_mock_sleep):
mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
mock_est.config = self._run_config
mock_train_spec = tf.compat.v1.test.mock.Mock(
spec=training.TrainSpec, hooks=[])
mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)
executor = training._TrainingExecutor(mock_est, mock_train_spec,
mock_eval_spec)
tf_config = {'TF_CONFIG': json.dumps(_TF_CONFIG_FOR_GOOGLE)}
with tf.compat.v1.test.mock.patch.dict('os.environ', tf_config):
self._run_task(executor)
mock_server.assert_not_called()
def test_fail_with_empty_cluster_spec(self):
3
Source : training_test.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def test_fail_with_empty_cluster_spec(self):
mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)
mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)
mock_est.config = tf.compat.v1.test.mock.PropertyMock(
spec=run_config_lib.RunConfig)
mock_est.config.cluster_spec = None
mock_est.config.master = 'grpc://...'
mock_est.config.task_type = 'worker'
mock_est.config.task_id = 2
with self.assertRaisesRegexp(RuntimeError,
_INVALID_CONFIG_FOR_STD_SERVER_MSG):
self._run_task(
training._TrainingExecutor(mock_est, mock_train_spec, mock_eval_spec))
def test_fail_with_empty_master(self):
3
Source : training_test.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def test_fail_with_empty_master(self):
mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)
mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)
mock_est.config = tf.compat.v1.test.mock.PropertyMock(
spec=run_config_lib.RunConfig)
mock_est.config.cluster_spec = tf.train.ClusterSpec(
{'worker': ['dummy', 'dummy1']})
mock_est.config.master = ''
mock_est.config.task_type = 'worker'
mock_est.config.task_id = 2
with self.assertRaisesRegexp(RuntimeError,
_INVALID_CONFIG_FOR_STD_SERVER_MSG):
self._run_task(
training._TrainingExecutor(mock_est, mock_train_spec, mock_eval_spec))
@tf.compat.v1.test.mock.patch.object(time, 'sleep')
3
Source : training_test.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def test_fail_with_empty_task_type(self):
mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)
mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)
mock_est.config = tf.compat.v1.test.mock.PropertyMock(
spec=run_config_lib.RunConfig)
mock_est.config.cluster_spec = tf.train.ClusterSpec({'worker': ['dummy']})
mock_est.config.master = 'grpc://...'
mock_est.config.task_type = ''
mock_est.config.task_id = 2
with self.assertRaisesRegexp(RuntimeError,
_INVALID_CONFIG_FOR_STD_SERVER_MSG):
self._run_task(
training._TrainingExecutor(mock_est, mock_train_spec, mock_eval_spec))
def test_fail_with_none_task_id(self):
3
Source : training_test.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def test_fail_with_none_task_id(self):
mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)
mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)
mock_est.config = tf.compat.v1.test.mock.PropertyMock(
spec=run_config_lib.RunConfig)
mock_est.config.cluster_spec = tf.train.ClusterSpec({'worker': ['dummy']})
mock_est.config.master = 'grpc://...'
mock_est.config.task_type = 'worker'
mock_est.config.task_id = None
with self.assertRaisesRegexp(RuntimeError,
_INVALID_CONFIG_FOR_STD_SERVER_MSG):
self._run_task(
training._TrainingExecutor(mock_est, mock_train_spec, mock_eval_spec))
class TrainingExecutorRunWorkerTest(_TrainingExecutorTrainingTest,
3
Source : training_test.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def test_delay_for_worker(self, _):
mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
mock_est.config = self._run_config
mock_train_spec = tf.compat.v1.test.mock.Mock(
spec=training.TrainSpec, hooks=[])
mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)
executor = training._TrainingExecutor(mock_est, mock_train_spec,
mock_eval_spec)
expected_secs = (self._run_config.task_id + 1) * _DELAY_SECS_PER_WORKER
with tf.compat.v1.test.mock.patch.object(time, 'sleep') as mock_sleep:
mock_sleep.side_effect = lambda s: self.assertEqual(expected_secs, s)
self._run_task(executor)
self.assertTrue(mock_sleep.called)
@tf.compat.v1.test.mock.patch.object(server_lib, 'Server')
3
Source : training_test.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def test_delay_disabled_for_worker(self, _):
mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
mock_est.config = self._run_config.replace(
experimental_max_worker_delay_secs=0)
mock_train_spec = tf.compat.v1.test.mock.Mock(
spec=training.TrainSpec, hooks=[])
mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)
executor = training._TrainingExecutor(mock_est, mock_train_spec,
mock_eval_spec)
with tf.compat.v1.test.mock.patch.object(time, 'sleep') as mock_sleep:
self._run_task(executor)
self.assertFalse(mock_sleep.called)
class TrainingExecutorRunChiefTest(_TrainingExecutorTrainingTest,
3
Source : training_test.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def test_no_delay_for_chief(self, _):
mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
mock_est.config = self._run_config
mock_train_spec = tf.compat.v1.test.mock.Mock(
spec=training.TrainSpec, hooks=[])
mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)
executor = training._TrainingExecutor(mock_est, mock_train_spec,
mock_eval_spec)
with tf.compat.v1.test.mock.patch.object(time, 'sleep') as mock_sleep:
self._run_task(executor)
mock_sleep.assert_not_called()
class TrainingExecutorRunMasterTest(tf.test.TestCase):
3
Source : training_test.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def test_train_with_no_eval_spec_fails(self, mock_server, unused_mock_sleep):
mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
mock_est.evaluate = lambda *args, **kw: {
tf.compat.v1.GraphKeys.GLOBAL_STEP: 123
}
mock_est.config = self._run_config
train_spec = training.TrainSpec(
input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()])
eval_spec = None
executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_SPEC_MSG):
executor.run_master()
@tf.compat.v1.test.mock.patch.object(time, 'sleep')
3
Source : training_test.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def test_evaluate_with_no_eval_spec_fails(self):
mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
mock_est.latest_checkpoint.return_value = 'latest_it_is'
mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)
self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec)
eval_spec = None
executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec)
with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_SPEC_MSG):
executor.run_evaluator()
def test_evaluate_with_train_hooks(self):
3
Source : training_test.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def test_errors_out_if_evaluate_returns_empty_dict(self):
mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
train_spec = training.TrainSpec(input_fn=lambda: 1)
eval_spec = training.EvalSpec(
input_fn=(lambda: 1), start_delay_secs=0, throttle_secs=0)
mock_est.evaluate.return_value = {}
executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
with self.assertRaisesRegexp(ValueError, _INVALID_EMPTY_EVAL_RESULT_ERR):
executor.run_evaluator()
def test_errors_out_if_evaluate_returns_non_dict(self):
3
Source : training_test.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def test_errors_out_if_evaluate_returns_non_dict(self):
mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
train_spec = training.TrainSpec(input_fn=lambda: 1)
eval_spec = training.EvalSpec(
input_fn=(lambda: 1), start_delay_secs=0, throttle_secs=0)
mock_est.evaluate.return_value = 123
executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_RESULT_TYPE_ERR):
executor.run_evaluator()
def test_errors_out_if_evaluate_returns_dict_without_global_step(self):
3
Source : training_test.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def test_errors_out_if_evaluate_returns_dict_without_global_step(self):
mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
train_spec = training.TrainSpec(input_fn=lambda: 1)
eval_spec = training.EvalSpec(
input_fn=(lambda: 1), start_delay_secs=0, throttle_secs=0)
mock_est.evaluate.return_value = {'loss': 123}
executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
with self.assertRaisesRegexp(ValueError,
_MISSING_GLOBAL_STEP_IN_EVAL_RESULT_ERR):
executor.run_evaluator()
class TrainingExecutorRunPsTest(tf.test.TestCase):
3
Source : training_test.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def test_fail_with_empty_task_type(self):
mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)
mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)
mock_est.config = tf.compat.v1.test.mock.PropertyMock(
spec=run_config_lib.RunConfig)
mock_est.config.cluster_spec = tf.train.ClusterSpec({'ps': ['dummy']})
mock_est.config.master = 'grpc://...'
mock_est.config.task_type = ''
mock_est.config.task_id = 2
with self.assertRaisesRegexp(RuntimeError,
_INVALID_CONFIG_FOR_STD_SERVER_MSG):
training._TrainingExecutor(mock_est, mock_train_spec,
mock_eval_spec).run_ps()
def test_fail_with_none_task_id(self):
3
Source : training_test.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def test_fail_with_none_task_id(self):
mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)
mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)
mock_est.config = tf.compat.v1.test.mock.PropertyMock(
spec=run_config_lib.RunConfig)
mock_est.config.cluster_spec = tf.train.ClusterSpec({'ps': ['dummy']})
mock_est.config.master = 'grpc://...'
mock_est.config.task_type = 'ps'
mock_est.config.task_id = None
with self.assertRaisesRegexp(RuntimeError,
_INVALID_CONFIG_FOR_STD_SERVER_MSG):
training._TrainingExecutor(mock_est, mock_train_spec,
mock_eval_spec).run_ps()
class StopAtSecsHookTest(tf.test.TestCase):
3
Source : training_test.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def test_train_with_no_eval_spec_fails(self):
mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
train_spec = training.TrainSpec(
input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()])
eval_spec = None
executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_SPEC_MSG):
executor.run_local()
def test_train_hooks(self):
3
Source : training_test.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def test_errors_out_if_evaluate_returns_empty_dict(self):
est = estimator_lib.Estimator(
model_fn=self._model_fn,
config=run_config_lib.RunConfig(save_checkpoints_steps=2))
mock_est = tf.compat.v1.test.mock.Mock(
spec=estimator_lib.Estimator, wraps=est)
train_spec = training.TrainSpec(input_fn=self._input_fn)
eval_spec = training.EvalSpec(
input_fn=lambda: self._input_fn(repeat=False), throttle_secs=0)
mock_est.evaluate.return_value = {}
executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
with self.assertRaisesRegexp(ValueError, _INVALID_EMPTY_EVAL_RESULT_ERR):
executor.run_local()
def test_errors_out_if_evaluate_returns_non_dict(self):
3
Source : training_test.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def test_errors_out_if_evaluate_returns_non_dict(self):
est = estimator_lib.Estimator(
model_fn=self._model_fn,
config=run_config_lib.RunConfig(save_checkpoints_steps=2))
mock_est = tf.compat.v1.test.mock.Mock(
spec=estimator_lib.Estimator, wraps=est)
train_spec = training.TrainSpec(input_fn=self._input_fn)
eval_spec = training.EvalSpec(
input_fn=lambda: self._input_fn(repeat=False), throttle_secs=0)
mock_est.evaluate.return_value = 123
executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
with self.assertRaisesRegexp(TypeError, _INVALID_EVAL_RESULT_TYPE_ERR):
executor.run_local()
def test_errors_out_if_evaluate_returns_dict_without_global_step(self):
3
Source : training_test.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def test_errors_out_if_evaluate_returns_dict_without_global_step(self):
est = estimator_lib.Estimator(
model_fn=self._model_fn,
config=run_config_lib.RunConfig(save_checkpoints_steps=2))
mock_est = tf.compat.v1.test.mock.Mock(
spec=estimator_lib.Estimator, wraps=est)
train_spec = training.TrainSpec(input_fn=self._input_fn)
eval_spec = training.EvalSpec(
input_fn=lambda: self._input_fn(repeat=False), throttle_secs=0)
mock_est.evaluate.return_value = {'loss': 123}
executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
with self.assertRaisesRegexp(ValueError,
_MISSING_GLOBAL_STEP_IN_EVAL_RESULT_ERR):
executor.run_local()
def test_train_and_evaluate_return_metrics(self):
3
Source : training_test.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def test_invalid_task_type(self):
mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
mock_est.config = tf.compat.v1.test.mock.Mock()
mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)
mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)
mock_est.config = tf.compat.v1.test.mock.Mock()
mock_est.config.cluster_spec = tf.train.ClusterSpec({'1': ['dummy']})
mock_est.config.task_type = ''
executor = training._TrainingExecutor(mock_est, mock_train_spec,
mock_eval_spec)
with self.assertRaisesRegexp(ValueError, _INVALID_TASK_TYPE):
executor.run()
class TrainAndEvaluateIntegrationTest(tf.test.TestCase):
3
Source : driver_test.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def setUp(self):
super().setUp()
self._test_dir = os.path.join(
os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
self._testMethodName)
# Mock metadata and create driver.
self._mock_metadata = tf.compat.v1.test.mock.Mock()
self._file_based_driver = driver.FileBasedDriver(self._mock_metadata)
def testResolveExecProperties(self):
3
Source : publisher_test.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def setUp(self):
super().setUp()
self._mock_metadata = tf.compat.v1.test.mock.Mock()
self._mock_metadata.publish_execution = tf.compat.v1.test.mock.Mock()
self._output_dict = {
'output_data': [_OutputType()],
}
self._exec_properties = {'k': 'v'}
self._pipeline_info = data_types.PipelineInfo(
pipeline_name='my_pipeline', pipeline_root='/tmp', run_id='my_run_id')
self._component_info = data_types.ComponentInfo(
component_type='a.b.c',
component_id='my_component',
pipeline_info=self._pipeline_info)
def testPrepareExecutionComplete(self):
0
Source : training_test.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def test_error_out_if_evaluator_task_id_is_non_zero(self):
tf_config = {
'cluster': {
run_config_lib.TaskType.CHIEF: ['host0:0'],
},
'task': {
'type': run_config_lib.TaskType.EVALUATOR,
'index': 1
}
}
mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
mock_est.config = _create_run_config_with_cluster_spec(tf_config)
mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)
mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)
with self.assertRaisesRegexp(ValueError, _INVALID_EVAL_TASK_ID_ERR):
training.train_and_evaluate(mock_est, mock_train_spec, mock_eval_spec)
def test_invalid_estimator(self):
0
Source : training_test.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def test_train_with_train_spec(self, mock_server, unused_mock_sleep):
mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
mock_est.config = self._run_config
train_spec = training.TrainSpec(
input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()])
mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)
mock_server_instance = mock_server.return_value
executor = training._TrainingExecutor(mock_est, train_spec, mock_eval_spec)
self._run_task(executor)
mock_server.assert_called_with(
mock_est.config.cluster_spec,
job_name=mock_est.config.task_type,
task_index=mock_est.config.task_id,
config=tf.compat.v1.test.mock.ANY,
protocol=None,
start=False)
self.assertTrue(mock_server_instance.start.called)
mock_est.train.assert_called_with(
input_fn=train_spec.input_fn,
max_steps=train_spec.max_steps,
hooks=list(train_spec.hooks),
saving_listeners=tf.compat.v1.test.mock.ANY)
mock_est.evaluate.assert_not_called()
mock_est.export_saved_model.assert_not_called()
@tf.compat.v1.test.mock.patch.object(time, 'sleep')
0
Source : training_test.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def test_train_with_no_eval_spec(self, mock_server, unused_mock_sleep):
mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
mock_est.config = self._run_config
train_spec = training.TrainSpec(
input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()])
eval_spec = None
mock_server_instance = mock_server.return_value
executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
self._run_task(executor)
mock_server.assert_called_with(
mock_est.config.cluster_spec,
job_name=mock_est.config.task_type,
task_index=mock_est.config.task_id,
config=tf.compat.v1.test.mock.ANY,
protocol=None,
start=False)
self.assertTrue(mock_server_instance.start.called)
mock_est.train.assert_called_with(
input_fn=train_spec.input_fn,
max_steps=train_spec.max_steps,
hooks=list(train_spec.hooks),
saving_listeners=tf.compat.v1.test.mock.ANY)
mock_est.evaluate.assert_not_called()
mock_est.export_saved_model.assert_not_called()
@tf.compat.v1.test.mock.patch.object(time, 'sleep')
0
Source : training_test.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def test_single_worker_node_with_empty_tf_master(self, mock_server,
unused_mock_sleep):
mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
mock_train_spec = tf.compat.v1.test.mock.Mock(
spec=training.TrainSpec, hooks=[])
mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)
mock_est.config = tf.compat.v1.test.mock.PropertyMock(
spec=run_config_lib.RunConfig)
# Single node cluster.
mock_est.config.cluster_spec = tf.train.ClusterSpec({'worker': ['dummy']})
mock_est.config.master = ''
mock_est.config.task_type = 'worker'
mock_est.config.task_id = 2
self._run_task(
training._TrainingExecutor(mock_est, mock_train_spec, mock_eval_spec))
self.assertTrue(mock_est.train.called)
mock_server.assert_not_called()
def test_fail_with_empty_task_type(self):
0
Source : training_test.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def test_no_delay_for_master(self, _):
mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
mock_est.evaluate = lambda *args, **kw: {
tf.compat.v1.GraphKeys.GLOBAL_STEP: 123
}
mock_est.config = self._run_config
mock_train_spec = tf.compat.v1.test.mock.Mock(
spec=training.TrainSpec, max_steps=123, hooks=[])
mock_eval_spec = tf.compat.v1.test.mock.Mock(
spec=training.EvalSpec, exporters=[])
mock_train_spec.saving_listeners = tuple([])
executor = training._TrainingExecutor(mock_est, mock_train_spec,
mock_eval_spec)
with tf.compat.v1.test.mock.patch.object(time, 'sleep') as mock_sleep:
executor.run_master()
mock_sleep.assert_not_called()
@tf.compat.v1.test.mock.patch.object(time, 'sleep')
0
Source : training_test.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def test_train_with_train_spec(self, mock_server, unused_mock_sleep):
mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
mock_est.evaluate = lambda *args, **kw: {
tf.compat.v1.GraphKeys.GLOBAL_STEP: 123
}
mock_est.config = self._run_config
train_spec = training.TrainSpec(
input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()])
mock_eval_spec = tf.compat.v1.test.mock.Mock(
spec=training.EvalSpec, exporters=[])
mock_server_instance = mock_server.return_value
executor = training._TrainingExecutor(mock_est, train_spec, mock_eval_spec)
executor.run_master()
mock_server.assert_called_with(
mock_est.config.cluster_spec,
job_name=mock_est.config.task_type,
task_index=mock_est.config.task_id,
config=tf.compat.v1.test.mock.ANY,
protocol=None,
start=False)
self.assertTrue(mock_server_instance.start.called)
mock_est.train.assert_called_with(
input_fn=train_spec.input_fn,
max_steps=train_spec.max_steps,
hooks=list(train_spec.hooks),
saving_listeners=tf.compat.v1.test.mock.ANY)
mock_est.export_saved_model.assert_not_called()
@tf.compat.v1.test.mock.patch.object(time, 'sleep')
0
Source : training_test.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def test_train_with_train_hooks(self, mock_server, unused_mock_sleep):
mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
mock_est.evaluate = lambda *args, **kw: {
tf.compat.v1.GraphKeys.GLOBAL_STEP: 123
}
mock_est.config = self._run_config
train_spec = training.TrainSpec(
input_fn=lambda: 1, max_steps=2, hooks=[_FakeHook()])
mock_eval_spec = tf.compat.v1.test.mock.Mock(
spec=training.EvalSpec, exporters=[])
extra_hooks = [_FakeHook()]
executor = training._TrainingExecutor(
mock_est, train_spec, mock_eval_spec, train_hooks=extra_hooks)
executor.run_master()
mock_est.train.assert_called_with(
input_fn=train_spec.input_fn,
max_steps=train_spec.max_steps,
hooks=list(train_spec.hooks) + extra_hooks,
saving_listeners=tf.compat.v1.test.mock.ANY)
@tf.compat.v1.test.mock.patch.object(time, 'sleep')
0
Source : training_test.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def test_no_server_startup_in_google(self, mock_server, unused_mock_sleep):
mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
mock_est.evaluate = lambda *args, **kw: {
tf.compat.v1.GraphKeys.GLOBAL_STEP: 123
}
mock_est.config = self._run_config
mock_train_spec = tf.compat.v1.test.mock.Mock(
spec=training.TrainSpec, max_steps=123, hooks=[])
mock_eval_spec = tf.compat.v1.test.mock.Mock(
spec=training.EvalSpec, exporters=[])
mock_train_spec.saving_listeners = tuple([])
executor = training._TrainingExecutor(mock_est, mock_train_spec,
mock_eval_spec)
tf_config = {'TF_CONFIG': json.dumps(_TF_CONFIG_FOR_GOOGLE)}
with tf.compat.v1.test.mock.patch.dict('os.environ', tf_config):
executor.run_master()
mock_server.assert_not_called()
def test_fail_with_empty_cluster_spec(self):
0
Source : training_test.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def test_fail_with_empty_cluster_spec(self):
mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)
mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)
mock_est.config = tf.compat.v1.test.mock.PropertyMock(
spec=run_config_lib.RunConfig)
mock_est.config.cluster_spec = None
mock_est.config.master = 'grpc://...'
mock_est.config.task_type = 'master'
mock_est.config.task_id = 2
mock_train_spec.saving_listeners = tuple([])
with self.assertRaisesRegexp(RuntimeError,
_INVALID_CONFIG_FOR_STD_SERVER_MSG):
training._TrainingExecutor(mock_est, mock_train_spec,
mock_eval_spec).run_master()
def test_fail_with_empty_master(self):
0
Source : training_test.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def test_fail_with_empty_master(self):
mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)
mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)
mock_est.config = tf.compat.v1.test.mock.PropertyMock(
spec=run_config_lib.RunConfig)
mock_est.config.cluster_spec = tf.train.ClusterSpec({
'master': ['dummy'],
'worker': ['dummy1']
})
mock_est.config.master = ''
mock_est.config.task_type = 'master'
mock_est.config.task_id = 0
mock_train_spec.saving_listeners = tuple([])
with self.assertRaisesRegexp(RuntimeError,
_INVALID_CONFIG_FOR_STD_SERVER_MSG):
training._TrainingExecutor(mock_est, mock_train_spec,
mock_eval_spec).run_master()
@tf.compat.v1.test.mock.patch.object(time, 'sleep')
0
Source : training_test.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def test_single_master_node_with_empty_tf_master(self, mock_server,
unused_mock_sleep):
mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
mock_est.evaluate = lambda *args, **kw: {
tf.compat.v1.GraphKeys.GLOBAL_STEP: 123
}
mock_train_spec = tf.compat.v1.test.mock.Mock(
spec=training.TrainSpec, max_steps=123, hooks=[])
mock_eval_spec = tf.compat.v1.test.mock.Mock(
spec=training.EvalSpec, exporters=[])
mock_est.config = tf.compat.v1.test.mock.PropertyMock(
spec=run_config_lib.RunConfig)
mock_est.config.cluster_spec = tf.train.ClusterSpec({'master': ['dummy']})
mock_est.config.master = ''
mock_est.config.task_type = 'master'
mock_est.config.task_id = 0
mock_train_spec.saving_listeners = tuple([])
executor = training._TrainingExecutor(mock_est, mock_train_spec,
mock_eval_spec)
executor.run_master()
mock_server.assert_not_called()
self.assertTrue(mock_est.train.called)
def test_fail_with_empty_task_type(self):
0
Source : training_test.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def test_fail_with_empty_task_type(self):
mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)
mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)
mock_est.config = tf.compat.v1.test.mock.PropertyMock(
spec=run_config_lib.RunConfig)
mock_est.config.cluster_spec = tf.train.ClusterSpec({'master': ['dummy']})
mock_est.config.master = 'grpc://...'
mock_est.config.task_type = ''
mock_est.config.task_id = 2
mock_train_spec.saving_listeners = tuple([])
with self.assertRaisesRegexp(RuntimeError,
_INVALID_CONFIG_FOR_STD_SERVER_MSG):
training._TrainingExecutor(mock_est, mock_train_spec,
mock_eval_spec).run_master()
def test_fail_with_none_task_id(self):
0
Source : training_test.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def test_fail_with_none_task_id(self):
mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)
mock_eval_spec = tf.compat.v1.test.mock.Mock(spec=training.EvalSpec)
mock_est.config = tf.compat.v1.test.mock.PropertyMock(
spec=run_config_lib.RunConfig)
mock_est.config.cluster_spec = tf.train.ClusterSpec({'master': ['dummy']})
mock_est.config.master = 'grpc://...'
mock_est.config.task_type = 'master'
mock_est.config.task_id = None
mock_train_spec.saving_listeners = tuple([])
with self.assertRaisesRegexp(RuntimeError,
_INVALID_CONFIG_FOR_STD_SERVER_MSG):
training._TrainingExecutor(mock_est, mock_train_spec,
mock_eval_spec).run_master()
@tf.compat.v1.test.mock.patch.object(server_lib, 'Server')
0
Source : training_test.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def test_run_master_triggers_evaluate_and_export(self, _):
def estimator_train(saving_listeners, *args, **kwargs):
# There shalt be a saving_listener. Estimator is going to call
# `after_save`.
del args, kwargs
saving_listeners[0].begin()
saving_listeners[0].after_save(session=None, global_step_value=0)
saving_listeners[0].after_save(session=None, global_step_value=10)
mock_est = tf.compat.v1.test.mock.Mock(
spec=estimator_lib.Estimator, model_dir='path/', train=estimator_train)
mock_est.latest_checkpoint.return_value = 'checkpoint_path/'
mock_est.config = self._run_config
exporter = tf.compat.v1.test.mock.PropertyMock(spec=exporter_lib.Exporter)
exporter.name = 'see_whether_export_is_called'
train_spec = training.TrainSpec(input_fn=lambda: 1, max_steps=300)
eval_spec = training.EvalSpec(
input_fn=lambda: 1, steps=2, exporters=exporter)
eval_result = {_GLOBAL_STEP_KEY: train_spec.max_steps}
mock_est.evaluate.return_value = eval_result
executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
executor.run_master()
mock_est.evaluate.assert_called_with(
name=eval_spec.name,
input_fn=eval_spec.input_fn,
steps=eval_spec.steps,
checkpoint_path='checkpoint_path/',
hooks=eval_spec.hooks)
self.assertEqual(1, exporter.export.call_count)
exporter.export.assert_called_with(
estimator=mock_est,
export_path=os.path.join('path/', 'export', exporter.name),
checkpoint_path='checkpoint_path/',
eval_result=eval_result,
is_the_final_export=True)
@tf.compat.v1.test.mock.patch.object(basic_session_run_hooks,
0
Source : training_test.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def test_run_master_throttle_eval(self, _, mock_timer_class):
mock_est = tf.compat.v1.test.mock.Mock(
spec=estimator_lib.Estimator, model_dir='path/')
mock_timer = tf.compat.v1.test.mock.Mock()
mock_timer_class.return_value = mock_timer
def estimator_train(saving_listeners, *args, **kwargs):
del args, kwargs
saving_listeners[0].begin()
# Call four times.
mock_timer.should_trigger_for_step.return_value = True
saving_listeners[0].after_save(session=None, global_step_value=None)
mock_timer.should_trigger_for_step.return_value = True
saving_listeners[0].after_save(session=None, global_step_value=None)
mock_timer.should_trigger_for_step.return_value = False
saving_listeners[0].after_save(session=None, global_step_value=None)
mock_timer.should_trigger_for_step.return_value = True
saving_listeners[0].after_save(session=None, global_step_value=None)
mock_est.train = estimator_train
mock_est.latest_checkpoint.side_effect = ['ckpt1', 'ckpt2']
mock_est.config = self._run_config
exporter = tf.compat.v1.test.mock.PropertyMock(spec=exporter_lib.Exporter)
exporter.name = 'see_whether_export_is_called'
train_spec = training.TrainSpec(input_fn=lambda: 1, max_steps=300)
eval_spec = training.EvalSpec(
input_fn=lambda: 1, steps=2, exporters=exporter, throttle_secs=10)
mock_est.evaluate.side_effect = [{
_GLOBAL_STEP_KEY: train_spec.max_steps // 2
}, {
_GLOBAL_STEP_KEY: train_spec.max_steps
}]
executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
executor.run_master()
self.assertEqual(2, mock_est.evaluate.call_count)
self.assertEqual(2, exporter.export.call_count)
is_final_export_list = [
call[1]['is_the_final_export']
for call in exporter.export.call_args_list
]
self.assertEqual([False, True], is_final_export_list)
@tf.compat.v1.test.mock.patch.object(basic_session_run_hooks,
0
Source : training_test.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def test_run_master_throttle_eval_which_skips_final_ckpt(
self, _, mock_timer_class):
mock_est = tf.compat.v1.test.mock.Mock(
spec=estimator_lib.Estimator, model_dir='path/')
mock_timer = tf.compat.v1.test.mock.Mock()
mock_timer_class.return_value = mock_timer
def estimator_train(saving_listeners, *args, **kwargs):
del args, kwargs
saving_listeners[0].begin()
# Call tree times (one for first saving).
mock_timer.should_trigger_for_step.return_value = True
saving_listeners[0].after_save(session=None, global_step_value=0)
mock_timer.should_trigger_for_step.return_value = True
saving_listeners[0].after_save(session=None, global_step_value=125)
mock_timer.should_trigger_for_step.return_value = False
saving_listeners[0].after_save(session=None, global_step_value=250)
# At the end evaluate should be called even if throttle secs prevents it.
mock_timer.should_trigger_for_step.return_value = False
saving_listeners[0].end(session=None, global_step_value=300)
mock_est.train = estimator_train
mock_est.latest_checkpoint.side_effect = ['ckpt1', 'ckpt2']
mock_est.config = self._run_config
exporter = tf.compat.v1.test.mock.PropertyMock(spec=exporter_lib.Exporter)
exporter.name = 'see_whether_export_is_called'
train_spec = training.TrainSpec(input_fn=lambda: 1, max_steps=300)
eval_spec = training.EvalSpec(
input_fn=lambda: 1, steps=2, exporters=exporter, throttle_secs=10)
mock_est.evaluate.side_effect = [{
_GLOBAL_STEP_KEY: train_spec.max_steps // 2
}, {
_GLOBAL_STEP_KEY: train_spec.max_steps
}]
executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
executor.run_master()
self.assertEqual(2, mock_est.evaluate.call_count)
self.assertEqual(2, exporter.export.call_count)
is_final_export_list = [
call[1]['is_the_final_export']
for call in exporter.export.call_args_list
]
self.assertEqual([False, True], is_final_export_list)
class TrainingExecutorRunEvaluatorTest(tf.test.TestCase):
0
Source : training_test.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def test_evaluate_with_evaluate_spec(self):
mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
mock_est.latest_checkpoint.return_value = 'latest_it_is'
mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)
self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec)
eval_spec = training.EvalSpec(
input_fn=lambda: 1,
steps=2,
hooks=[_FakeHook()],
name='cont_eval',
start_delay_secs=0,
throttle_secs=0)
executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec)
executor.run_evaluator()
mock_est.evaluate.assert_called_with(
name='cont_eval',
input_fn=eval_spec.input_fn,
steps=eval_spec.steps,
checkpoint_path='latest_it_is',
hooks=eval_spec.hooks)
self.assertFalse(mock_est.train.called)
def test_evaluate_with_no_eval_spec_fails(self):
0
Source : training_test.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def test_evaluate_with_train_hooks(self):
mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
mock_est.latest_checkpoint.return_value = 'latest_it_is'
mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)
self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec)
eval_spec = training.EvalSpec(
input_fn=lambda: 1,
steps=2,
hooks=[_FakeHook()],
name='cont_eval',
start_delay_secs=0,
throttle_secs=0)
# The train_hooks will not be called during eval.
mock_hook = tf.compat.v1.test.mock.Mock(
spec=tf.compat.v1.train.SessionRunHook)
executor = training._TrainingExecutor(
mock_est, mock_train_spec, eval_spec, train_hooks=[mock_hook])
executor.run_evaluator()
mock_hook.begin.assert_not_called()
def test_evaluate_multiple_times(self):
0
Source : training_test.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def test_evaluate_multiple_times(self):
training_max_step = 200
mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
mock_est.model_dir = tf.compat.as_bytes(tf.compat.v1.test.get_temp_dir())
mock_est.evaluate.side_effect = [{
_GLOBAL_STEP_KEY: training_max_step // 2
}, {
_GLOBAL_STEP_KEY: training_max_step
}]
mock_est.latest_checkpoint.side_effect = ['path_1', 'path_2']
mock_train_spec = tf.compat.v1.test.mock.Mock(spec=training.TrainSpec)
mock_train_spec.max_steps = training_max_step
exporter = tf.compat.v1.test.mock.PropertyMock(spec=exporter_lib.Exporter)
exporter.name = 'see_how_many_times_export_is_called'
mock_est.times_export_was_called = 0
mock_est.times_final_export_was_called = 0
def export(estimator, export_path, checkpoint_path, eval_result,
is_the_final_export):
del export_path, checkpoint_path, eval_result
estimator.times_export_was_called += 1
# final_export is happened at the end.
self.assertEqual(0, estimator.times_final_export_was_called)
if is_the_final_export:
estimator.times_final_export_was_called += 1
exporter.export = export
eval_spec = training.EvalSpec(
input_fn=lambda: 1,
start_delay_secs=0,
throttle_secs=0,
exporters=exporter)
executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec)
executor.run_evaluator()
self.assertEqual(2, mock_est.evaluate.call_count)
self.assertEqual(2, mock_est.times_export_was_called)
self.assertEqual(1, mock_est.times_final_export_was_called)
def test_evaluate_listener_before_eval(self):
0
Source : training_test.py
with Apache License 2.0
from tensorflow
with Apache License 2.0
from tensorflow
def test_evaluate_listener_before_eval(self):
training_max_step = 200
mock_est = tf.compat.v1.test.mock.Mock(spec=estimator_lib.Estimator)
mock_est.model_dir = tf.compat.as_bytes(tf.compat.v1.test.get_temp_dir())
# Without early stopping, this eval will be run twice.
mock_est.evaluate.side_effect = [{
_GLOBAL_STEP_KEY: training_max_step // 2
}, {
_GLOBAL_STEP_KEY: training_max_step
}]
mock_est.latest_checkpoint.side_effect = ['path_1', 'path_2']
mock_train_spec = tf.compat.v1.test.mock.Mock(
spec=training.TrainSpec, hooks=[])
mock_train_spec.max_steps = training_max_step
class _Listener(training._ContinuousEvalListener):
def __init__(self):
self.call_count = 0
def before_eval(self):
self.call_count += 1
return self.call_count == 1
listener = _Listener()
eval_spec = training.EvalSpec(
input_fn=lambda: 1, start_delay_secs=0, throttle_secs=0)
training._TrainingExecutor(
mock_est, mock_train_spec, eval_spec,
continuous_eval_listener=listener).run_evaluator()
# Before_eval returns False during the second time, so, evaluate will be
# called once.
self.assertEqual(1, mock_est.evaluate.call_count)
self.assertEqual(2, listener.call_count)
def test_evaluate_listener_after_eval(self):
See More Examples