Embedding learning is a commonly used technique to deal with categorical features in deep recommendation models by mapping sparse features into dense vectors. However, embedding tables can be large because of the corresponding feature sizes. Distributed training has been adopted to place the tables on multiple hardware devices, such as GPUs, but even with distributed training, embedding tables are often efficiency bottlenecks.

Typically, embedding lookup consists of four stages. In the forward pass, the sparse indices are mapped into dense vectors (forward computation), which are then sent to the target devices (forward communication). In the backward pass, the gradients of the embedding vectors are sent back from the target devices (backward communication) and applied to the embedding vectors (backward computation). The tables can easily lead to imbalances if not carefully partitioned. For example, the random placement in Figure 1 is bottlenecked by GPU2 with a 56.6-millisecond latency.

In a paper published in October 2022 — “__DreamShard: Generalizable embedding table placement for recommender systems__” — we focus on embedding table sharding, or how to place the embedding tables to balance computation and communication costs. This is the basis of DreamShard, a reinforcement learning approach for embedding table sharding (see Figure 2).

DreamShard learns a cost network to directly predict the costs of the embedding operations. Specifically, the network takes as input the table features (e.g., table dimension) of each single table and outputs the computation and communication costs. It then trains a policy network by interacting with an estimated Markov decision process (MDP), without real GPU executions, where the states and the rewards are estimated by the predictions of the cost network. Equipped with sum reductions for the table representations and max reductions for the device representations, the two networks can directly generalize to unseen placement tasks with different numbers of tables and/or devices, without fine-tuning.

DreamShard outperforms the existing heuristic strategies on both open sourced synthetic tables and Meta tables, achieving up to a 19 percent speedup over the state of the art. It can generalize to unseen tasks that have different numbers of tables and/or devices with a negligible performance drop (< 0.5 milliseconds), and its inference is efficient — it can place hundreds of tables in less than one second.

The idea behind DreamShard is to formulate the table placement process as an MDP and train a cost network to estimate its states and rewards. A policy network with a tailored generalizable network architecture is trained by efficiently interacting with the estimated MDP. The two networks are updated iteratively to improve the state/reward estimation and the placement policy.

We formulate embedding table sharding as an MDP, where we assign tables to devices one by one (see Figure 3). At each step, the legal actions are the devices that will not cause memory explosion if assigning the current table. At the final step, we get the reward by collecting the cost from the hardware. The one-by-one placement enables the agent to be generalized across different numbers of tables. For example, an agent trained on an MDP with very few tables can be applied to another MDP with more tables by simply executing more steps.

Interacting with the above MDP is computationally expensive since it requires GPU execution. To address this, we build an estimated MDP by approximating the cost features (state) and the reward with a cost network (see Figure 4). The cost network is designed around two ideas. First, it uses a shared MLP to map raw table features into table representations. For any unseen tables, this MLP can be directly applied to extract table representations. Second, it enables a fixed-dimension representation for each device with sum reductions (i.e., the element-wise sum of the table representations in the device), and similarly for the overall representation across devices with max reductions. The reduced representations are then followed by multiple MLP heads for cost predictions.

For unseen tasks with different numbers of tables and/or devices, the reductions will always lead to fixed-dimension device/overall representations so that the prediction heads can be directly applied. We train the cost network with mean squared error loss using the cost data collected from the GPUs. Once trained, it can predict the cost features or the reward with a single forward pass without GPU execution.

approximating the cost features and the reward with a cost network.

The policy network uses a shared MLP and sum reductions to produce a fixed-dimension representation, which is then concatenated with the cost features to obtain the device representation (see Figure 5). To accommodate the potentially variable action space, since the number of available devices may vary, a shared MLP will process each device representation separately to obtain a confidence score. This is followed by a Softmax layer to produce action probabilities. This design allows the policy to generalize across different numbers of devices, and the training procedure iteratively executes the following: 1) collecting cost data from GPUs based on the placements generated by the current policy, 2) updating the cost network with the previously collected cost data, and 3) updating the policy network by interacting with the current estimated MDP.

Embedding table sharding is a significant design challenge in the distributed training of deep recommendation models. Optimizing embedding table sharding can greatly boost the training throughput since embedding computation and communication are often the bottlenecks. Researchers and practitioners who work on efficiency problems in recommendation models would find DreamShard interesting and useful.

This work also provides a concrete example of how RL can be used to improve machine learning system design. The idea of training neural cost models and reinforcement learning could be applied to many combinatorial optimization problems in the system design.

—

Thanks to the many people who provided technical insights, discussions, and feedback: Dhruv Choudhary, Chris Cummins, Xizhou Feng, Aaron Ferber, Yuchen Hao, Pavani Panakanti, Soohee Lee, Zhongyi Lin, Zirui Liu, Geet Sethi, Srinivas Sridharan, Zhou Wang, Justin Wong, Carole-Jean Wu, and Yufei Zhu.

We also deeply appreciate the support from our leadership team: Leo Cazares, Binu John, Sukwoo Kang, Richard Kaufmann, Arun Kejariwal, Max Leung, Parth Malani, Martin Patterson, and Rishi Sinha.