Recurrent Neural Network(RNN)の基礎を学んだ
ニューラルネットワークが現在盛んに研究されているが、その中でも系列データを扱うRecurrent Neural Network(RNN)について学習した。 CNNは研究でもよく取り扱っていたが、RNNは手付かずだったので。。
教材はこれとか
- 作者: 岡谷貴之
- 出版社/メーカー: 講談社
- 発売日: 2015/04/08
- メディア: 単行本(ソフトカバー)
- この商品を含むブログ (13件) を見る
これとか
Deep Learning (Adaptive Computation and Machine Learning series)
- 作者: Ian Goodfellow,Yoshua Bengio,Aaron Courville
- 出版社/メーカー: The MIT Press
- 発売日: 2016/11/18
- メディア: ハードカバー
- この商品を含むブログ (1件) を見る
を見た。Bengio先生の方は日本語訳プロジェクトが進行中のようで、暫定版がScribdで公開されている。もう少し待てば製本版が買えるようになるかな。
http://www.deeplearningbook.me/
系列データ
RNNは系列データを対象とする。わかりやすい例だと音声、文章など。このようなデータを一般に
と表す。は可変であり、添え字は大体を使う。
便宜上
を時刻と呼ぶことが多いが、厳密に時刻と対応している必要はなく、順序に意味があるデータならばなんでもいい(画像、動画、ボリュームデータなど)。多次元への拡張はMulti-dimentional Recurrent Neural Networkで取り上げられているが、またおいおい見ていく。とりあえずここでは、音声の波形のような一次元データを考えるものとする。
RNNの仕組み(順伝播)
RNNにも色々な種類があるようだが、最も単純な回帰構造を持つネットワークは以下のような感じ。
文字通り、時刻tの隠れ層の出力が時刻t+1の隠れ層に影響するような構造をしている。入力がd次元T系列だとする。入力層から隠れ層につながる重みがU, 隠れ層から隠れ層の重みがW, 隠れ層から出力層の重みがVで表されている。Wは各時刻で共通の重みを用いる(ただし時々刻々と更新されていく)。
入力を受け取ってを出力、を受け取ってを出力...といった手順で、入力の系列長と同じ長さの系列を出力する(CTC, connectionist temporarl classificationという構造を使えばこの限りではない)。
入力層、隠れ層、出力層の添え字をそれぞれi, j, kで表し、時刻tにおけるネットワークの各ユニットへの入力を
、隠れ層の入出力を、出力層の入出力をとする。またノードiからjへの重みを、ノードjからj'を、ノードjからkをと書く。
時刻tの隠れ層への入力は以下で表される。
但しバイアス項は重みと入力ににまとめて省略しているものとする。隠れ層の出力は
は活性化関数。 から時刻tにおける出力を随時計算していくわけだが、t=1に関しては初期値を定める必要がある。通常はとするらしい。まとめると
RNNの出力は隠れ層の出力と重みから以下のように計算される。
まとめると
と書ける。この出力を用いて誤差関数を計算する。これは通常のネットワークと同様で、例えば分類問題ならはsoftmaxになり、誤差関数はcross entropyを計算する。
出力系列に対する教示信号とするとき、cross entropyは以下のようになる
この場合、単一のラベルを系列長分だけ推定するネットワークとなる。
逆伝搬
RNNの学習は通常のニューラルネット同様に勾配降下により行われるが、重みによる微分を求める必要がある。主にRTRL法(real time recurrent learning)、BPTT法(back propagation through time)の二つが用いられる。後者のほうが
よりシンプルということで、BPTT法についてみていく。
RNNを時刻で展開すると順伝搬型NNとみなすことができるということで、以下のような伝播を行う。
展開したネットワークを通常の順伝播ネットワークとみなすことで誤差逆伝播を実現する。ただし気を付けることが二つある。一つは、
展開すると各時刻で別々の層に誤差を伝播しているように見えるが、実際に更新されるのは同じWである。二つ目は、時刻tの隠れ層に伝播される誤差は時刻tの出力層と、時刻t+1の隠れ層からの二つであること。
通常のネットワークの第l+1層から第l層への誤差、入力での微分は、
と書けることから、RNNについてもこの考えを導入する。時刻tの出力層のユニットkにおけるデルタを
と書き、同時刻の中間層のユニットjのデルタを
と書くことにする。時刻tにおける隠れ層のユニットは、tでの出力層とt+1の隠れ層のユニットとつながりがあるため、は
のように計算できる。から始め、tを小さくしながら繰り返し計算すると隠れ層の誤差が求められる。
各層のデルタが計算できたため、誤差の各層の重みによる微分を計算できるようになった。重みはU, W, Vの三つがあるため、それぞれに関してみていく。は各tの隠れ層のユニットjへの入力 にのみ含まれるため、
となる。同様にして、は、各tの隠れ層のユニットjへの入力にのみ含まれるため、
となり、は各tの出力層のユニットkへの入力にのみ含まれるので、
と求まる。各重みの誤差勾配を計算したあとは、重みを勾配方向に更新する通常の勾配降下法が適用される。
まとめ
以上が最も単純なRNNの順伝搬、逆伝搬の構造である。理論上は時刻tにおける出力はそれまでの系列の入力の影響を受け、データに潜むcontextを学習できる・・・ということになるが、実際は系列が長くなりすぎると勾配消失が起こるため、ごくわずかな期間の情報しか反映できない問題がある。これを解決した手法がGated recurrent unit(GRU)とかLong short-term memory unit(LSTM)とかあり、実際に成果を上げているモデルはこれらを用いている。