Minimal Example

In this example, we'll create train, eval and predict toy problems. But first, we need to what dose problem mean here. Essentially, a problem should have a name(string), a problem type(string), and a preprocessing function(callable). The following problem type is pre-defined:

for problem_type in params.list_available_problem_types():
    print('`{problem_type}`: {desc}'.format(
        desc=params.problem_type_desc[problem_type], problem_type=problem_type))
`cls`: Classification
`multi_cls`: Multi-Label Classification
`seq_tag`: Sequence Labeling
`masklm`: Masked Language Model
`pretrain`: NSP+MLM(Deprecated)
`regression`: Regression
`vector_fit`: Vector Fitting
`premask_mlm`: Pre-masked Masked Language Model
`contrastive_learning`: Contrastive Learning

Normally, you would want to use this library to do multi-task learning. There are two types of chaining operations can be used to chain problems.

  • &. If two problems have the same inputs, they can be chained using &. Problems chained by & will be trained at the same time.
  • |. If two problems don't have the same inputs, they need to be chained using |. Problems chained by | will be sampled to train at every instance. If your problem dose not fall in the pre-defined problem types, you can implement your own and register to params. We will cover this topic later. Let's start with a simple example of adding a classification problem and a sequence labeling problem.
problem_type_dict = {'toy_cls': 'cls', 'toy_seq_tag': 'seq_tag'}

Then we need to do some coding. We need to implement preprocessing function for each problem. The preprocessing function is a callable with

  • same name as problem name
  • fixed input signature
  • returns(or yield) inputs and targets
  • decorated by m3tl.preproc_decorator.preprocessing_fn
import m3tl
from m3tl.preproc_decorator import preprocessing_fn
from m3tl.params import Params
from m3tl.special_tokens import TRAIN
@preprocessing_fn
def toy_cls(params: Params, mode: str):
    "Simple example to demonstrate singe modal tuple of list return"
    if mode == TRAIN:
        toy_input = ['this is a test' for _ in range(10)]
        toy_target = ['a' if i <=5 else 'b' for i in range(10)]
    else:
        toy_input = ['this is a test' for _ in range(10)]
        toy_target = ['a' if i <=5 else 'b' for i in range(10)]
    return toy_input, toy_target

@preprocessing_fn
def toy_seq_tag(params: Params, mode: str):
    "Simple example to demonstrate singe modal tuple of list return"
    if mode == TRAIN:
        toy_input = ['this is a test'.split(' ') for _ in range(10)]
        toy_target = [['a', 'b', 'c', 'd'] for _ in range(10)]
    else:
        toy_input = ['this is a test'.split(' ') for _ in range(10)]
        toy_target = [['a', 'b', 'c', 'd'] for _ in range(10)]
    return toy_input, toy_target

processing_fn_dict = {'toy_cls': toy_cls, 'toy_seq_tag': toy_seq_tag}

Now we're good to go! Since these two toy problems shares the same input, we can chain them with &.

from m3tl.run_bert_multitask import train_bert_multitask, eval_bert_multitask, predict_bert_multitask
problem = 'toy_cls&toy_seq_tag'
# train
model = train_bert_multitask(
    problem=problem,
    num_epochs=1,
    problem_type_dict=problem_type_dict,
    processing_fn_dict=processing_fn_dict,
    continue_training=False
)
2021-06-24 16:07:25.768 | INFO     | m3tl.base_params:register_multiple_problems:543 - Adding new problem toy_cls, problem type: cls
2021-06-24 16:07:25.769 | INFO     | m3tl.base_params:register_multiple_problems:543 - Adding new problem toy_seq_tag, problem type: seq_tag
2021-06-24 16:07:25.770 | WARNING  | m3tl.base_params:prepare_dir:363 - bert_config not exists. will load model from huggingface checkpoint.
2021-06-24 16:07:27.851 | WARNING  | m3tl.read_write_tfrecord:chain_processed_data:258 - Chaining problems with & may consume a lot of memory if data is not pyspark RDD.
2021-06-24 16:07:27.853 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:514 - text: this is a test
2021-06-24 16:07:27.854 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:514 - text_modal_type: text
2021-06-24 16:07:27.854 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_input_ids: [101, 8554, 8310, 143, 10060, 102]
2021-06-24 16:07:27.855 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_mask: [1, 1, 1, 1, 1, 1]
2021-06-24 16:07:27.855 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_segment_ids: [0, 0, 0, 0, 0, 0]
2021-06-24 16:07:27.856 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - toy_cls_label_ids: 0
2021-06-24 16:07:27.862 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:514 - text: ['this', 'is', 'a', 'test']
2021-06-24 16:07:27.862 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:514 - text_modal_type: text
2021-06-24 16:07:27.863 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_input_ids: [101, 8554, 8310, 143, 10060, 102]
2021-06-24 16:07:27.863 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_mask: [1, 1, 1, 1, 1, 1]
2021-06-24 16:07:27.863 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_segment_ids: [0, 0, 0, 0, 0, 0]
2021-06-24 16:07:27.864 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - toy_seq_tag_label_ids: [0, 1, 2, 3, 4, 0]
2021-06-24 16:07:27.867 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:135 - Writing tmp/toy_cls_toy_seq_tag/train_00000.tfrecord
2021-06-24 16:07:27.898 | WARNING  | m3tl.read_write_tfrecord:chain_processed_data:258 - Chaining problems with & may consume a lot of memory if data is not pyspark RDD.
2021-06-24 16:07:27.899 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:514 - text: this is a test
2021-06-24 16:07:27.900 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:514 - text_modal_type: text
2021-06-24 16:07:27.900 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_input_ids: [101, 8554, 8310, 143, 10060, 102]
2021-06-24 16:07:27.901 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_mask: [1, 1, 1, 1, 1, 1]
2021-06-24 16:07:27.901 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_segment_ids: [0, 0, 0, 0, 0, 0]
2021-06-24 16:07:27.901 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - toy_cls_label_ids: 0
2021-06-24 16:07:27.905 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:514 - text: ['this', 'is', 'a', 'test']
2021-06-24 16:07:27.906 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:514 - text_modal_type: text
2021-06-24 16:07:27.906 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_input_ids: [101, 8554, 8310, 143, 10060, 102]
2021-06-24 16:07:27.907 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_mask: [1, 1, 1, 1, 1, 1]
2021-06-24 16:07:27.907 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_segment_ids: [0, 0, 0, 0, 0, 0]
2021-06-24 16:07:27.907 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - toy_seq_tag_label_ids: [0, 1, 2, 3, 4, 0]
2021-06-24 16:07:27.910 | DEBUG    | m3tl.read_write_tfrecord:_write_fn:135 - Writing tmp/toy_cls_toy_seq_tag/eval_00000.tfrecord
2021-06-24 16:07:28.601 | INFO     | m3tl.input_fn:train_eval_input_fn:59 - sampling weights: 
2021-06-24 16:07:28.602 | INFO     | m3tl.input_fn:train_eval_input_fn:60 - {
    "toy_cls_toy_seq_tag": 1.0
}
2021-06-24 16:07:28.750 | INFO     | m3tl.input_fn:train_eval_input_fn:59 - sampling weights: 
2021-06-24 16:07:28.751 | INFO     | m3tl.input_fn:train_eval_input_fn:60 - {
    "toy_cls_toy_seq_tag": 1.0
}
2021-06-24 16:07:28.943 | CRITICAL | m3tl.base_params:update_train_steps:456 - Updating train_steps to 1
2021-06-24 16:07:29.062 | INFO     | m3tl.input_fn:train_eval_input_fn:59 - sampling weights: 
2021-06-24 16:07:29.063 | INFO     | m3tl.input_fn:train_eval_input_fn:60 - {
    "toy_cls_toy_seq_tag": 1.0
}
WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce.
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)
Some layers from the model checkpoint at bert-base-chinese were not used when initializing TFBertModel: ['mlm___cls', 'nsp___cls']
- This IS expected if you are initializing TFBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
All the layers of TFBertModel were initialized from the model checkpoint at bert-base-chinese.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertModel for predictions without further training.
2021-06-24 16:07:33.091 | CRITICAL | m3tl.embedding_layer.base:__init__:58 - Modal Type id mapping: 
 {
    "text": 0
}
2021-06-24 16:07:33.278 | INFO     | m3tl.utils:set_phase:478 - Setting phase to infer
2021-06-24 16:07:33.289 | CRITICAL | m3tl.model_fn:compile:271 - Initial lr: 0.0
2021-06-24 16:07:33.290 | CRITICAL | m3tl.model_fn:compile:272 - Train steps: 1
2021-06-24 16:07:33.290 | CRITICAL | m3tl.model_fn:compile:273 - Warmup steps: 0
2021-06-24 16:07:33.496 | INFO     | m3tl.utils:set_phase:478 - Setting phase to train
The parameters `output_attentions`, `output_hidden_states` and `use_cache` cannot be updated when calling a model.They have to be set to True/False in the config object (i.e.: `config=XConfig.from_pretrained('name', output_attentions=True)`).
The parameter `return_dict` cannot be set in graph mode and will always be set to `True`.
2021-06-24 16:07:50.903 | INFO     | m3tl.utils:set_phase:478 - Setting phase to train
1/1 [==============================] - ETA: 0s - mean_acc: 0.9431 - toy_cls_acc: 0.6000 - toy_seq_tag_acc: 0.0857 - BertMultiTaskTop/toy_cls/losses/0: 0.7577 - BertMultiTaskTop/toy_seq_tag/losses/0: 2.3291
2021-06-24 16:08:04.196 | INFO     | m3tl.utils:set_phase:478 - Setting phase to eval
The parameters `output_attentions`, `output_hidden_states` and `use_cache` cannot be updated when calling a model.They have to be set to True/False in the config object (i.e.: `config=XConfig.from_pretrained('name', output_attentions=True)`).
The parameter `return_dict` cannot be set in graph mode and will always be set to `True`.
WARNING:tensorflow:Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches (in this case, 1000 batches). You may need to use the repeat() function when building your dataset.
1/1 [==============================] - 39s 39s/step - mean_acc: 0.9431 - toy_cls_acc: 0.6000 - toy_seq_tag_acc: 0.0857 - BertMultiTaskTop/toy_cls/losses/0: 0.7577 - BertMultiTaskTop/toy_seq_tag/losses/0: 2.3291 - val_loss: 3.1377 - val_mean_acc: 0.3000 - val_toy_cls_acc: 0.6000 - val_toy_seq_tag_acc: 0.0000e+00 - val_BertMultiTaskTop/toy_cls/losses/0: 0.6794 - val_BertMultiTaskTop/toy_seq_tag/losses/0: 2.4584
Model: "BertMultiTask"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
BertMultiTaskBody (BertMulti multiple                  102268416 
_________________________________________________________________
basic_mtl (BasicMTL)         multiple                  0         
_________________________________________________________________
BertMultiTaskTop (BertMultiT multiple                  5387      
_________________________________________________________________
sum_loss_combination (SumLos multiple                  0         
=================================================================
Total params: 102,273,805
Trainable params: 102,273,799
Non-trainable params: 6
_________________________________________________________________

For eval, we need to provide model_dir or model to the function. Please note that the unresolved object warning raised by tensorflow is expected since optimizer's states will not be initialized in evaluation and prediction.

# eval
eval_dict = eval_bert_multitask(problem=problem,
                    problem_type_dict=problem_type_dict, processing_fn_dict=processing_fn_dict,
                    model_dir=model.params.ckpt_dir)
2021-06-24 16:08:13.640 | INFO     | m3tl.base_params:register_multiple_problems:543 - Adding new problem toy_cls, problem type: cls
2021-06-24 16:08:13.641 | INFO     | m3tl.base_params:register_multiple_problems:543 - Adding new problem toy_seq_tag, problem type: seq_tag
2021-06-24 16:08:13.782 | INFO     | m3tl.input_fn:train_eval_input_fn:59 - sampling weights: 
2021-06-24 16:08:13.782 | INFO     | m3tl.input_fn:train_eval_input_fn:60 - {
    "toy_cls_toy_seq_tag": 1.0
}
2021-06-24 16:08:14.095 | INFO     | m3tl.input_fn:train_eval_input_fn:59 - sampling weights: 
2021-06-24 16:08:14.096 | INFO     | m3tl.input_fn:train_eval_input_fn:60 - {
    "toy_cls_toy_seq_tag": 1.0
}
WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce.
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)
2021-06-24 16:08:15.113 | CRITICAL | m3tl.embedding_layer.base:__init__:58 - Modal Type id mapping: 
 {
    "text": 0
}
2021-06-24 16:08:15.189 | INFO     | m3tl.utils:set_phase:478 - Setting phase to infer
The parameters `output_attentions`, `output_hidden_states` and `use_cache` cannot be updated when calling a model.They have to be set to True/False in the config object (i.e.: `config=XConfig.from_pretrained('name', output_attentions=True)`).
The parameter `return_dict` cannot be set in graph mode and will always be set to `True`.
2021-06-24 16:08:17.691 | CRITICAL | m3tl.model_fn:compile:271 - Initial lr: 0.0
2021-06-24 16:08:17.692 | CRITICAL | m3tl.model_fn:compile:272 - Train steps: 1
2021-06-24 16:08:17.693 | CRITICAL | m3tl.model_fn:compile:273 - Warmup steps: 0
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.
2021-06-24 16:08:18.193 | INFO     | m3tl.utils:set_phase:478 - Setting phase to eval
The parameters `output_attentions`, `output_hidden_states` and `use_cache` cannot be updated when calling a model.They have to be set to True/False in the config object (i.e.: `config=XConfig.from_pretrained('name', output_attentions=True)`).
The parameter `return_dict` cannot be set in graph mode and will always be set to `True`.
2/2 [==============================] - 8s 8ms/step - loss: 3.1377 - mean_acc: 0.3000 - toy_cls_acc: 0.6000 - toy_seq_tag_acc: 0.0000e+00 - BertMultiTaskTop/toy_cls/losses/0: 0.6794 - BertMultiTaskTop/toy_seq_tag/losses/0: 2.4584
print(eval_dict)
{'loss': 3.1377058029174805, 'mean_acc': 0.30000001192092896, 'toy_cls_acc': 0.6000000238418579, 'toy_seq_tag_acc': 0.0, 'BertMultiTaskTop/toy_cls/losses/0': 0.6793524026870728, 'BertMultiTaskTop/toy_seq_tag/losses/0': 2.458353281021118}
# predict
fake_inputs = ['this is a test'.split(' ') for _ in range(10)]
pred, model = predict_bert_multitask(
    problem=problem,
    inputs=fake_inputs, model_dir=model.params.ckpt_dir,
    problem_type_dict=problem_type_dict,
    processing_fn_dict=processing_fn_dict, return_model=True)
2021-06-24 16:08:26.800 | INFO     | m3tl.utils:set_phase:478 - Setting phase to infer
2021-06-24 16:08:26.801 | INFO     | m3tl.base_params:register_multiple_problems:543 - Adding new problem toy_cls, problem type: cls
2021-06-24 16:08:26.802 | INFO     | m3tl.base_params:register_multiple_problems:543 - Adding new problem toy_seq_tag, problem type: seq_tag
2021-06-24 16:08:26.822 | INFO     | m3tl.run_bert_multitask:predict_bert_multitask:464 - Checkpoint dir: models/toy_cls_toy_seq_tag_ckpt
2021-06-24 16:08:29.839 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:514 - text: ['this', 'is', 'a', 'test']
2021-06-24 16:08:29.840 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:514 - text_modal_type: text
2021-06-24 16:08:29.841 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_input_ids: [101, 8554, 8310, 143, 10060, 102]
2021-06-24 16:08:29.841 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_mask: [1, 1, 1, 1, 1, 1]
2021-06-24 16:08:29.841 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_segment_ids: [0, 0, 0, 0, 0, 0]
2021-06-24 16:08:29.891 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:514 - text: ['this', 'is', 'a', 'test']
2021-06-24 16:08:29.892 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:514 - text_modal_type: text
2021-06-24 16:08:29.892 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_input_ids: [101, 8554, 8310, 143, 10060, 102]
2021-06-24 16:08:29.893 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_mask: [1, 1, 1, 1, 1, 1]
2021-06-24 16:08:29.894 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_segment_ids: [0, 0, 0, 0, 0, 0]
2021-06-24 16:08:29.914 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:514 - text: ['this', 'is', 'a', 'test']
2021-06-24 16:08:29.915 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:514 - text_modal_type: text
2021-06-24 16:08:29.915 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_input_ids: [101, 8554, 8310, 143, 10060, 102]
2021-06-24 16:08:29.916 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_mask: [1, 1, 1, 1, 1, 1]
2021-06-24 16:08:29.916 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_segment_ids: [0, 0, 0, 0, 0, 0]
2021-06-24 16:08:30.835 | CRITICAL | m3tl.embedding_layer.base:__init__:58 - Modal Type id mapping: 
 {
    "text": 0
}
2021-06-24 16:08:30.907 | INFO     | m3tl.utils:set_phase:478 - Setting phase to infer
The parameters `output_attentions`, `output_hidden_states` and `use_cache` cannot be updated when calling a model.They have to be set to True/False in the config object (i.e.: `config=XConfig.from_pretrained('name', output_attentions=True)`).
The parameter `return_dict` cannot be set in graph mode and will always be set to `True`.
2021-06-24 16:08:33.300 | INFO     | m3tl.utils:set_phase:478 - Setting phase to infer
The parameters `output_attentions`, `output_hidden_states` and `use_cache` cannot be updated when calling a model.They have to be set to True/False in the config object (i.e.: `config=XConfig.from_pretrained('name', output_attentions=True)`).
The parameter `return_dict` cannot be set in graph mode and will always be set to `True`.
2021-06-24 16:08:39.912 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:514 - text: ['this', 'is', 'a', 'test']
2021-06-24 16:08:39.913 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:514 - text_modal_type: text
2021-06-24 16:08:39.913 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_input_ids: [101, 8554, 8310, 143, 10060, 102]
2021-06-24 16:08:39.914 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_mask: [1, 1, 1, 1, 1, 1]
2021-06-24 16:08:39.914 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_segment_ids: [0, 0, 0, 0, 0, 0]

pred is a dictionary with problem name as key and probability distribution array as value.

for problem_name, prob_array in pred.items():
    print(f'{problem_name} - {prob_array.shape}')
toy_cls - (10, 2)
toy_seq_tag - (10, 7, 5)

Use Different Models

By default, we use Bert as the base model. But thanks to transformers, it's easy to switch to any SOTA transformers models with some simple configuration and pass the params to train function as an argument.

# change model to distilbert-base-uncased
from m3tl.params import Params
params = Params()
# specify model and its loading module
params.transformer_model_name = 'distilbert-base-uncased'
params.transformer_model_loading = 'TFDistilBertModel'
# specify tokenizer and its loading module
params.transformer_tokenizer_name = 'distilbert-base-uncased'
params.transformer_tokenizer_loading = 'DistilBertTokenizer'
# specify config and its loading module
params.transformer_config_name = 'distilbert-base-uncased'
params.transformer_config_loading = 'DistilBertConfig'

Besides the "body" model, we can also set mtl model. By default, it will be hard parameter sharing, but we have implemented various mtl models. To see what's available, use

import json
print(json.dumps(params.list_available_mtl_setup(), indent=4))
{
    "available_mtl_model": [
        "basic",
        "mmoe"
    ],
    "available_problem_sampling_strategy": [],
    "available_loss_combination_strategy": [
        "sum"
    ],
    "available_gradient_surgery": []
}
# train model with mmoe
params.assign_mtl_model('mmoe')
model = train_bert_multitask(
    problem=problem,
    num_epochs=1,
    problem_type_dict=problem_type_dict,
    processing_fn_dict=processing_fn_dict,
    continue_training=False,
    params=params # pass params
)
2021-06-24 16:08:41.917 | INFO     | m3tl.base_params:register_multiple_problems:543 - Adding new problem toy_cls, problem type: cls
2021-06-24 16:08:41.918 | INFO     | m3tl.base_params:register_multiple_problems:543 - Adding new problem toy_seq_tag, problem type: seq_tag
2021-06-24 16:08:41.919 | WARNING  | m3tl.base_params:prepare_dir:363 - bert_config not exists. will load model from huggingface checkpoint.
2021-06-24 16:08:44.124 | INFO     | m3tl.input_fn:train_eval_input_fn:59 - sampling weights: 
2021-06-24 16:08:44.125 | INFO     | m3tl.input_fn:train_eval_input_fn:60 - {
    "toy_cls_toy_seq_tag": 1.0
}
2021-06-24 16:08:44.279 | INFO     | m3tl.input_fn:train_eval_input_fn:59 - sampling weights: 
2021-06-24 16:08:44.280 | INFO     | m3tl.input_fn:train_eval_input_fn:60 - {
    "toy_cls_toy_seq_tag": 1.0
}
2021-06-24 16:08:44.467 | CRITICAL | m3tl.base_params:update_train_steps:456 - Updating train_steps to 1
2021-06-24 16:08:44.589 | INFO     | m3tl.input_fn:train_eval_input_fn:59 - sampling weights: 
2021-06-24 16:08:44.589 | INFO     | m3tl.input_fn:train_eval_input_fn:60 - {
    "toy_cls_toy_seq_tag": 1.0
}
WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce.
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)
Some layers from the model checkpoint at distilbert-base-uncased were not used when initializing TFDistilBertModel: ['vocab_transform', 'vocab_projector', 'vocab_layer_norm', 'activation_13']
- This IS expected if you are initializing TFDistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFDistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
All the layers of TFDistilBertModel were initialized from the model checkpoint at distilbert-base-uncased.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFDistilBertModel for predictions without further training.
2021-06-24 16:08:47.617 | CRITICAL | m3tl.embedding_layer.base:__init__:58 - Modal Type id mapping: 
 {
    "text": 0
}
2021-06-24 16:08:47.695 | INFO     | m3tl.utils:set_phase:478 - Setting phase to infer
2021-06-24 16:08:47.702 | CRITICAL | m3tl.model_fn:compile:271 - Initial lr: 0.0
2021-06-24 16:08:47.703 | CRITICAL | m3tl.model_fn:compile:272 - Train steps: 1
2021-06-24 16:08:47.703 | CRITICAL | m3tl.model_fn:compile:273 - Warmup steps: 0
2021-06-24 16:08:47.874 | INFO     | m3tl.utils:set_phase:478 - Setting phase to train
The parameters `output_attentions`, `output_hidden_states` and `use_cache` cannot be updated when calling a model.They have to be set to True/False in the config object (i.e.: `config=XConfig.from_pretrained('name', output_attentions=True)`).
The parameter `return_dict` cannot be set in graph mode and will always be set to `True`.
2021-06-24 16:08:55.282 | INFO     | m3tl.utils:set_phase:478 - Setting phase to train
1/1 [==============================] - ETA: 0s - mean_acc: 0.7434 - toy_cls_acc: 0.4000 - toy_seq_tag_acc: 0.2714 - BertMultiTaskTop/toy_cls/losses/0: 0.6950 - BertMultiTaskTop/toy_seq_tag/losses/0: 1.6072
2021-06-24 16:09:05.034 | INFO     | m3tl.utils:set_phase:478 - Setting phase to eval
The parameters `output_attentions`, `output_hidden_states` and `use_cache` cannot be updated when calling a model.They have to be set to True/False in the config object (i.e.: `config=XConfig.from_pretrained('name', output_attentions=True)`).
The parameter `return_dict` cannot be set in graph mode and will always be set to `True`.
WARNING:tensorflow:Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches (in this case, 1000 batches). You may need to use the repeat() function when building your dataset.
1/1 [==============================] - 22s 22s/step - mean_acc: 0.7434 - toy_cls_acc: 0.4000 - toy_seq_tag_acc: 0.2714 - BertMultiTaskTop/toy_cls/losses/0: 0.6950 - BertMultiTaskTop/toy_seq_tag/losses/0: 1.6072 - val_loss: 2.3018 - val_mean_acc: 0.3429 - val_toy_cls_acc: 0.4000 - val_toy_seq_tag_acc: 0.2857 - val_BertMultiTaskTop/toy_cls/losses/0: 0.6948 - val_BertMultiTaskTop/toy_seq_tag/losses/0: 1.6070
Model: "BertMultiTask"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
BertMultiTaskBody (BertMulti multiple                  66363648  
_________________________________________________________________
m_mo_e (MMoE)                multiple                  799760    
_________________________________________________________________
BertMultiTaskTop (BertMultiT multiple                  907       
_________________________________________________________________
sum_loss_combination_3 (SumL multiple                  0         
=================================================================
Total params: 67,164,317
Trainable params: 67,164,311
Non-trainable params: 6
_________________________________________________________________

Write More Flexible Preprocessing Function

The most simple preprocessing function returns tuple of list, inputs and labels, as we shown above. However, inputs can get pretty complicated when doing multi-modal multi-task learning. In this case, we can use dictionary to store our data with some magic keys:

  • "inputs_" and "labels_" prefix. We still divide the preprocessing output into inputs and labels. By adding "inputs_" and "labels_" prefix to the dictionary keys, the module will correctly handle them in train, eval and predict.
  • "_modal_type" and "_modal_info" suffix. Adding these suffix will indicate the modal type of some inputs. If they're not provided, the module will try to infer the correct information from data.
  • i. If specified, this key will be used to join problems chained with &. It is required if any problems are chained with &.

Example:

from m3tl.predefined_problems.test_data import generate_fake_data
gen = generate_fake_data(output_format='gen_dict')
pprint.pprint(next(gen))
{'inputs_array': array([0.89512351, 0.89110354, 0.70502249, 0.23868364, 0.40018975,
       0.52657185, 0.87574078, 0.08114504, 0.93732932, 0.24289513]),
 'inputs_cate': 0,
 'inputs_cate_modal_info': 1,
 'inputs_cate_modal_type': 'category',
 'inputs_record_id': 0,
 'inputs_text': 'this is a test',
 'labels': 'a'}

Local Preprocessing

You can return a list of dictionary or a generator of dictionary from your preprocessing function.

from m3tl.utils import get_or_make_label_encoder
from m3tl.special_tokens import TRAIN
import inspect
params.num_cpus = 1
@preprocessing_fn
def toy_cls(params: Params, mode: str):
    # IMPORTANT!
    get_or_make_label_encoder(
        params=params,
        problem=inspect.currentframe().f_code.co_name, # current function name
        mode=mode,
        label_list=['a', 'b'],
        overwrite=True
    )
    return generate_fake_data(output_format='gen_dict')

params.register_problem(problem_name='toy_cls', problem_type='cls', processing_fn=toy_cls)

# then you can call the preproc function and take a look at the result
pprint.pprint(next(toy_cls(params, TRAIN)))
2021-06-24 16:09:15.822 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:514 - record_id: 0
2021-06-24 16:09:15.823 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:514 - text: this is a test
2021-06-24 16:09:15.824 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:514 - array: [0.96637881 0.49298023 0.38897724 0.36710049 0.36735467 0.28640803
 0.39647259 0.30369951 0.35238779 0.05860911]
2021-06-24 16:09:15.825 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:514 - cate: 0
2021-06-24 16:09:15.825 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:514 - cate_modal_type: category
2021-06-24 16:09:15.825 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:514 - cate_modal_info: 1
2021-06-24 16:09:15.826 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:514 - record_id_modal_type: category
2021-06-24 16:09:15.826 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:514 - text_modal_type: text
2021-06-24 16:09:15.827 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:514 - array_modal_type: array
2021-06-24 16:09:15.827 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - record_id: 0
2021-06-24 16:09:15.827 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_input_ids: [101, 2023, 2003, 1037, 3231, 102]
2021-06-24 16:09:15.828 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_mask: [1, 1, 1, 1, 1, 1]
2021-06-24 16:09:15.828 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - text_segment_ids: [0, 0, 0, 0, 0, 0]
2021-06-24 16:09:15.829 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - toy_cls_label_ids: 0
2021-06-24 16:09:15.829 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - array_input_ids: [[0.96637881 0.49298023 0.38897724 0.36710049 0.36735467 0.28640803
  0.39647259 0.30369951 0.35238779 0.05860911]]
2021-06-24 16:09:15.830 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - array_mask: [1]
2021-06-24 16:09:15.830 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - array_segment_ids: [0]
2021-06-24 16:09:15.831 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - cate_input_ids: [0]
2021-06-24 16:09:15.831 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - cate_mask: [1]
2021-06-24 16:09:15.832 | DEBUG    | m3tl.bert_preprocessing.create_bert_features:_create_multimodal_bert_features:519 - cate_segment_ids: [0]
{'array_input_ids': array([[0.96637881, 0.49298023, 0.38897724, 0.36710049, 0.36735467,
        0.28640803, 0.39647259, 0.30369951, 0.35238779, 0.05860911]]),
 'array_mask': [1],
 'array_segment_ids': array([0], dtype=int32),
 'cate_input_ids': array([0]),
 'cate_mask': [1],
 'cate_segment_ids': array([0], dtype=int32),
 'record_id': 0,
 'text_input_ids': [101, 2023, 2003, 1037, 3231, 102],
 'text_mask': [1, 1, 1, 1, 1, 1],
 'text_segment_ids': [0, 0, 0, 0, 0, 0],
 'toy_cls_label_ids': 0}

Pyspark preprocessing(experimental)

If your data is too huge to process locally, you can also return a pyspark RDD from your preprocessing function.

If two problems chained with & and they only share part of the inputs, returning RDD from preprocessing function is required.

from m3tl.utils import set_is_pyspark
import tempfile
set_is_pyspark(True)

@preprocessing_fn
def toy_cls(params: Params, mode: str):
    return generate_fake_data(output_format='rdd')

params.register_problem(problem_name='toy_cls', problem_type='cls', processing_fn=toy_cls)

# set pyspark output path
params.pyspark_output_path = tempfile.mkdtemp()

# then you can call the preproc function and take a look at the result
toy_cls_rdd = toy_cls(params, TRAIN)
pprint.pprint(toy_cls_rdd.collect()[0])
{'array_input_ids': array([[0.00258592, 0.20750642, 0.25051955, 0.85366103, 0.93457556,
        0.05154129, 0.48336023, 0.36393742, 0.40964549, 0.77414554]]),
 'array_mask': [1],
 'array_segment_ids': array([0], dtype=int32),
 'cate_input_ids': array([0]),
 'cate_mask': [1],
 'cate_segment_ids': array([0], dtype=int32),
 'record_id': 0,
 'text_input_ids': [101, 2023, 2003, 1037, 3231, 102],
 'text_mask': [1, 1, 1, 1, 1, 1],
 'text_segment_ids': [0, 0, 0, 0, 0, 0],
 'toy_cls_label_ids': 0}

What Happened?

The inputs returned by preprocessing function will be tokenized using transformers tokenizer which is configurable like we showed before and the labels will be encoded(or tokenized if the target is text) as scalar or numpy array. The encoded inputs and target then will be serialized and written as TFRecord. Please note that the TFRecord will NOT be overwritten even if you run the code again. So if you want to change the data in TFRecord, you need to manually remove the directory of TFRecord. The default directory is ./tmp/{problem_name}.

After the TFRecord is created, if you want to check the feature info, you can head to the corresponding directory and take a look at the json file within.

First, we make sure the TFRecord is created.

from m3tl.input_fn import train_eval_input_fn

dataset = train_eval_input_fn(params)

Below is the TFRecord directory tree.

tmp/
    toy_cls_toy_seq_tag/
        eval_00000.tfrecord
        train_feature_desc.json
        problem_info.txt
        eval_feature_desc.json
        train_00000.tfrecord

We can take a look at the json file.

import json
import os

# the problem chained by & create one TFRecord folder
json_path = os.path.join(params.tmp_file_dir, 'toy_cls_toy_seq_tag', 'train_feature_desc.json')
print(json.dumps(json.load(open(json_path, 'r', encoding='utf8')), indent=4))
{
    "text_input_ids": "int64",
    "text_input_ids_shape_value": [
        null
    ],
    "text_input_ids_shape": "int64",
    "text_mask": "int64",
    "text_mask_shape_value": [
        null
    ],
    "text_mask_shape": "int64",
    "text_segment_ids": "int64",
    "text_segment_ids_shape_value": [
        null
    ],
    "text_segment_ids_shape": "int64",
    "toy_cls_label_ids": "int64",
    "toy_cls_label_ids_shape": "int64",
    "toy_cls_label_ids_shape_value": [],
    "toy_seq_tag_label_ids": "int64",
    "toy_seq_tag_label_ids_shape_value": [
        null
    ],
    "toy_seq_tag_label_ids_shape": "int64"
}