Imports

Write TFRecords

Serialize Functions

serialize_fn[source]

serialize_fn(features:dict, return_feature_desc=False)

Make TFRecord

make_tfrecord_local[source]

make_tfrecord_local(data_list, output_dir, serialize_fn, mode='train', example_per_file=100000, prefix='', **kwargs)

make tf record and return total number of records

make_tfrecord_pyspark[source]

make_tfrecord_pyspark(data_list, output_dir:str, serialize_fn:Callable, mode='train', example_per_file=100000, prefix='', **kwargs)

make tf record and return total number of records with pyspark

make_tfrecord[source]

make_tfrecord(data_list, output_dir, serialize_fn, mode='train', example_per_file=100000, prefix='', **kwargs)

Local write tfrecord test

Pyspark write tfrecord test

from pyspark import SparkContext, SparkConf
jar_path = '/data/m3tl/tmp/tensorflow-hadoop-1.10.0.jar'


conf = SparkConf().set('spark.jars', jar_path)

sc = SparkContext(conf=conf)
from m3tl.utils import set_is_pyspark
import tempfile
set_is_pyspark(True)
test_features_rdd = sc.parallelize([test_features]).coalesce(1)
pyspark_dir = tempfile.mkdtemp()
make_tfrecord(
    test_features_rdd, output_dir=test_base.tmpfiledir, serialize_fn=serialize_fn, pyspark_dir=pyspark_dir)
0
from m3tl.read_write_tfrecord import make_feature_desc
json_path = os.path.join(test_base.tmpfiledir, 'train_feature_desc.json')
feature_desc = make_feature_desc(json.load(open(json_path, 'r')))

tfrecord_path = os.path.join(pyspark_dir, 'train', 'part-r-00000')

tfr_dataset = tf.data.TFRecordDataset(tfrecord_path)


def _parse_fn(x):
    return tf.io.parse_single_example(x, feature_desc)


tfr_dataset = tfr_dataset.map(_parse_fn)

for i in tfr_dataset.take(1):
    assert np.all(tf.sparse.to_dense(i['float_array']).numpy(
    ) == np.array([4., 5., 6.], dtype='float32'))

Chain problems and write API

chain_processed_data[source]

chain_processed_data(problem_preproc_gen_dict:Dict[str, Iterator[T_co]])

write_tfrecord[source]

write_tfrecord(params:Params, replace=False)

Write TFRecord for every problem chunk

Output location: params.tmp_file_dir

Arguments: params {params} -- params

Keyword Arguments: replace {bool} -- Whether to replace existing tfrecord (default: {False})

set_is_pyspark(False)
test_base.params.assign_problem(
    'weibo_fake_ner&weibo_fake_cls|weibo_fake_multi_cls|weibo_masklm|weibo_premask_mlm', base_dir=test_base.tmpckptdir)
write_tfrecord(
    params=test_base.params, replace=True)
assert os.path.exists(os.path.join(
    test_base.tmpfiledir, 'weibo_fake_cls_weibo_fake_ner'))
assert os.path.exists(os.path.join(
    test_base.tmpfiledir, 'weibo_fake_multi_cls'))
2021-06-19 21:41:03.855 | WARNING  | m3tl.base_params:assign_problem:642 - base_dir and dir_name arguments will be deprecated in the future. Please use model_dir instead.
2021-06-19 21:41:03.856 | WARNING  | m3tl.base_params:prepare_dir:361 - bert_config not exists. will load model from huggingface checkpoint.
2021-06-19 21:41:03.935 | WARNING  | __main__:chain_processed_data:15 - Chaining problems with & may consume a lot of memory if data is not pyspark RDD.
2021-06-19 21:41:03.942 | DEBUG    | __main__:_write_fn:11 - Writing /tmp/tmpa5zj7lsf/weibo_fake_cls_weibo_fake_ner/train_00000.tfrecord
2021-06-19 21:41:03.973 | WARNING  | __main__:chain_processed_data:15 - Chaining problems with & may consume a lot of memory if data is not pyspark RDD.
2021-06-19 21:41:03.979 | DEBUG    | __main__:_write_fn:11 - Writing /tmp/tmpa5zj7lsf/weibo_fake_cls_weibo_fake_ner/eval_00000.tfrecord
2021-06-19 21:41:04.005 | DEBUG    | __main__:_write_fn:11 - Writing /tmp/tmpa5zj7lsf/weibo_fake_multi_cls/train_00000.tfrecord
2021-06-19 21:41:04.030 | DEBUG    | __main__:_write_fn:11 - Writing /tmp/tmpa5zj7lsf/weibo_fake_multi_cls/eval_00000.tfrecord
2021-06-19 21:41:04.110 | DEBUG    | __main__:_write_fn:11 - Writing /tmp/tmpa5zj7lsf/weibo_masklm/train_00000.tfrecord
2021-06-19 21:41:04.162 | DEBUG    | __main__:_write_fn:11 - Writing /tmp/tmpa5zj7lsf/weibo_masklm/eval_00000.tfrecord
2021-06-19 21:41:04.229 | DEBUG    | __main__:_write_fn:11 - Writing /tmp/tmpa5zj7lsf/weibo_premask_mlm/train_00000.tfrecord
2021-06-19 21:41:04.297 | DEBUG    | __main__:_write_fn:11 - Writing /tmp/tmpa5zj7lsf/weibo_premask_mlm/eval_00000.tfrecord

Read TFRecords

make_feature_desc[source]

make_feature_desc(feature_desc_dict:dict)

reshape_tensors_in_dataset[source]

reshape_tensors_in_dataset(example, feature_desc_dict:dict)

Reshape serialized tensor back to its original shape

Arguments: example {Example} -- Example

Returns: Example -- Example

add_loss_multiplier[source]

add_loss_multiplier(example, problem)

set_shape_for_dataset[source]

set_shape_for_dataset(example, feature_desc_dict)

get_dummy_features[source]

get_dummy_features(dataset_dict, feature_desc_dict)

Get dummy features. Dummy features are used to make sure every feature dict at every iteration has the same keys.

Example: problem A: {'input_ids': [1,2,3], 'A_label_ids': 1} problem B: {'input_ids': [1,2,3], 'B_label_ids': 2}

Then dummy features: {'A_label_ids': 0, 'B_label_ids': 0}

At each iteration, we sample a problem, let's say we sampled A Then: feature dict without dummy: {'input_ids': [1,2,3], 'A_label_ids': 1} feature dict with dummy: {'input_ids': [1,2,3], 'A_label_ids': 1, 'B_label_ids':0}

Arguments: dataset_dict {dict} -- dict of datasets of all problems

Returns: dummy_features -- dict of dummy tensors

add_dummy_features_to_dataset[source]

add_dummy_features_to_dataset(example, dummy_features)

Add dummy features to dataset

feature dict without dummy: {'input_ids': [1,2,3], 'A_label_ids': 1} feature dict with dummy: {'input_ids': [1,2,3], 'A_label_ids': 1, 'B_label_ids':0}

Arguments: example {data example} -- dataset example dummy_features {dict} -- dict of dummy tensors

read_tfrecord[source]

read_tfrecord(params:Params, mode:str)

Read and parse TFRecord for every problem

The returned dataset is parsed, reshaped, to_dense tensors with dummy features.

Arguments: params {params} -- params mode {str} -- mode, train, eval or predict

Returns: dict -- dict with keys: problem name, values: dataset

Local read tfrecord test

Pyspark read tfrecord test

NOTE: Test pyspark generated tfrecord

from m3tl.test_base import PysparkTestBase
pyspark_test_base = PysparkTestBase()

problem_chunk_str = pyspark_test_base.params.get_problem_chunk(as_str=True)[0]

write_tfrecord(params=pyspark_test_base.params, replace=True)
assert os.path.exists(os.path.join(
    pyspark_test_base.params.pyspark_output_path, problem_chunk_str, 'problem_info.txt'))
2021-06-19 21:41:07.551 | INFO     | m3tl.base_params:register_multiple_problems:538 - Adding new problem pyspark_fake_seq_tag, problem type: seq_tag
2021-06-19 21:41:07.552 | INFO     | m3tl.base_params:register_multiple_problems:538 - Adding new problem pyspark_fake_multi_cls, problem type: multi_cls
2021-06-19 21:41:07.552 | INFO     | m3tl.base_params:register_multiple_problems:538 - Adding new problem pyspark_fake_cls, problem type: cls
2021-06-19 21:41:07.553 | WARNING  | m3tl.base_params:assign_problem:642 - base_dir and dir_name arguments will be deprecated in the future. Please use model_dir instead.
2021-06-19 21:41:07.554 | WARNING  | m3tl.base_params:prepare_dir:361 - bert_config not exists. will load model from huggingface checkpoint.
2021-06-19 21:41:13.188 | INFO     | m3tl.utils:set_phase:478 - Setting phase to train