3

I have some training data in a numpy array - it fits in the memory but it is bigger than 2GB. I'm using tf.keras and the dataset API. To give you a simplified, self-contained example:

import numpy as np
import tensorflow as tf
from tensorflow.keras import layers

model = tf.keras.Sequential([
    layers.Dense(64, activation='relu', input_shape=(32,)),
    layers.Dense(64, activation='relu'),
    layers.Dense(1)
])

model.compile(optimizer=tf.train.AdamOptimizer(0.001),
          loss='mse',
          metrics=['mae'])

# generate some big input datasets, bigger than 2GB
data = np.random.random((1024*1024*8, 32))
labels = np.random.random((1024*1024*8, 1))
val_data = np.random.random((100, 32))
val_labels = np.random.random((100, 1))

train_dataset = tf.data.Dataset.from_tensor_slices((data, labels))
train_dataset = train_dataset.batch(32).repeat()

val_dataset = tf.data.Dataset.from_tensor_slices((val_data, val_labels))
val_dataset = val_dataset.batch(32).repeat()

model.fit(train_dataset, epochs=10, steps_per_epoch=30,
      validation_data=val_dataset, validation_steps=3)

So, executing this results in an error "Cannot create a tensor proto whose content is larger than 2GB". The documentation lists a solution to this problem: https://www.tensorflow.org/guide/datasets#consuming_numpy_arrays - just use tf.placeholders and then feed_dict in session run.

Now the main question is: how to do this with tf.keras? I cannot feed anything for the placeholders when I call model.fit() and in fact when I introduced the placeholders I got errors saying "You must feed a value for placeholder tensor".

1 Answer 1

4

As with Estimator API, you can use from_generator

data_chunks = list(np.split(data, 1024))
labels_chunks = list(np.split(labels, 1024))

def genenerator():
    for i, j in zip(data_chunks, labels_chunks):
        yield i, j

train_dataset = tf.data.Dataset.from_generator(genenerator, (tf.float32, tf.float32))
train_dataset = train_dataset.shuffle().batch().repeat()

Also take a look https://github.com/tensorflow/tensorflow/issues/24520

Sign up to request clarification or add additional context in comments.

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.