LSTM

In this section, we will discuss practical LSTM optimization techniques that are proposed in Serving RNN-based Deep Learning Models 10x Faster[1]. We will go through the work and catch some high lelve ideas while skipping too much details. It will be highly recommended that you read the paper to gain more insights if you plan to apply the techniques to your systems.

LSTM Revisited

LSTM/GRU are among those variations of RNNs, inheriting basic recurrent structure of RNN but using different cell computations to catch long-term dependencies along sequences. We take LSTM as an example and illustrate its cell computation.

it=σ(Wixt+Uiht1+bi)ft=σ(Wfxt+Ufht1+bf)ot=σ(Woxt+Uoht1+bo)ct=ftct1+ittanh(Wcxt+Ucht1+bc)ht=ottanh(ct)i_t = \sigma(W_i x_t + U_i h_{t-1} + b_i) \\ f_t = \sigma(W_f x_t + U_f h_{t-1} + b_f) \\ o_t = \sigma(W_o x_t + U_o h_{t-1} + b_o) \\ c_t = f_t c_{t-1} + i_t \tanh(W_c x_t + U_c h_{t-1} + b_c) \\ h_t = o_t \tanh(c_t)

In LSTM, the total amount of computation is dominated by MMs. Polular DL toolkits such as TensorFlow or PyTorch have already done a lot of work to optimize their GEMM kenels. What else are still there that we could leverage to improve the performance of LSTM? Or what are problems with existing LSTM implementation?

The Problem: Low Computational Intensity

During the inference phase, the batch size is very likely to be 1 or a few. At each time slot t, the demension of input tensor x and weight tensor W are [1, e] and [e, h] respectively, where e is the embedding size and h is the hidden size. According to the equation in the blocking section, the computational intensity of the MM between those two matrices is given by:

q=2kk+knr+22, given 2nrkq=\frac{2k}{k+\frac{k}{n_r}+2} \approx 2, \text{ given } 2 \ll n_r \ll k

The input x is a small vector with one of its dimensions equal to 1. This results in a very poor computational intensity and is the key reason that significantly slows down the GEMM calculation. GEMM kernels in exising NN tookits work well on large matrices, but poorly on small matrices in RNN. During training phase, the small matrix issue could be compensated by setting a large batch size. However, this is not applicable for inference.

By using the stream benchmark toolarrow-up-right, the observed data bandwidth between L3 and L2 cache on it is 62.5 GigaFloats/s(250 GB/s). With a computational intensity of 2 above, the best achievable performance of LSTM on the Xeon E5-2650 machine is at most 125 GFLOPS(62.5 x 2). This is less than 8% of the CPU's theoretical peak of 1.69 TFLOPS. As a result, the CPU is idle, waiting for data IO.

Fuse Matrices

The phases in LSTM could be divided into two categories:

  • Time-dependent phase. The phase consists of an MM that has dependency across time slots, e.g., those MMs taking hidden state h_t as inputs.

  • Time-independent phase. The phase does not contain any MM that has dependency across time slots.

In the time independent phase, the four weight matrices could be further fused into a single larger one [W_i, W_f, W_c, W_o] since they share a common input matrix. With the adjustment, the problem of low computaional intensity caused by small input x will be relieved.

Loop Tiling for Weight Matrix

In the time dependent phase, the weight matrix is always needed to compute c_t and h_t during each time slot. To exploit this data reuse opportunaty across the sequence and improve temporal locality, it is necessary choose a good blocking size for the weight matrix such that it won't be evicted from cache until no longer needed.

We can use loop-tiling to iterate the weight matrix in the outer loop to achieve the goal. If this is the case, the data movements for W will be reduced from seq_len*|W| to |W|, which is independent of the sequence length.

LSTM Optimization

Weight-Centric Streamlining (WCS)

However, to ensure this reuse, the computation must be conducted at where the weights are, i.e., the mapping between parallel partitions and the cores that execute them must not change across TDPs.

In essence, they use ParallelOuterRNN. That creates partitions outside RNN loop, and a partition is binding with a thread id(computation core), so it leverages thread private L2 cache across phrase and timestamp.

Implementation

The techniques in the work has already been applied to onnxruntime. You could check the implementation herearrow-up-right.

References

[1] Zhang, M., Rajbhandari, S., Wang, W. and He, Y., 2018. Deepcpu: Serving rnn-based deep learning models 10x faster. In 2018 {USENIX} Annual Technical Conference ({USENIX}{ATC} 18) (pp. 951-965).[link]arrow-up-right[cache]arrow-up-right

Last updated