Tech ONTAP Blogs

DLIO: An Approach to Overcome Storage Benchmarking Challenges for Deep Learning Workloads

RodrigoNascimento
NetApp
173 Views

In the first post of our series, we explored the AI/ML workflow through the lens of a Medallion Data Architecture. We explained our rationale to identify the key stages of the pipeline to target for storage benchmarking. 

 

In this post, we introduce DLIO, a benchmarking tool purpose-built to simulate the I/O patterns of Deep Learning (DL) workloads. We'll walk you through its capabilities and how it enables storage benchmarking without the need for using any AI hardware.

 

Deep Learning I/O (DLIO)

 

DLIO is a benchmark tool to emulate the I/O pattern and behavior of deep learning applications [1a]. It was designed to emulate the AI/ML training process with the intent to measure how fast data is served from storage to RAM.

 

During the training process, data is loaded in batches concurrently through multiple threads while accelerators execute training. After processing each batch, the accelerator triggers a request to the host, prompting the loading of another batch from storage. This iterative cycle guarantees uninterrupted data processing, contributing to the efficiency of the training process [1b].

 

Many new AI hardware (e.g. GPU, DPU, TPU, Cerebras, etc.) have been designed and deployed to accelerate the computation during training [2]. That hardware is not cheap, but the good news is that you don't need any AI hardware to run DLIO and benchmark your storage solution for your AI/ML pipeline.

 

Deep learning frameworks like PyTorch and TensorFlow provide an abstraction called data loader, which simplifies key aspects of the data handling such as batching, shuffling, and parallel data loading.

 

When you iterate over a data loader instance, it triggers I/O operations - this is when the data loader opens files, reads samples, and prepares them for processing. Once the data is transferred to the GPU, the computation phase begins, including forward and back propagation. Interestingly, during this computation phase, no I/O operation related to training occur.

 

Therefore, if you wanted to measure how efficiently your storage solution delivers data to the GPU, you should focus on measuring the performance of the data loading mechanism specifically. The authors of DLIO recognized this pattern and came up with an elegant solution shown in Figure 1: replacing the computation stage with a sleep function.

 

The duration of the sleep function should match the time a specific GPU model takes to perform the forward and back propagation when training a given model. This approach allows researchers to isolate and accurately measure the performance of the data loading stage without the need to invest in GPU hardware.

 

Figure 1. DLIO solution for storage benchmark. Adapted from [1b] with modifications.Figure 1. DLIO solution for storage benchmark. Adapted from [1b] with modifications.

 

The DLIO benchmark achieves over 90% similarity in I/O behavior. This similarity validates that DLIO benchmark is an accurate representation of real applications. The loss of 3-6% similarity is because all applications have a distribution of transfer request sizes, which is represented as a median request size within the benchmark [2].

 

DLIO includes a variety of deep learning workload examples, such as UNET-3D, Cosmoflow, ResNet50, and LLaMA 3. It also supports the creation of customized workloads through a flexible configuration system.

Let's take a closer look at DLIO. Let me show you the steps I followed to get it working on my Ubuntu 22.04 virtual machine.
 

DLIO Installation Steps

 
I began by setting up a virtual machine running Ubuntu Server, opting for the minimal installation to keep the environment lightweight. I'm currently using Ubuntu 22.04, which includes Python 3.10.12 by default. As of this writing, Python 3.10.12 is the required version for installing DLIO without any compatibility issues. Once your VM is ready, you need to follow the steps outlined below.
 
1. Begin by installing the OS packages required by DLIO. Pay special attention to the MPI package. Based on my experience, MPICH tends to be more straightforward to work with compared to OpenMPI.
 
sudo apt install -y build-essential git vim sysstat cmake libhdf5-dev hwloc libhwloc-dev mpich libmpich-dev bc​
 
2. Clone the DLIO repository.
 
git clone https://github.com/argonne-lcf/dlio_benchmark.git

 

3. Install the python modules required by DLIO:

 

pip3 install -r dlio_benchmark/requirements.txt

 

4. To avoid some warning messages thrown out by TensorFlow when running a workload, install the following package:

 

pip3 install tensorflow-cpu

 

5. Change to the DLIO directory and run install the dlio_benchmark:

 

cd dlio_benchmark ; pip3 install .

 

6. Run the dlio_benchmark command to test if your installation has been successful:

 

mpirun -np 8 dlio_benchmark workload=unet3d_a100 ++workload.workflow.generate_data=True ++workload.workflow.train=False

 

if you run dlio_benchmark and encounter an error indicating that the shared library libmpi.so.12 is missing, execute the command below and try again:

 

cd /lib/x86_64-linux-gnu ; ln -s libmpich.so.12 libmpi.so.12

 

 Next, let me show you how DLIO works its magic to measure storage performance for deep learning workloads. From loading the datasets to faking the computation stage.

 

DLIO Execution Flow

 

DLIO begins by initializing the MPI stack via the DLIOMPI.get_instance().initialize() method.

 

# dlio_benchmark/main.py

def main() -> None:
    """
    The main method to start the benchmark runtime.
    """
    DLIOMPI.get_instance().initialize()
    run_benchmark()
    DLIOMPI.get_instance().finalize()

 

The DLIOMPI.initialize() method sets up the MPI environment by calling MPI.Init(), updates the MPI state to MPIState.MPI_INITIALIZE, and opens the MPI.COMM_WORLD communicator, which encompasses all participating processes.

 

# dlio_benchmark/utils/utility.py

class DLIOMPI:
    ...
    def initialize(self):
        from mpi4py import MPI
        if self.mpi_state == MPIState.UNINITIALIZED:
            # MPI may have already been initialized by dlio_benchmark_test.py
            if not MPI.Is_initialized():
                MPI.Init()
            
            self.mpi_state = MPIState.MPI_INITIALIZED
            self.mpi_rank = MPI.COMM_WORLD.rank
            self.mpi_size = MPI.COMM_WORLD.size
            self.mpi_world = MPI.COMM_WORLD
            split_comm = MPI.COMM_WORLD.Split_type(MPI.COMM_TYPE_SHARED)
            # Get the number of nodes
            self.mpi_ppn = split_comm.size
            self.mpi_local_rank = split_comm.rank
            self.mpi_nodes = self.mpi_size//split_comm.size
        elif self.mpi_state == MPIState.CHILD_INITIALIZED:
            raise Exception(f"method {self.classname()}.initialize() called in a child process")
        else:
            pass    # redundant call

 

Next, the run_benchmark() function is invoked, which instantiates a DLIOBenchmark object using a workload configuration. This configuration defines parameters including the directory to stored training and checkpoint files, the number of files for training, the batch size, among other options needed for setting up a training workload. The benchmark is then executed through a sequence of method calls: initialize(), run(), finalize().

 

# dlio_benchmark/main.py

@hydra.main(version_base=None, config_path="configs", config_name="config")
def run_benchmark(cfg: DictConfig):    
    benchmark = DLIOBenchmark(cfg['workload'])
    benchmark.initialize()
    benchmark.run()
    benchmark.finalize()

 

The run() method coordinates the training process across all epochs. For each epoch, it prepares the dataset for reading, performs training, and records execution stats using the StatsCounter class via the stats property of the benchmark object.

 

Training is initiated by the line steps = self._train(epoch). To understand the training execution in detail, let's examine the _train(self, epoch) method.

 

# dlio_benchmark/main.py
...
class DLIOBenchmark:
    ...
    @dlp.log
    def run(self):
        ...
        if (not self.generate_only) and (not self.args.checkpoint_only):
            ...
            for epoch in range(1, self.epochs + 1):
                self.stats.start_epoch(epoch)
                self.next_checkpoint_step = self.steps_between_checkpoints
                self.stats.start_train(epoch)
                steps = self._train(epoch)
                self.stats.end_train(epoch, steps)
                self.logger.debug(f"{utcnow()} Rank {self.my_rank} returned after {steps} steps.")
                self.framework.get_loader(DatasetType.TRAIN).finalize()
                # Perform evaluation if enabled
                if self.do_eval and epoch >= next_eval_epoch:
                    next_eval_epoch += self.epochs_between_evals
                    self.stats.start_eval(epoch)
                    self._eval(epoch)
                    self.stats.end_eval(epoch)
                    self.framework.get_loader(DatasetType.VALID).finalize()
                self.args.reconfigure(epoch + 1) # reconfigure once per epoch
                self.stats.end_epoch(epoch)

        if (self.args.checkpoint_only):
            self._checkpoint()            
        self.stats.end_run()

 

The data is loaded in batches via the for batch in loader.next(): loop. The interesting part here is how the training computation is simulated using a sleep function. This simulation begins with the call to self.framework.compute(batch, epoch, block_step, self.computation_time).

 

# dlio_benchmark/main.py 
...
class DLIOBenchmark:
...
    def _train(self, epoch):
        """
        Training loop for reading the dataset and performing training computations.
        :return: returns total steps.
        """
        block = 1  # A continuous period of training steps, ended by checkpointing
        block_step = overall_step = 1  # Steps are taken within blocks
        max_steps = math.floor(self.num_samples * self.num_files_train / self.batch_size / self.comm_size)
        self.steps_per_epoch = max_steps
        # Start the very first block
        self.stats.start_block(epoch, block)
        loader = self.framework.get_loader(dataset_type=DatasetType.TRAIN)
        self.stats.start_loading()
        for batch in loader.next():
            self.stats.batch_loaded(epoch, overall_step, block)
            computation_time = self.args.computation_time
            if (isinstance(computation_time, dict) and len(computation_time) > 0) or (isinstance(computation_time, float) and  computation_time > 0):
                self.framework.trace_object("Train", overall_step, 1)
            self.stats.start_compute()
            self.framework.compute(batch, epoch, block_step, self.computation_time)
            self.stats.batch_processed(epoch, overall_step, block)
            # This is the barrier to simulate allreduce. It is required to simulate the actual workloads.
            self.comm.barrier()
            if self.do_checkpoint and (
                    self.steps_between_checkpoints >= 0) and overall_step == self.next_checkpoint_step:
                self.stats.end_block(epoch, block, block_step)
                self.stats.start_save_ckpt(epoch, block, overall_step)
                self.checkpointing_mechanism.save_checkpoint(epoch, overall_step)
                self.stats.end_save_ckpt(epoch, block)
                block += 1
                # Reset the number of steps after every checkpoint to mark the start of a new block
                block_step = 1
                self.next_checkpoint_step += self.steps_between_checkpoints
            else:
                block_step += 1
            overall_step += 1
            if overall_step > max_steps or ((self.total_training_steps > 0) and (overall_step > self.total_training_steps)):
                if self.args.my_rank == 0:
                    self.logger.info(f"{utcnow()} Maximum number of steps reached")
                if (block_step != 1 and self.do_checkpoint) or (not self.do_checkpoint):
                    self.stats.end_block(epoch, block, block_step - 1)
                break
            # start a new block here
            if block_step == 1 and block != 1:
                self.stats.start_block(epoch, block)
            self.stats.start_loading()

        self.comm.barrier()
        if self.do_checkpoint and (self.steps_between_checkpoints < 0) and (epoch == self.next_checkpoint_epoch):
            self.stats.end_block(epoch, block, block_step-1)
            self.stats.start_save_ckpt(epoch, block, overall_step-1)
            self.checkpointing_mechanism.save_checkpoint(epoch, overall_step)
            self.stats.end_save_ckpt(epoch, block)
            self.next_checkpoint_epoch += self.epochs_between_checkpoints
        return overall_step

 

The compute method is implemented by the Framework class, which serves as an abstract base class defining the required methods for the classes implementing a framework like PyTorch or TensorFlow.

 

In the PyTorch implementation, the compute method invokes the model() method, which in turn calls a sleep function located in the utils/utility.py module. Specifically, the line base_sleep(sleep_time) simulates the time an accelerator takes to complete the computation stage. This includes the forward pass, backward pass, and weights and bias updates.

 

# dlio_benchmark/utils/utility.py
...
def sleep(config):
    sleep_time = 0.0
    if isinstance(config, dict) and len(config) > 0:
        if "type" in config:
            if config["type"] == "normal":
                sleep_time = np.random.normal(config["mean"], config["stdev"])
            elif config["type"] == "uniform":
                sleep_time = np.random.uniform(config["min"], config["max"])
            elif config["type"] == "gamma":
                sleep_time = np.random.gamma(config["shape"], config["scale"])
            elif config["type"] == "exponential":
                sleep_time = np.random.exponential(config["scale"])
            elif config["type"] == "poisson":
                sleep_time = np.random.poisson(config["lam"])
        else:
            if "mean" in config:
                if "stdev" in config:
                    sleep_time = np.random.normal(config["mean"], config["stdev"])
                else:
                    sleep_time = config["mean"]
    elif isinstance(config, (int, float)):
        sleep_time = config
    sleep_time = abs(sleep_time)
    if sleep_time > 0.0:
        base_sleep(sleep_time)
    return sleep_time

 

For example, during UNET-3D training, it was measured that a NVIDIA H100 takes approximately 0.323 seconds to complete this stage. This value is passed to DLIO via the workload configuration file using the workflow.train.computation_time key.

 

The use of base_sleep(sleep_time) allows performance testing of storage systems for deep learning workloads without requiring expensive accelerators of any type in the lab. It's worth noting that DLIO's authors chose to alias Python's native sleep function as base_sleep in their implementation.

 

# dlio_benchmark/utils/utility.py
...
from time import time, sleep as base_sleep

 

Key Takeaways

 

  1. DLIO does not require any accelerator (e.g., GPU, TPU, DPU) to benchmark your storage system.

  2. Benchmark pass criteria are based on both throughput and latency. Therefore, focusing solely on high throughput is insufficient. You must also ensure the system responds quickly enough to maintain high accelerator utilization (AU).
  3. Accelerator utilization depends on the workload type. For example:

    • To pass a UNET-3D benchmark, AU must be ≥ 90%
    • To pass a CosmoFlow benchmark, AU must be ≥ 70%

 

Closing Thoughts

 

Thanks for sticking with us through this deep dive! We know it's a lot to take in, but by now you should have a solid understanding of the context and challenges we faced in coming up with a cost-efficient method for measuring performance for deep learning workloads and the rationale behind our approach to overcoming these challenges.

 

In the next post of this series, we'll explore our methodology and share performance results from training a UNET-3D model using an AWS FSx for NetApp ONTAP scale-out file system.

 

References

 

[1a] DLIO Benchmark. Available from: <https://dlio-benchmark.readthedocs.io/en/latest/>

 

[1b] DLIO Benchmark Overview. Available from: <https://dlio-benchmark.readthedocs.io/en/latest/overview.html>

 

[2] H. Devarajan, H. Zheng, A. Kougkas, X. -H. Sun and V. Vishwanath, DLIO: A Data-Centric Benchmark for Scientific Deep Learning Applications., 2021 IEEE/ACM 21st International Symposium on Cluster, Cloud and Internet Computing (CCGrid), Melbourne, Australia, 2021, pp. 81-91, doi: 10.1109/CCGrid51090.2021.00018.

Comments
Public