LSTM
In this section, we will discuss practical LSTM optimization techniques that are proposed in Serving RNN-based Deep Learning Models 10x Faster[1]. We will go through the work and catch some high lelve ideas while skipping too much details. It will be highly recommended that you read the paper to gain more insights if you plan to apply the techniques to your systems.
LSTM Revisited
LSTM/GRU are among those variations of RNNs, inheriting basic recurrent structure of RNN but using different cell computations to catch long-term dependencies along sequences. We take LSTM as an example and illustrate its cell computation.
In LSTM, the total amount of computation is dominated by MMs. Polular DL toolkits such as TensorFlow or PyTorch have already done a lot of work to optimize their GEMM kenels. What else are still there that we could leverage to improve the performance of LSTM? Or what are problems with existing LSTM implementation?
The Problem: Low Computational Intensity
During the inference phase, the batch size is very likely to be 1 or a few. At each time slot t, the demension of input tensor x and weight tensor W are [1, e] and [e, h] respectively, where e is the embedding size and h is the hidden size. According to the equation in the blocking section, the computational intensity of the MM between those two matrices is given by:
The input x is a small vector with one of its dimensions equal to 1. This results in a very poor computational intensity and is the key reason that significantly slows down the GEMM calculation. GEMM kernels in exising NN tookits work well on large matrices, but poorly on small matrices in RNN. During training phase, the small matrix issue could be compensated by setting a large batch size. However, this is not applicable for inference.
By using the stream benchmark tool, the observed data bandwidth between L3 and L2 cache on it is 62.5 GigaFloats/s(250 GB/s). With a computational intensity of 2 above, the best achievable performance of LSTM on the Xeon E5-2650 machine is at most 125 GFLOPS(62.5 x 2). This is less than 8% of the CPU's theoretical peak of 1.69 TFLOPS. As a result, the CPU is idle, waiting for data IO.
Fuse Matrices
The phases in LSTM could be divided into two categories:
Time-dependent phase. The phase consists of an MM that has dependency across time slots, e.g., those MMs taking hidden state
h_tas inputs.Time-independent phase. The phase does not contain any MM that has dependency across time slots.
In the time independent phase, the four weight matrices could be further fused into a single larger one [W_i, W_f, W_c, W_o] since they share a common input matrix. With the adjustment, the problem of low computaional intensity caused by small input x will be relieved.
Loop Tiling for Weight Matrix
In the time dependent phase, the weight matrix is always needed to compute c_t and h_t during each time slot. To exploit this data reuse opportunaty across the sequence and improve temporal locality, it is necessary choose a good blocking size for the weight matrix such that it won't be evicted from cache until no longer needed.
We can use loop-tiling to iterate the weight matrix in the outer loop to achieve the goal. If this is the case, the data movements for W will be reduced from seq_len*|W| to |W|, which is independent of the sequence length.

Weight-Centric Streamlining (WCS)
However, to ensure this reuse, the computation must be conducted at where the weights are, i.e., the mapping between parallel partitions and the cores that execute them must not change across TDPs.
In essence, they use ParallelOuterRNN. That creates partitions outside RNN loop, and a partition is binding with a thread id(computation core), so it leverages thread private L2 cache across phrase and timestamp.
Implementation
The techniques in the work has already been applied to onnxruntime. You could check the implementation here.
References
[1] Zhang, M., Rajbhandari, S., Wang, W. and He, Y., 2018. Deepcpu: Serving rnn-based deep learning models 10x faster. In 2018 {USENIX} Annual Technical Conference ({USENIX}{ATC} 18) (pp. 951-965).[link][cache]
Last updated