我正在 tensorflow 中训练卷积模型。在对模型进行了大约 70 个 epochs 的训练后,大约需要 1.5 小时,我无法保存模型。它给了我ValueError: GraphDef cannot be larger than 2GB
。我发现随着训练的进行,图中节点的数量不断增加。
在 epoch 0、3、6、9,图中的节点数分别为 7214、7238、7262、7286。当我使用with tf.Session() as sess:
,而不是sess = tf.Session()
将会话作为 传递时,节点数分别为 3982、4006、4030、4054 个 epochs 0、3、6、9。
在这个答案中,据说随着节点被添加到图中,它可能会超过其最大大小。我需要帮助了解节点数量如何在我的图表中不断增加。
我使用以下代码训练我的模型:
def runModel(data):
'''
Defines cost, optimizer functions, and runs the graph
'''
X, y,keep_prob = modelInputs((755, 567, 1),4)
logits = cnnModel(X,keep_prob)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y), name="cost")
optimizer = tf.train.AdamOptimizer(.0001).minimize(cost)
correct_pred = tf.equal(tf.argmax(logits, 1), tf.argmax(y, 1), name="correct_pred")
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32), name='accuracy')
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
for e in range(12):
batch_x, batch_y = data.next_batch(30)
x = tf.reshape(batch_x, [30, 755, 567, 1]).eval(session=sess)
batch_y = tf.one_hot(batch_y,4).eval(session=sess)
sess.run(optimizer, feed_dict={X: x, y: batch_y,keep_prob:0.5})
if e%3==0:
n = len([n.name for n in tf.get_default_graph().as_graph_def().node])
print("No.of nodes: ",n,"\n")
current_cost = sess.run(cost, feed_dict={X: x, y: batch_y,keep_prob:1.0})
acc = sess.run(accuracy, feed_dict={X: x, y: batch_y,keep_prob:1.0})
print("At epoch {epoch:>3d}, cost is {a:>10.4f}, accuracy is {b:>8.5f}".format(epoch=e, a=current_cost, b=acc))
节点数量增加的原因是什么?
您正在训练循环中创建新节点。特别是,您正在调用tf.reshape
and tf.one_hot
,每个都会创建一个(或多个)节点。您可以:
我会推荐第二个,因为使用 TensorFlow 进行数据准备似乎没有任何好处。你可以有类似的东西:
import numpy as np
# ...
x = np.reshape(batch_x, [30, 755, 567, 1])
# ...
# One way of doing one-hot encoding with NumPy
classes_arr = np.arange(4).reshape([1] * batch_y.ndims + [-1])
batch_y = (np.expand_dims(batch_y, -1) == classes_arr).astype(batch_y.dtype)
# ...
PD:我还建议tf.Session()
在with
上下文管理器中使用,以确保close()
在最后调用其方法(除非您想稍后继续使用相同的会话)。
本文收集自互联网,转载请注明来源。
如有侵权,请联系[email protected] 删除。
我来说两句