Paddle中有各种各样的 Sampler:

  • paddle.io.BatchSampler
  • paddle.io.Sampler
  • paddle.io.WeightedRandomSampler
  • paddle.io.SequenceSampler
  • paddle.io.DistributedBatchSampler
  • paddle.io.RandomSampler

先说结论,这些 Sampler 都是用来给数据做采样的,但是为了效率我又不能直接把数据拿出来对吧,我做采样时,只把数据集中数据的索引拿出来即可,通过索引访问数据即可

来先看看最基本的 Sampler

1. Sampler

class Sampler:
    def __init__(self, data_source=None):
        self.data_source = data_source

    def __iter__(self):
        raise NotImplementedError

    # Not define __len__ method in this base class here for __len__
    # is not needed in same sence, e.g. paddle.io.IterableDataset

可以看到 Sampler 就是定义了一个抽象类,由于有 __iter__ 那么他是一个迭代器,我们在继承它写自己的 Sampler,一定要注意定义__len__方法

为方便理解迭代器,这里稍微改一下写一个自己的类:

class Sampler:
    def __init__(self, data_source=None): # data_source 就是 paddle.io.dataset 对象
        self.data_source = data_source

    def __iter__(self):
        yield 1
        yield 2

Sampler() 就是一个实例化的迭代器:

for i in Sampler():
    print(i)

每次返回一个值:

1
2

那么接下来就清楚了,所有的 Sampler 都是迭代器,返回了数据的索引,继承自 Sampler,要重写 __len____iter__ 方法

SequenceSampler

趁热打铁,来看看 SequenceSampler 类

class SequenceSampler(Sampler):

    def __init__(self, data_source):
        self.data_source = data_source

    def __iter__(self):
        return iter(range(len(self.data_source)))

    def __len__(self):
        return len(self.data_source)

SequenceSampler 真的很简单,只是会返回一个内容为 [0, 1, 2, ..., len(data_source) -1] 的迭代器

举个例子,如果你的数据集长度为7

>>> dataset = you_data()
>>> print(len(dataset))
7

那么

for i in SequenceSampler(dataset):
	print(i)
0
1
2
3
4
5
6

官方文档是这么说的:

顺序迭代 data_source 返回样本下标,即一次返回 [0, 1, 2, …, len(data_source) - 1]

更多推荐

Paddle 中的几个 Sampler 用法(一)