用Java在apche_beam中编写tfrecords

2020年2月26日 58点热度 0条评论

如何在Java中编写以下代码?如果我有Java中的记录/字典列表,如何编写波束代码,将其写入tfrecord中,其中tf.train.Examples已序列化。
有很多使用python的示例,以下是python的一个示例,如何在Java中编写相同的逻辑?

import tensorflow as tf
import apache_beam as beam
from apache_beam.runners.interactive import interactive_runner
from apache_beam.coders import ProtoCoder

class Foo(beam.DoFn):
  def process(self, element, *args, **kwargs):
    import tensorflow as tf

    foo = element.get('foo')
    bar = element.get('bar')

    feature = {
      "foo":
        tf.train.Feature(bytes_list=tf.train.BytesList(value=[foo.encode('utf-8')])),
      "bar":
        tf.train.Feature(bytes_list=tf.train.BytesList(value=[bar.encode('utf-8')]))
    }
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    yield example_proto

p = beam.Pipeline(runner=interactive_runner.InteractiveRunner())

records = p | "Create records" >> beam.Create([{'foo': 'abc', 'bar': 'pqr'} for _ in range(10)])
tf_examples = records | "Convert to tf examples" >> beam.ParDo(Foo())
tf_examples | "Dump Records" >> beam.io.WriteToTFRecord(file_path_prefix="./output/data-",
                                                    coder=ProtoCoder(tf.train.Example()),
                                                    file_name_suffix='.tfrecord', num_shards=2)

p.run()

解决方案如下:

我已经尝试使用Java进行此操作,但是仍然遇到一些问题,新问题的链接位于Write tfrecords from beam pipeline?