سلام.تو مثالی که براتون قرار دادم از دیتاست MNIST است داده های هر 3 بخش train,test,validation را در فایل tfrecord ذخیره می کنه و سپس جهت نمونه فقط داده های بخش train را بارگذاری می کنه و نمایش میده. مسیر دیتاست MNIST و مسیری که قراره فایل های tfrecord را در آن ایجاد کنه را مشخص کنید.
import tensorflow as tf
import numpy as np
from tensorflow.contrib.learn.python.learn.datasets import mnist
import os
import cv2
def intt64Feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def bytesFeature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def createMNIST2TFRecord(mnist_path,tf_path):
data_splits = ["train","test","validation"]
datasets = mnist.read_data_sets(mnist_path,dtype=tf.uint8,reshape=False,
validation_size=1000)
for d in range(len(data_splits)):
cur_split = data_splits[d]
print("saving "+cur_split)
dataset = datasets[d]
file_name = os.path.join(tf_path,cur_split+".tfrecord")
print(file_name)
writer = tf.python_io.TFRecordWriter(file_name)
images_count = dataset.images.shape[0]
for index in range(images_count) :
cur_img =dataset.images[index]
image = cur_img.tostring()
example = tf.train.Example(features=tf.train.Features(
feature={
'height':intt64Feature(dataset.images.shape[1]),
'width':intt64Feature(dataset.images.shape[2]),
'depth':intt64Feature(dataset.images.shape[3]),
'label':intt64Feature(dataset.labels[index]),
"raw_image":bytesFeature(image)
}
))
writer.write(example.SerializeToString())
writer.close()
def loadMNISTFromTFRecord(file_name):
data_iterator = tf.python_io.tf_record_iterator(file_name)
while True:
try:
example_serialized = next(data_iterator)
example = tf.train.Example()
example.ParseFromString(example_serialized)
width = example.features.feature['width'].int64_list.value[0]
height = example.features.feature['height'].int64_list.value[0]
depth = example.features.feature['depth'].int64_list.value[0]
label = example.features.feature['label'].int64_list.value[0]
image = example.features.feature['raw_image'].bytes_list.value
flat_image = np.fromstring(image[0],np.uint8)
reshaped_img = flat_image.reshape((height,width,-1))
cv2.imshow("view",reshaped_img)
cv2.waitKey(0)
except tf.errors.OutOfRangeError:
break
def main():
mnist_path = r"D:\Database\MNIST"
ft_path = r"D:\tf_test"
createMNIST2TFRecord(mnist_path,ft_path)
ft_train_file_name = os.path.join(ft_path,"train.tfrecord")
loadMNISTFromTFRecord(ft_train_file_name)
if __name__ == "__main__":
main()