A-A+

Tensorflow中数据集的使用方法(tf.data.Dataset)

2020年04月30日 11:34 汪洋大海 暂无评论 阅读 81 views 次

看此文章时建议配合【Tensorflow中数据集的使用方法(tf.data.Dataset)】一同观看。
使用Dataset管理数据集需要首先定义数据来源,我们可以使用numpy或者tensor定义的数据作为数据来源来定义Dataset,假设我们有如下numpy定义的代码。

1、引入必要的包

  1. import numpy as np
  2. import tensorflow as tf

2、使用numpy构造数据集

  1. seed = 1000 # 定义随机数产生的方式
  2. data_size = 10 # 数据集大小
  3. n_repeats = 10 # 数据集重复的次数,这个值就是平常我们见到的max_epoch
  4. batch_size = 6 # 批处理的大小
  5. np.random.seed(seed)
  6. # 在这里我们定义两个特征集合和一个标签集合,features1有三列特征,features2有4列特征,labels是0-2之间的一种
  7. features1 = np.random.random((data_size, 3))
  8. features2 = np.random.random((data_size, 4))
  9. labels = np.random.randint(0, 3, data_size)
  10. # 打印测试
  11. print(features1)
  12. print(features2)
  13. print(labels)

3、将numpy数据转换成Dataset

一般我们常使用tf.data.Dataset.from_tensor_slices方法加载数据。同时,Dataset提供了repeat()和batch()方法方便我们建立循环的数据,repeat参数给定一个整型值就可以使数据重复几份,而batch则是将数据以多少条进行批处理,也就是按照batch参数大小切割数据。

注意,repeat和batch的先后顺序不一样 ,结果是不同的,先repeat再batch会把数据先复制N份变成一个大数据,然后batch是根据这个大的数据来做的。例如,上面我们构造了10个数据,先repeat10份就有100个,再假设batch设置为6,那么最终数据是100/6+1=17份,那么也就是循环17次,如果先batch设置为6,那么数据先变成了10/6+1=2份,再repeat10次就有了20份数据了,循环要20次。这个一定要注意。

  1. # dataset = tf.data.Dataset.from_tensor_slices((features1, features2, labels)).repeat(10).batch(6)
  2. dataset = tf.data.Dataset.from_tensor_slices((features1, features2, labels)).batch(batch_size).repeat(n_repeats)

4、获取数据迭代器

数据准备完成之后需要获取数据迭代器供后面迭代使用,Tensorflow创建迭代器的方法有四种,其中单词迭代器和可初始化的迭代器是最常见的两种:

  1. # 单次迭代器只能循环使用一次数据,而且单次迭代器不需要手动显示调用sess.run()进行初始化即可使用
  2. iterator = dataset.make_one_shot_iterator()
  3. # 可初始化的迭代器可以重新初始化进行循环,但是需要手动显示调用sess.run()才能循环
  4. iterator = dataset.make_initializable_iterator()
  5. # 创建了迭代器之后,我们获取迭代器结果便于后面的运行,注意,这里不会产生迭代,只是建立tensorflow的计算图,因此不会消耗迭代
  6. next_element = iterator.get_next()

5、创建了迭代器之后就可以循环数据了

迭代器循环的停止通过捕获数据越界的错误进行

  1. count = 0
  2. with tf.Session() as sess:
  3. # 这是显示初始化,当我们的迭代器是dataset.make_initializable_iterator()的时候,才需要调用这个方法,否则不需要
  4. sess.run(iterator.initializer)
  5. # 无线循环数据,直到越界
  6. while True:
  7. try:
  8. features1_batch, features2_batch, labels_batch = sess.run(next_element)
  9. count += 1
  10. print(count)
  11. except tf.errors.OutOfRangeError:
  12. break

这里的count输出与上面repeat和batch的先后顺序有关,大家可以自己更换代码查看。

6、使用tqdm循环输出

除了上述捕获越界错误外,我们也可手动计算epoch循环次数和batch循环次数来确定终止的情况。可以配合tqdm包进行输出。tqdm是一个快速,可扩展的Python进度条,可以在 Python 长循环中添加一个进度提示信息,用户只需要封装任意的迭代器 tqdm(iterator)即可。我们先引入必要的包:

  1. import time
  2. from tqdm import trange

接下来我们使用自己计算的结果循环:

  1. # 注意,这里实例总数显然是先repeat再batch的结果,要根据实际情况做改变
  2. total_instances = data_size * n_repeats
  3. steps_per_epoch = data_size // batch_size if data_size / batch_size == 0 else data_size // batch_size + 1
  4. with tf.Session() as sess:
  5. sess.run(iterator.initializer)
  6. for epoch in range(n_repeats):
  7. tqr = trange(steps_per_epoch, desc="%2d" % (epoch + 1), leave=False)
  8. for _ in tqr:
  9. features1_batch, features2_batch, labels_batch = sess.run(next_element)
  10. # 由于这里循环没有计算过程,速度很快,看不到进度条,我们加了暂停0.5秒便于观察结果
  11. time.sleep(0.5)
  12. # 由于所有数据都已经循环完毕,如下代码将会报越界的错误,证明我们是对的
  13. sess.run(next_element)

我们可以看到如下的进度条:

以上就是Tensorflow中dataset的读取、循环使用的基本概念。

完整代码可以参考Github:https://github.com/df19900725/tensorflow_example

文章来源:https://www.datalearner.com/blog/1051556350245210

布施恩德可便相知重

微信扫一扫打赏

支付宝扫一扫打赏

×
标签:

给我留言