Battle-Tested LLM Training: From Dataset to Data Iterator

If you find an interesting dataset (often from either Huggingface or TFDS nowadays) and you’d like to use it for LLM training, this post is for you! Specifically, I’ll be explaining the process that gradually turns a Huggingface dataset to an iterator that’s ready to feed model training with batches of data. Conceptually it takes four steps.

To make it concrete, I’ll use MaxText’s make_hf_iterator as my reference code, and choose openwebtext-10k as our input dataset.

load raw dataset

First, let’s load the raw openwebtext-10k dataset. If streaming is on as shown below, data files will not be downloaded. Instead, it streams the data progressively while iterating on the dataset.

from datasets import load_dataset

dataset = load_dataset("stas/openwebtext-10k", split="train", streaming=True)


At this stage, we need tokenize the raw dataset’s text field and trim the tokenized sequence up to predefined max_length. Practically, we’d first create a tokenizer either from a local file or from Huggingface via tokenizer_path. In the example below, we use t5-small tokenizer, which would be fetched from Huggingface directly.

# Sets some constants
add_bos, add_eos, max_length = True, True, 512
tokenizer_path = "t5-small"
data_column_name = "text"

# Creates a tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained(

Tokenization is then accomplished by running that tokenizer via dataset map function, _input_pipeline_utils.tokenization. This function applies the above tokenizer to the field data_column_name of each data example and truncates the tokens up to max_length.

from maxtext.MaxText.input_pipeline import _input_pipeline_utils

dataset =
    fn_kwargs={"hf_tokenizer": tokenizer, "max_length": max_length - 1, "column_name": data_column_name},
# Post-tokenization: renaming the field where the tokens are.
dataset = dataset.select_columns(["input_ids"]).rename_column("input_ids", data_column_name)

transform: pack and shift

After tokenization, data examples become token sequences of various lengths. To increase training efficiency, we try to pack as many sequences as possible into the context window (max_length in the code). Here we use grain’s experimental packing API PackAndBatchOperation.

In multi-host setting, each host (i.e., process) gets an equal share of the global batch size (say 512), and this input pipeline code runs at host-level in parallel, thus we want host-level batch size when we batch, i.e., global_batch_size // jax.process_count().

import grain.python as grain
# Sets some constants.
global_batch_size = 512

# Adds packing transformation.
transformations = []
# HFNormalizeFeatures makes two copies of `text` field: one is called
# `inputs` and the other `targets`.
        # In multi-host setting, each host (i.e., process) gets an equal share 
        # of the global batch size.
        # And this input pipeline runs at host-level in parallel, thus we want 
        # host-level batch size here.
        batch_size=global_batch_size // jax.process_count(),
        length_struct={"inputs": max_length, "targets": max_length},

# Post-packing: reformating tuple to flat dict style.

Finally we shift the inputs field by 1 token to the right, to make it ready for teacher-forcing training.



Now with all the transformations done, we need to tell each host how to sample from the transformed dataset. Most common settings include number of epochs (num_epochs), which shard of the dataset the current host should load (shard_options), whether to shuffle (shuffle), etc.

sampler = grain.IndexSampler(
        shard_index=dataloading_host_index, shard_count=dataloading_host_count, drop_remainder=False

put together

We put everything together with grain.DataLoader API, which takes in the raw dataset, training-required transformations and sampler. The returned dataloader is ready to produce batches the downstream training loop needs (iter(dataloader)).

dataloader = grain.DataLoader(
    read_options=grain.ReadOptions(num_threads=1, prefetch_buffer_size=128),

data_iter = iter(dataloader)
batch = next(data_iter)

Final words

Feel free to run and fork input_pipeline_data2iter.ipynb if you’d like to run a complete version of input pipeline. It’s worth noting that the returned batch sits in host CPU memory and so it’s necessary to further shard it across TPU devices before feeding the batch to pjitted train step. This could be done by MultiHostDataLoadIterator. If you’d like to know the details, this previous post could be of interest. If you’d like to run the input pipeline

