使用从Java程序中应用的Dropout恢复保存的模型

2020年9月8日 12点热度 0条评论

我有一个应用了辍学的经过预先训练的模型,我想从Java程序中恢复它。

对于我的应用程序,在推理步骤中,我需要打开输出,并多次重复向模型输入数据并获得一系列预测。

我做了什么:

加载模型并初始化会话

model = SavedModelBundle.load ("path_to_model", "serve");
sess  = model.session();

送入模型(重复多次,例如3次)

for (i = 0; i < 3; i++) 
    t_pred = sess.runner().feed("x", x).fetch("y").run().get(0);

假设:

第一次:获取数组A1 = [y1,y2,y3]

第二次:获取数组A2 = [z1,z2,z3]

...

我想要相同的推断,但A2与A1不同。

我知道辍学面具会随着时间而改变。

我想我需要“ seed”变量作为python API中的变量。但我找不到任何参考。

我试过的

为了获得相同的预测列表,我需要加载模型并多次初始化会话。

for (i = 0; i < 3; i++) 
    model = SavedModelBundle.load ("path_to_model", "serve");
    sess  = model.session();
    t_pred = sess.runner().feed("x", x).fetch("y").run().get(0);

但这不是最佳选择,因为它需要花费一些时间来加载模型并可能导致与内存相关的问题。

我该如何解决这个问题?

先感谢您!

解决方案如下:

最后,我解决了这个问题。

重新打开会话时,我认为该会话已重新初始化是错误的:Session s = modelBundle.session();

用涉及的图重新初始化。

byte[] metaGraph = Files.readAllBytes(Paths.get(save_path));
Graph g = new Graph();
Session sess = new Session(g);

但这会导致错误:

“尝试使用未初始化的值”

我通过更改将模型保存在python中的方式来修复了该错误。

以前,我使用过:

builder = tf.saved_model.builder.SavedModelBuilder(save_folder)
builder.add_meta_graph_and_variables(sess,[tf.saved_model.tag_constants.SERVING])
save_path = builder.save()

似乎没有保存种子(初始化为局部变量),然后导致模型未保存状态。

我将其更改为:

with tf.gfile.GFile(save_path, 'wb') as f:
   f.write(out_graph_def.SerializeToString())

而且效果很好^^。