from pyspark import SparkContext, SparkConf
jar_path = '/data/m3tl/tmp/tensorflow-hadoop-1.10.0.jar'
conf = SparkConf().set('spark.jars', jar_path)
sc = SparkContext(conf=conf)
from m3tl.utils import set_is_pyspark
import tempfile
set_is_pyspark(True)
test_features_rdd = sc.parallelize([test_features]).coalesce(1)
pyspark_dir = tempfile.mkdtemp()
make_tfrecord(
test_features_rdd, output_dir=test_base.tmpfiledir, serialize_fn=serialize_fn, pyspark_dir=pyspark_dir)
from m3tl.read_write_tfrecord import make_feature_desc
json_path = os.path.join(test_base.tmpfiledir, 'train_feature_desc.json')
feature_desc = make_feature_desc(json.load(open(json_path, 'r')))
tfrecord_path = os.path.join(pyspark_dir, 'train', 'part-r-00000')
tfr_dataset = tf.data.TFRecordDataset(tfrecord_path)
def _parse_fn(x):
return tf.io.parse_single_example(x, feature_desc)
tfr_dataset = tfr_dataset.map(_parse_fn)
for i in tfr_dataset.take(1):
assert np.all(tf.sparse.to_dense(i['float_array']).numpy(
) == np.array([4., 5., 6.], dtype='float32'))
set_is_pyspark(False)
test_base.params.assign_problem(
'weibo_fake_ner&weibo_fake_cls|weibo_fake_multi_cls|weibo_masklm|weibo_premask_mlm', base_dir=test_base.tmpckptdir)
write_tfrecord(
params=test_base.params, replace=True)
assert os.path.exists(os.path.join(
test_base.tmpfiledir, 'weibo_fake_cls_weibo_fake_ner'))
assert os.path.exists(os.path.join(
test_base.tmpfiledir, 'weibo_fake_multi_cls'))
from m3tl.test_base import PysparkTestBase
pyspark_test_base = PysparkTestBase()
problem_chunk_str = pyspark_test_base.params.get_problem_chunk(as_str=True)[0]
write_tfrecord(params=pyspark_test_base.params, replace=True)
assert os.path.exists(os.path.join(
pyspark_test_base.params.pyspark_output_path, problem_chunk_str, 'problem_info.txt'))