Paddle中有各种各样的 Sampler:
paddle.io.BatchSampler
paddle.io.Sampler
paddle.io.WeightedRandomSampler
paddle.io.SequenceSampler
paddle.io.DistributedBatchSampler
paddle.io.RandomSampler
先说结论,这些 Sampler
都是用来给数据做采样的,但是为了效率我又不能直接把数据拿出来对吧,我做采样时,只把数据集中数据的索引拿出来即可,通过索引访问数据即可
来先看看最基本的 Sampler
1. Sampler
class Sampler:
def __init__(self, data_source=None):
self.data_source = data_source
def __iter__(self):
raise NotImplementedError
# Not define __len__ method in this base class here for __len__
# is not needed in same sence, e.g. paddle.io.IterableDataset
可以看到 Sampler
就是定义了一个抽象类,由于有 __iter__
那么他是一个迭代器,我们在继承它写自己的 Sampler
,一定要注意定义__len__
方法
为方便理解迭代器,这里稍微改一下写一个自己的类:
class Sampler:
def __init__(self, data_source=None): # data_source 就是 paddle.io.dataset 对象
self.data_source = data_source
def __iter__(self):
yield 1
yield 2
Sampler()
就是一个实例化的迭代器:
for i in Sampler():
print(i)
每次返回一个值:
1
2
那么接下来就清楚了,所有的 Sampler
都是迭代器,返回了数据的索引,继承自 Sampler
,要重写 __len__
和 __iter__
方法
SequenceSampler
趁热打铁,来看看 SequenceSampler 类
class SequenceSampler(Sampler):
def __init__(self, data_source):
self.data_source = data_source
def __iter__(self):
return iter(range(len(self.data_source)))
def __len__(self):
return len(self.data_source)
SequenceSampler 真的很简单,只是会返回一个内容为 [0, 1, 2, ..., len(data_source) -1]
的迭代器
举个例子,如果你的数据集长度为7
>>> dataset = you_data()
>>> print(len(dataset))
7
那么
for i in SequenceSampler(dataset):
print(i)
0
1
2
3
4
5
6
官方文档是这么说的:
顺序迭代 data_source 返回样本下标,即一次返回 [0, 1, 2, …, len(data_source) - 1]
更多推荐
Paddle 中的几个 Sampler 用法(一)
发布评论