我有一个应用了辍学的经过预先训练的模型,我想从Java程序中恢复它。
对于我的应用程序,在推理步骤中,我需要打开输出,并多次重复向模型输入数据并获得一系列预测。
我做了什么:
加载模型并初始化会话
model = SavedModelBundle.load ("path_to_model", "serve");
sess = model.session();
送入模型(重复多次,例如3次)
for (i = 0; i < 3; i++)
t_pred = sess.runner().feed("x", x).fetch("y").run().get(0);
假设:
第一次:获取数组A1 = [y1,y2,y3]
第二次:获取数组A2 = [z1,z2,z3]
...
我想要相同的推断,但A2与A1不同。
我知道辍学面具会随着时间而改变。
我想我需要“ seed”变量作为python API中的变量。但我找不到任何参考。
我试过的
为了获得相同的预测列表,我需要加载模型并多次初始化会话。
for (i = 0; i < 3; i++)
model = SavedModelBundle.load ("path_to_model", "serve");
sess = model.session();
t_pred = sess.runner().feed("x", x).fetch("y").run().get(0);
但这不是最佳选择,因为它需要花费一些时间来加载模型并可能导致与内存相关的问题。
我该如何解决这个问题?
先感谢您!
解决方案如下:
最后,我解决了这个问题。
重新打开会话时,我认为该会话已重新初始化是错误的:Session s = modelBundle.session();
用涉及的图重新初始化。
byte[] metaGraph = Files.readAllBytes(Paths.get(save_path));
Graph g = new Graph();
Session sess = new Session(g);
但这会导致错误:
“尝试使用未初始化的值”
我通过更改将模型保存在python中的方式来修复了该错误。
以前,我使用过:
builder = tf.saved_model.builder.SavedModelBuilder(save_folder)
builder.add_meta_graph_and_variables(sess,[tf.saved_model.tag_constants.SERVING])
save_path = builder.save()
似乎没有保存种子(初始化为局部变量),然后导致模型未保存状态。
我将其更改为:
with tf.gfile.GFile(save_path, 'wb') as f:
f.write(out_graph_def.SerializeToString())
而且效果很好^^。