import tensorflow as tf
from m3tl.test_base import TestBase
from m3tl.input_fn import train_eval_input_fn
test_base = TestBase()
params = test_base.params
hidden_dim = params.bert_config.hidden_size
train_dataset = train_eval_input_fn(params=params)
one_batch = next(train_dataset.as_numpy_iterator())