import tensorflow as tf
from m3tl.test_base import TestBase
from m3tl.input_fn import train_eval_input_fn
test_base = TestBase()
params = test_base.params

hidden_dim = params.bert_config.hidden_size

train_dataset = train_eval_input_fn(params=params)
one_batch = next(train_dataset.as_numpy_iterator())

Imports and utils

Top Layer

class MaskLM[source]

MaskLM(*args, **kwargs) :: Model

Multimodal MLM top layer.

Get or make label encoder function

masklm_get_or_make_label_encoder_fn[source]

masklm_get_or_make_label_encoder_fn(params:BaseParams, problem:str, mode:str, label_list:List[str], *args, **kwargs)

Label handing function

masklm_label_handling_fn[source]

masklm_label_handling_fn(target, label_encoder=None, tokenizer=None, decoding_length=None, *args, **kwargs)