一步步教你理解LSTM

重磅乾貨,第一時間送達

1 什麼是LSTM

LSTM全名是Long Short-Term Memory,長短時記憶網路,可以用來處理時序資料,在自然語言處理和語音識別等領域應用廣泛。和原始的迴圈神經網路RNN相比,LSTM解決了RNN的梯度消失問題,可以處理長序列資料,成為當前最流行的RNN變體。

2 LSTM應用舉例

假設我們的模型的輸入是依次輸入一句話的每個單詞,我們需要對單詞做分類,比如有兩句話:(1)arrive Beijing on November 2nd,這裡的Beijing是目的地;(2)leave Beijing on November 2nd,這裡的Beijing是出發地。如果用普通的神經網路,輸入是‘Beijing’,那麼輸出一定就是確定的,但事實上我們希望在‘Beijing’前面是‘arrive’時,‘Beijing’被識別為目的地,在‘Beijing’前面時‘leave’時,‘Beijing’被識別為出發地。這裡LSTM就會派上用場,因為LSTM可以記住歷史資訊,在讀到‘Beijing’時,LSTM還知道在前面是‘arrive’還是‘leave’,根據歷史資訊來做出不同的判斷,即使輸入是相同的,輸出也會不同。

3 LSTM結構剖析

普通的神經元是一個輸入,一個輸出,如圖所示:

對於神經元h1來講,輸入就是x1,輸出就是y1,LSTM做的就是把普通的神經元,替換成LSTM的單元。

一步步教你理解LSTM

從圖中可以看到LSTM有四個輸入,分別是input(模型輸入),forget gate(遺忘門),input gate(輸入門),以及output gate(輸出門)。因此相比普通的神經網路,LSTM的引數量是它們的4倍。這3個門訊號都是處於0~1之間的實數,1代表完全開啟,0代表關閉。遺忘門:決定了前一時刻中memory中的是否會被記住,當遺忘門開啟時,前一刻的記憶會被保留,當遺忘門關閉時,前一刻的記憶就會被清空。輸入門:決定當前的輸入有多少被保留下來,因為在序列輸入中,並不是每個時刻的輸入的資訊都是同等重要的,當輸入完全沒有用時,輸入門關閉,也就是此時刻的輸入資訊被丟棄了。輸出門:決定當前memroy的資訊有多少會被立即輸出,輸出門開啟時,會被全部輸出,當輸出門關閉時,當前memory中的資訊不會被輸出。

4 LSTM公式推導

有了上面的知識,再來推導LSTM的公式就很簡單了,圖中代表遺忘門,代表輸入門,代表輸出門。C是memroy cell,儲存記憶資訊。代表上一時刻的記憶資訊,代表當前時刻的記憶資訊,h是LSTM單元的輸出,是前一刻的輸出。

一步步教你理解LSTM

遺忘門計算:

這裡的是把兩個向量拼接起來的意思,用sigmoid函式主要原因是得到有個0~1之間的數,作為遺忘門的控制訊號。

輸入門計算:

當前輸入:

當前時刻的記憶資訊的更新:

從這個公式可以看出,前一刻的記憶資訊透過遺忘門,當前時刻的輸入透過輸入門,加起來更新當前的記憶資訊。

輸入門計算:

LSTM的輸出,是由輸出門和當前記憶資訊共同決定的:

這樣我們就明白了LSTM的前向計算過程。有了LSTM前向傳播演算法,推導反向傳播演算法就很容易了, 透過梯度下降法迭代更新我們所有的引數,關鍵點在於計算所有引數基於損失函式的偏導數,這裡就不細講了。

小結

LSTM雖然結構複雜,但是隻要理順了裡面的各個部分和之間的關係,是不難掌握的。在實際使用中,可以藉助演算法庫如Keras,PyTorch等來搞定,但是仍然需要理解LSTM的模型結構。

參考文獻

https://www。youtube。com/watch?v=rTqmWlnwz_0&index=35&list=PLJV_el3uVTsPy9oCRY30oBPNLCo89yu49

https://zybuluo。com/hanbingtao/note/581764

http://www。cnblogs。com/pinard/p/6519110。html

http://blog。echen。me/2017/05/30/exploring-lstms/

下載1:OpenCV-Contrib擴充套件模組中文版教程

下載2:Python視覺實戰專案52講

下載3:OpenCV實戰專案20講

交流群