from tensorflow\python\data\ops\iterator_ops.py
@staticmethod
def from_structure(output_types,
output_shapes=None,
shared_name=None,
output_classes=None):
"""Creates a new, uninitialized `Iterator` with the given structure.
This iterator-constructing method can be used to create an iterator that
is reusable with many different datasets.
The returned iterator is not bound to a particular dataset, and it has
no `initializer`. To initialize the iterator, run the operation returned by
`Iterator.make_initializer(dataset)`.
The following is an example
用给定的结构创建一个新的未初始化的Iterator。
此迭代器构造方法可用于创建可与许多不同数据集重复使用的迭代器。
返回的迭代器未绑定到特定的数据集,并且没有“ initializer”。 要初始化迭代器,请运行Iterator.make_initializer(dataset)返回的操作。
以下是一个例子
```python
iterator = Iterator.from_structure(tf.int64, tf.TensorShape([]))
dataset_range = Dataset.range(10)
range_initializer = iterator.make_initializer(dataset_range)
dataset_evens = dataset_range.filter(lambda x: x % 2 == 0)
evens_initializer = iterator.make_initializer(dataset_evens)
# Define a model based on the iterator; in this example, the model_fn
# is expected to take scalar tf.int64 Tensors as input (see
# the definition of 'iterator' above).
prediction, loss = model_fn(iterator.get_next())
# Train for `num_epochs`, where for each epoch, we first iterate over
# dataset_range, and then iterate over dataset_evens.
for _ in range(num_epochs):
# Initialize the iterator to `dataset_range`
sess.run(range_initializer)
while True:
try:
pred, loss_val = sess.run([prediction, loss])
except tf.errors.OutOfRangeError:
break
# Initialize the iterator to `dataset_evens`
sess.run(evens_initializer)
while True:
try:
pred, loss_val = sess.run([prediction, loss])
except tf.errors.OutOfRangeError:
break
```
Args:
output_types: A nested structure of `tf.DType` objects corresponding to
each component of an element of this dataset.
tf.DType对象的嵌套结构与该数据集元素的每个组成部分相对应。
output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects
corresponding to each component of an element of this dataset. If
omitted, each component will have an unconstrainted shape.
(可选。)“ tf.TensorShape”对象的嵌套结构,与该数据集的元素的每个组成部分相对应。
如果省略,则每个组件将具有不受约束的形状
。
shared_name: (Optional.) If non-empty, this iterator will be shared under
the given name across multiple sessions that share the same devices
(e.g. when using a remote server).
(可选。)如果为非空,则该迭代器将在共享相同设备的多个会话之间以给定名称共享(例如,在使用远程服务器时)。
output_classes: (Optional.) A nested structure of Python `type` objects
corresponding to each component of an element of this iterator. If
omitted, each component is assumed to be of type `tf.Tensor`.
(可选。)Python“类型”对象的嵌套结构,与该迭代器的元素的每个组件相对应。 如果省略,则假定每个组件的类型均为“ tf.Tensor”。
Returns:
An `Iterator`.
一个`迭代器`。
Raises:
TypeError: If the structures of `output_shapes` and `output_types` are
not the same.
TypeError:如果output_shapes和output_types的结构不同。
"""
output_types = nest.map_structure(dtypes.as_dtype, output_types)
if output_shapes is None:
output_shapes = nest.map_structure(
lambda _: tensor_shape.TensorShape(None), output_types)
else:
output_shapes = nest.map_structure_up_to(
output_types, tensor_shape.as_shape, output_shapes)
if output_classes is None:
output_classes = nest.map_structure(lambda _: ops.Tensor, output_types)
nest.assert_same_structure(output_types, output_shapes)
if shared_name is None:
shared_name = ""
if compat.forward_compatible(2018, 8, 3):
if _device_stack_is_empty():
with ops.device("/cpu:0"):
iterator_resource = gen_dataset_ops.iterator_v2(
container="",
shared_name=shared_name,
output_types=nest.flatten(
sparse.as_dense_types(output_types, output_classes)),
output_shapes=nest.flatten(
sparse.as_dense_shapes(output_shapes, output_classes)))
else:
iterator_resource = gen_dataset_ops.iterator_v2(
container="",
shared_name=shared_name,
output_types=nest.flatten(
sparse.as_dense_types(output_types, output_classes)),
output_shapes=nest.flatten(
sparse.as_dense_shapes(output_shapes, output_classes)))
else:
iterator_resource = gen_dataset_ops.iterator(
container="",
shared_name=shared_name,
output_types=nest.flatten(
sparse.as_dense_types(output_types, output_classes)),
output_shapes=nest.flatten(
sparse.as_dense_shapes(output_shapes, output_classes)))
return Iterator(iterator_resource, None, output_types, output_shapes,
output_classes)
上面文档中的示例
iterator = Iterator.from_structure(tf.int64, tf.TensorShape([]))
dataset_range = Dataset.range(10)
range_initializer = iterator.make_initializer(dataset_range)
dataset_evens = dataset_range.filter(lambda x: x % 2 == 0)
evens_initializer = iterator.make_initializer(dataset_evens)
# Define a model based on the iterator; in this example, the model_fn
# is expected to take scalar tf.int64 Tensors as input (see
# the definition of 'iterator' above).
# 根据迭代器定义模型; 在此示例中,model_fn预期将标量tf.int64张量用作输入
#(请参见上面的“迭代器”的定义)。
prediction, loss = model_fn(iterator.get_next())
# Train for `num_epochs`, where for each epoch, we first iterate over
# dataset_range, and then iterate over dataset_evens.
# 训练“ num_epochs”,其中对于每个时期,我们首先遍历dataset_range,然后遍历dataset_evens。
for _ in range(num_epochs):
# Initialize the iterator to `dataset_range`
sess.run(range_initializer)
while True:
try:
pred, loss_val = sess.run([prediction, loss])
except tf.errors.OutOfRangeError:
break
# Initialize the iterator to `dataset_evens`
# 将迭代器初始化为`dataset_evens`
sess.run(evens_initializer)
while True:
try:
pred, loss_val = sess.run([prediction, loss])
except tf.errors.OutOfRangeError:
break
更多推荐
tensorflow tf.data.Iterator.from_structure()(用给定的结构创建一个新的未初始化的迭代器Iterator)
发布评论