Attention#
#transformer, #attention
Note
Attention a vector of importance weights
in order to predict or infer one element, such as pixel in an image or a word in a sentence.
how strongly it is correlated with(or attends to) other elements and take the sum of their values weighted by the attention vector as the approximation of the target
์ด๋ฏธ์ง์ ํ ์คํธ ์์์ ํ๋์ ์์๋ pixel, word์ผ ๊ฒ์ด๋ค. ์ด์ ๋ํด์ attention์ ๋ค๋ฅธ ๊ฒ๋ค๊ณผ ์ฐ๊ด์ฑ์ด ์ผ๋ง๋ ๊ฐํ ์ง๋ฅผ ํํํ๋ ์ญํ ์ ํด์ค๋ค. ๊ทธ๋ฐ๋ฐ ์ด์ ๋ํด์๋ ๊ทธ๋์ seq2seq๊ฐ ๋์ผํ ์ญํ ์ ํ์์ง๋ง long-term problem ๋๋ฌธ์ ๋ฉ๋ฆฌ ์๋ ์์์ ๋ํ ๊ธฐ์ต๋ ฅ ๊ฐ์๋ก attention ์ฌ์ฉํ๊ฒ ๋๋ค. seq2seq๋ 2014๋ ์ attention์ 2015๋ ์ transformer๋ 2016๋ ์ ๊ฐ๊ฐ ๋์๋ค.
the secret sauce invented by attention is to create shortcuts between the context vector and the entire source input. The weights of these shortcut connections are customizable for each output element.
why attention?#
์ ์ด์ attention์ ๊ธฐ๊ณ ๋ฒ์ญ ์์ ์์ ๋์์ผ๋ก ๋์๋ ์๊ณ ๋ฆฌ์ฆ์ด๋ค. ๊ธฐ๊ณ ๋ฒ์ญ ์์ ์ ๋ชฉํ๋ A ์ธ์ด ๋ฌธ์ฅ์ B ์ธ์ด๋ก ์๋์ผ๋ก ๋ฒ์ญํ๋ ๊ฒ์ด๋ค. ์ด๊ธฐ ๊ธฐ๊ณ ๋ฒ์ญ ๋ชจ๋ธ์์๋ ์ ์ฒด ์์ค ๋ฌธ์ฅ์ ํ๋์ ๊ณ ์ ๋ context ๋ฒกํฐ๋ก ์์ฝํ๋ ค๊ณ ์๋ํ๋ค. ๊ทธ๋ฌ๋ ์ด ๋ฐฉ์์ ๊ธด ๋ฌธ์ฅ ์ฒ๋ฆฌ์ ๋ฒ์ญ ํ์ง์ ํ๊ณ๊ฐ ์กด์ฌํ๋ค.
Attention Mechanism์ ์ธ์ฝ๋์ ๋ง์ง๋ง ์จ๊ฒจ์ง ์ํ ํ๋(last hidden state)๋ก๋ถํฐ ๋จ์ผ context ๋ฒกํฐ๋ฅผ ๋ง๋๋(seq2seq๋ฐฉ์) ๋์ , ์์ค ์ ๋ ฅ ์ ์ฒด์ context ๋ฒกํฐ ์ฌ์ด์ ์ง๋ฆ๊ธธ ์ฐ๊ฒฐ์ ๋ง๋๋ ๊ฒ์ด๋ค. Attention์ ์ด๋ฌํ shortcut connection weight๋ฅผ ๊ฐ ์ถ๋ ฅ ์์(๋ฒ์ญ์ ๊ฐ ์์น์ ๋ํ ํ ํฐ)์ ๋ํด์ ์ปค์คํฐ๋ง์ด์ง ํ ์ ์๋๋ก ํ๋ค. ์ด๋ ๋ชจ๋ธ์ด ์ถ๋ ฅ ์์ฑ ์ค์ ์ด๋ค ๋ถ๋ถ์ ๋ ์ง์ค ํด์ผ ํ๋์ง๋ฅผ ํ์ตํ ์ ์๊ฒ ํ๋ค.
context ๋ฒกํฐ๋ ์ ์ฒด ์์ค ์ ๋ ฅ ์ํ์ค์ ๋ํด์ ์ ๊ทผํ ์ ์๋ค. ๋ฐ๋ผ์ ๊ธด ์์ค ๋ฌธ์ฅ์ ์ฒ๋ฆฌํ๋๋ฐ ์์ด์ ๋ฌธ์ ๊ฐ ์์ผ๋ฉฐ, ๋ชจ๋ธ์ ์์ค์ ๋์ ๊ฐ์ ์ ๋ ฌ์ ํ์ตํ๊ณ ์ ์ดํ๋ค. ์ฌ๊ธฐ์ ์ ๋ ฌ์ ํ์ตํ๋ค๋ ๊ฒ์ ๋จ์ด ๊ฐ ์์ ๊ด๊ณ๋ฅผ ํ์ตํ๋ค๋ ๊ฒ์ ์๋ฏธํ๋ค. ์ธ์ด๋ง๋ค ๋ฌธ๋ฒ์ ๋ฐ๋ผ ์ด์๊ณผ ๋ชฉ์ ์ด๊ฐ ์์ผ ํ ์์น๊ฐ ๋ค๋ฅผ ๊ฒ์ธ๋ฐ, ์ธ์ด๋ ๋ฌ๋ผ๋ ์๋ฏธ๊ฐ ๊ฐ์ ๋ฌธ์ฅ์์๋ ์์๋ค์ด ์์น๊ฐ ๋ค๋ฅผ ์ ์๋ค. ๊ทธ์ ๋ํ ํ์ต์ ํ๋ค๋ ๊ฒ์ด๋ค. ์ด๋ฌํ ๊ฒ์ ๋ํ ํ์ต์ ํ๋ ๊ณผ์ ์ด ๊ณง ์ธ์ด๋ฅผ ๊ธฐ๊ณ๊ฐ ํ์ตํ๋ ๊ฒ์ผ๋ก ๋ณผ ์ ์๋ค. context ๋ฒกํฐ๋ ์ฌ๋ฌ ์ ๋ณด ์์ค๋ฅผ ๊ฒฐํฉํ์ฌ ์ถ๋ ฅ์ ์์ฑํ๊ฒ ๋๋ค. context ๋ฒกํฐ๋ ์ถ๋ ฅ ์์๋ง๋ค ์ ๋ฐ์ดํธ๋๋ฉฐ, ๊ฐ ์์์ ๋ํ ๋ฒ์ญ์ ๋ํ ์ ๋ณด๋ฅผ ์ ๊ณตํ๋ค.
Encoder hidden state : ๊ธฐ์กด ์ ๋ณด
Decoder hidden state : ์ด์ ์ถ๋ ฅ ์ ๋ณด
์์ค์ ๋์ ๊ฐ์ ์ ๋ ฌ ์ ๋ณด
Shortcut#
shortcut connection์ ๋ฃ๋ ์๊ฐ ๋ ์ค๋ฅด๋ ๋ ผ๋ฌธ์ด ์๋ค. ๋ฐ๋ก resnet์ด๋ค. resnet์์๋ deep nn์์ ๋ฐ์ํ๋ gradient lossing ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด ๋์จ architecture๋ก ๊ต์ฅํ ์ผ์ธ์ด์ ๋ ํ๋ ๊ตฌ์กฐ์ด๋ค. ์ง๊ธ๊น์ง๋ ๋ง์ architecture๋ค์ backbone์ผ๋ก ์ฌ์ฉ๋๊ธฐ๋ ํ ๋งํผ ๊ทธ architecture๊ฐ ๊ฐ์ง๋ ๊ตฌ์กฐ์ ๊ฐ์ ์ด ํฌ๋ค. ๊ทธ ์ค์์๋ ์ค์ํ๋ ๊ฒ์ด ์ด shorcut์ด๋ค.
Resnet์์์ shortcut์ Residual connection
์ด๋ผ๊ณ ๋ ๋ถ๋ฆฐ๋ค. ์ด๋ ๋คํธ์ํฌ ๋ ์ด์ด๋ฅผ ๊ฑด๋๋ฐ๊ณ ์
๋ ฅ๊ณผ ์ถ๋ ฅ์ ์ง์ ์ฐ๊ฒฐํจ์ผ๋ก์จ gradient๊ฐ ์ํํ๊ฒ ํ๋ฅผ ์ ์๋๋ก ๋์๋ค.
Attention์์์ shortcut์ ์ธ์ฝ๋์ ๋์ฝ๋ ์ฌ์ด์ ์ฐ๊ฒฐ์ ๋งํ๋ค. ๊ฐ ์ถ๋ ฅ ์์๊ฐ ์์ค ์ ๋ ฅ์ ๋ชจ๋ ์์์ ๋ํ sum of weight๋ฅผ ๊ณ์ฐํ๊ณ ์ด๋ฅผ ์ฌ์ฉํ์ฌ ์ถ๋ ฅ์ ์์ฑํ๋ ๋ฉ์ปค๋์ฆ์ ์ค๋ช ํ๋ค. ์ด๋ ๋ชจ๋ธ์ด ์ ๋ ฅ๊ณผ ์ถ๋ ฅ ๊ฐ์ ๊ด๊ณ๋ฅผ ํ์ตํ๋๋ฐ ์ฌ์ฉํ๋ค๋ ์ ์์ ๋ค๋ฅด๋ค.
Definition#
\(\mathbf{x}\)๋ source sequnece๋ฅผ ๋ํ๋ด๊ณ , \(\mathbf{y}\)๋ target sequence๋ฅผ ๋ํ๋ธ๋ค. ๊ฐ๊ฐ n,m์ ๊ธธ์ด๋ฅผ ๊ฐ์ง๋ค. bold์ฒด ์ฒ๋ฆฌ๋ ๊ฒ์ vector๋ฅผ ๋ํ๋ธ๋ค๋ ๊ฒ์ ํญ์ ๊ธฐ์ตํ์.
encoder๋ bidirectional rnn์ด๋ค. forward hidden state๋ \(\overrightarrow{h_i}\)์ด๊ณ , backward hidden state๋ \(\overleftarrow{h_i}\)์ด๋ค. \(i\)๋ 1~n์ ๋ฒ์๋ฅผ ๊ฐ์ง๋ ์๊ฐ ๋ณ์๋ฅผ ๋ํ๋ธ๋ค. encoder๋ ๊ฒฐ๊ตญ ํด๋น ๋ ๋ณ์๋ฅผ concatenationํ ๊ฐ์ด๋ผ๊ณ ๋ณผ ์ ์๋ค. (์ฝ๊ฐ์ transpose๋ฅผ ํ๊ณ โฆ)
decoder๋ hidden state ๊ฐ์ผ๋ก \(s_t=f(s_{t-1},y_{t-1},c)\)๋ฅผ ๊ฐ์ง๋ค. t ์์ ์์์ context vector c๋ฅผ ๋ฃ์ด์ฃผ๊ณ , ๋ฐ๋ก ์ ์์ (t-1)์์์ output word๊ฐ๊ณผ ์ค์ ํ๊ฒ ๊ฐ y(t-1)๋ฅผ ๋ฃ์ด์ค๋ค. ๊ทธ๋ ๊ฒ ํจ์ผ๋ก์จ t์์ ์ hidden state \(s_t\)๋ฅผ ๊ณ์ฐํ๋ค.
๊ธฐ์กด์ ๋ฐฉ๋ฒ๋ก ์ ๋ชจ๋ hidden state๋ฅผ ๊ณ์ฐํ ๋ ๋์ผํ context vector๋ฅผ ์ฌ์ฉํ๋ ๋ฐ๋ฉด, Bahdanau et al.(2014)์ hidden state๋ฅผ ๊ณ์ฐํ ๋๋ง๋ค ์๋กญ๊ฒ ๋ง๋ค์ด์ง context vector๋ฅผ ์ฌ์ฉํ๊ฒ ๋๋ค. ์ด๋ฅผ ์ํด alignment model(score function) \(a\)๋ฅผ ์ฌ์ฉํ๋ค.
\(a\)๋ฅผ ์ด์ฉํ์ฌ \(s_{t-1}, h_j\)๊ฐ ์ผ๋ง๋ ์ ์ฌํ์ง๋ฅผ ๋ํ๋ด๋ attention score \(e_{t-1,j}\)๋ฅผ ๊ณ์ฐํ๋ค. ์ด๋ ๋ค์ํ a์ ์ ํ์ด ์ฌ์ฉ๋ ์ ์์ผ๋ฉฐ Bahdanau et al.(2014)์์ ์ฌ์ฉํ ๊ฒ์ ์๋์ ๊ฐ๋ค
softmax๋ฅผ ์ด์ฉํด์ attention score๋ฅผ attention weight(alignment vector) \(a_{t-1,j}\)๋ก ๋ณํ
๊ณ์ฐ๋ attention weight์ encoder ๋ถ๋ถ์ hidden state \(h_j\)๋ฅผ ๊ฐ์คํ๊ท ํ์ฌ t์์ ์ ์ฌ์ฉ๋ context vector๋ฅผ ๊ณ์ฐํ๋ค.