Deberta V3#
DEBERTAV3: IMPROVING DEBERTA USING ELECTRA-STYLE PRE-TRAINING WITH GRADIENT-DISENTANGLED EMBEDDING SHARING
Published as a conference paper at ICLR 2023 Paper PDF
Abstract
This paper presents a new pre-trained language model, DeBERTaV3
, which improves the original DeBERTa model by replacing masked language modeling(MLM) with replaced token detection (RTD)
, a more sample-ef๏ฌcient pre-training task. Our analysis shows that vanilla embedding sharing in ELECTRA hurts training ef๏ฌciency and model performance, because the training losses of the discriminator and the generator pull token embeddings in different directions, creating the โtug-of-warโ dynamics. We thus propose a new gradient-disentangled embedding sharing method that avoids the tug-of-war dynamics, improving both training ef๏ฌciency and the quality of the pre-trained model. We have pre-trained DeBERTaV3 using the same settings as DeBERTa to demonstrate its exceptional performance on a wide range of downstream natural language understanding (NLU) tasks. Taking the GLUE benchmark with eight tasks as an example, the DeBERTaV3 Large model achieves a 91.37% average score, which is 1.37% higher than DeBERTa and 1.91% higher than ELECTRA, setting a new state-of-the-art (SOTA) among the models with a similar structure. Furthermore, we have pre-trained a multilingual model mDeBERTaV3 and observed a larger improvement over strong baselines compared to English models. For example, the mDeBERTaV3 Base achieves a 79.8% zero-shot cross-lingual accuracy on XNLI and a 3.6% improvement over XLM-R Base, creating a new SOTA on this benchmark. Our models and code are publicly available at microsoft/DeBERTa.
๋ฌธ์ ์ค์ #
PLMs(Pre-trained Language Models)๋ฅผ scaling up์ ํด์ ํ๋ผ๋ฏธํฐ๋ฅผ ์์ญ์ต ์๋ฐฑ๋ง ๋จ์๋ก ๋๋ฆฌ๋ ๊ฒ์ด ํ์คํ ์ฑ๋ฅํฅ์์ด ์์๊ณ ์ง๊ธ๊น์ง์ ์ฃผ๋์ ์ธ ๋ฐฉ๋ฒ์ด์์ง๋ง, ๋ ์ฃผ์ํ ๊ฒ์ parameter๋ฅผ ์ค์ด๊ณ computation cost๋ฅผ ์ค์ด๋ ๊ฒ์ด๋ผ๊ณ ๋งํ๋ค.
Improving Efficency
incorporating disentangled attention(imporoved relative-position encoding mechanism)
DeBERTA๋ 1.5B๊น์ง scaling up์ ํจ์ผ๋ก์จ SuperGLUE์์ ์ฒ์์ผ๋ก ์ฌ๋ performance๋ฅผ ๋์ด์ฐ๋ค.
Replaced Token Detection(RTD) vs Masked language modeling(MLM)
proposed by ELECTRA(2020)
transformer encoder๋ฅผ ์ค์ผ๋ token๋ฅผ ์์ธกํ๋๋ฐ ์ฌ์ฉํ๋ BERT(MLM)์๋ ๋ค๋ฅด๊ฒ
RTD๋ generator, discriminator ๋ฅผ ์ฌ์ฉํ๋ค. generator๋ ํค๊น๋ฆฌ๋ ์ค์ผ์ ๋ง๋ค์ด๋ด๊ณ , discriminator๋ generator๊ฐ ๋ง๋ ์ค์ผ๋ ํ ํฐ์ original inputs๊ณผ ๊ตฌ๋ถ์ ํด๋ด๋ คํ๋ค. ๋ง์น GAN(Generative Adversarial Networks)๋ ์๋นํ ๋น์ทํ ๋ฉด์ด ์๋ค. ์ฐจ์ด์ ์ ๋ํด์ ๋ถ๋ช ํ๊ฒ ํ์.
์ฌ๊ธฐ์ ์ด์ ์ DeBERTa์์ V3๋ก ๋์๊ฐ๋ฉด์ ๋ฐ๋ ์ ์ผ๋ก ๋ ๊ฐ์ง๋ก ๊ผฝ๋๋ค. ํ๋๋ ์์์ ๋งํ BERT์ MLM์ ELECTRA ์คํ์ผ์ RTD(where the model is trained as a discriminator to predict whethre a token in the corrupt input is either original or replaced by a generator)๋ก ๋ฐ๊พธ๋ ๊ฒ์ด๊ณ , ๋ค๋ฅธ ํ๋๋ new embedding sharing method์ด๋ค. ELECTRA์์ generator discriminator๋ ๊ฐ์ token embedding์ ๊ณต์ ํ๋ค.
๊ทธ๋ฌ๋ ์ฐ๊ตฌ์๋ค์ ๋ณธ์ธ๋ค์ ์ฐ๊ตฌ์ ๋ฐ๋ฅด๋ฉด ์ด๊ฒ์ ํ์ต ํจ์จ์ฑ ๋ฉด์ด๋ ๋ชจ๋ธ์ ์ฑ๋ฅ๋ฉด์์ ๋ถ์ ์ ์ธ ์ํฅ์ ์ค๋ค๊ณ ๋งํ๋ค. training losses of the discriminator and the generator pull token embeddings into opposite directions
์ฆ ์์ฑ๊ธฐ์ ๋ถ๋ฅ๊ธฐ์ ํ์ต์ ๋ฐฉํฅ์ฑ์ด ๋ฐ๋์ด๊ธฐ ๋๋ฌธ์ ํ์ต loss๊ฐ ๊ฐํก์งํกํ ์ ๋ฐ์ ์๋ค๋ ๊ฒ์ด๋ค. ๊ทธ๋ผ ์๊ฐํด๋ณผ ์ ์๋ ๊ฒ์ด ๋ค์ ๋๊ฐ์ loss๋ฅผ ๋ง๋ค์ด์ ํ์ต ๋ฐฉํฅ์ฑ์ ๋ฐ๋๋ก ๊ฐ ์๋๋ก ํ์๊น ํ๋ ์ ์ ์๊ฐํด ๋ณผ ์ ์๊ฒ ๋ค. token embedding์ ๋ฌ๋ฆฌ ์ค๋ค๋ ๊ฒ์ด ์ด๋ค ์๋ฏธ์ธ์ง ๋ค์์ ํ์คํ๊ฒ ๋์์ผ ํ ๊ฒ์ด๋ค. MLM์ generator๋ฅผ token ์ค์์ ์๋ก ๊ด๋ จ์ด ์์ด๋ณด์ด๋ ๊ฐ๊น์ด ๊ฒ๋ค์ ์๋ก ์ก์๋น๊ธฐ๋ฉด์ ํ-์ต์ ์งํํ๊ฒ ๋๋ค. ํ์ง๋ง ๋ฐ๋ฉด์ RTD์ discriminator๋ ์๋ฏธ์ ์ผ๋ก ๊ฐ๊น์ด token์ ์ฌ์ด๋ฅผ ์ต๋ํ ๋ฉ๋ฆฌํ๋ ๋ฐฉํฅ์ผ๋ก ์ด์ง๋ถ๋ฅ ์ต์ ํ(๋ง๋ค ์๋๋ค)๋ฅผ ํ๊ณ pull their embedding์ ํ๊ฒ ๋จ์ผ๋ก์จ ๋์ฑ ๊ตฌ๋ถ์ ์ ํ ์ ์๋๋ก ํ๋ ๊ฒ์ด๋ค. ์ด๋ฌํ โ์ค๋ค๋ฆฌ๊ธฐ tug-of-warโ์ ๊ฐ์ ์ญํ์ด ํ์ต์ ๋ง์น๊ณ , ๋ชจ๋ธ ์ฑ๋ฅ์ ๋จ์ดํธ๋ฆฌ๋ ๊ฒ์ด๋ผ๊ณ ๋งํ๋ค. ๊ทธ๋ ๋ค๊ณ ๋ฌด์กฐ๊ฑด์ ์ผ๋ก seperated embedding์ ํ ์๋ ์๋ ๊ฒ์ด generator์ embedding์ discriminator์ ๋ค์ด์คํธ๋ฆผ taskํ์ต์ ํฌํจ์ํค๋ ๊ฒ์ด ๋์์ด ๋๋ค๊ณ ๋งํ๋ ๋
ผ๋ฌธ๋ ์์๊ธฐ ๋๋ฌธ์ด๋ค.
๊ทธ๋์ ๊ทธ๋ค์ด ์ ์ํ๋ ๊ฒ์ new gradient-disentangled embedding sharing(GDES) method
์ด๋ค. the generator shares its embeddings with the discriminator but stops the gradients from the discriminator to the generator embeddings. embedding sharing์ ์ฅ์ ๋ง์ ์ทจํ๋, ์ค๋ค๋ฆฌ๊ธฐ ์ญํ์ ํผํ ์ ์๋๋ก gradient์ ํ๋ฆ์ด discriminator์์ generator๋ก ํ๋ฅด์ง๋ ์๋๋ก ํ๋ ๋ฐฉ์์ธ ๊ฒ์ด๋ค.
Model Table#
Model |
Vocabulary(K) |
Backbone Parameters(M) |
Hidden Size |
Layers |
Note |
---|---|---|---|---|---|
128 |
1320 |
1536 |
48 |
128K new SPM vocab |
|
128 |
710 |
1536 |
24 |
128K new SPM vocab |
|
50 |
700 |
1024 |
48 |
Same vocab as RoBERTa |
|
50 |
350 |
1024 |
24 |
Same vocab as RoBERTa |
|
50 |
100 |
768 |
12 |
Same vocab as RoBERTa |
|
128 |
304 |
1024 |
24 |
128K new SPM vocab |
|
128 |
86 |
768 |
12 |
128K new SPM vocab |
|
128 |
44 |
768 |
6 |
128K new SPM vocab |
|
128 |
22 |
384 |
12 |
128K new SPM vocab |
|
250 |
86 |
768 |
12 |
250K new SPM vocab, multi-lingual model with 102 languages |
์ฐธ์กฐ
This is the model(89.9) that surpassed T5 11B(89.3) and human performance(89.8) on SuperGLUE for the first time. 128K new SPM vocab.
These V3 DeBERTa models are deberta models pre-trained with ELECTRA-style objective plus gradient-disentangled embedding sharing which significantly improves the model efficiency.
Background#
1. Transformer#
Transformer ๊ธฐ๋ฐ ์ธ์ด๋ชจ๋ธ๋ค์ \(L\)๊ฐ์ transformer block์ด ์์ฌ์ง ํํ๋ก ๊ตฌ์ฑ๋๋ค. ๊ฐ ๋ธ๋ฝ๋ค์ multi-head self-attention layer๋ค์ ํฌํจํ๊ณ ๊ทธ ๋ค๋ก์ fully-connected positional feed-forward network๊ฐ ๋ค๋ฐ๋ฅธ๋ค. ๊ธฐ์กด์ self-attention ๋ฉ์ปค๋์ฆ์ ๋จ์ด์ ์์น ์ ๋ณด๋ฅผ encodeํ๋๋ฐ๋ ์ ํฉํ์ง ์์๋ค. ๊ทธ๋์ ๊ธฐ์กด์ ์ ๊ทผ๋ฒ๋ค์ positional bias๋ฅผ ๊ฐ input word embedding์ ๋ํจ์ผ๋ก์จ, content์ position์ ๋ฐ๋ผ์ ๊ฐ์ด ๋ฌ๋ผ์ง๋ vector๋ก ํํํ๋ ค๊ณ ํ์๋ค. ์ด positional bias๋ absolute position embedding, relative position embedding ๋ฑ์ด ์์๋ค. ์๋์ ์์น ์๋ฒ ๋ฉ์ด ์ข์ ๊ฒฐ๊ณผ๋ค์ ์ต๊ทผ๊น์ง๋ ๋ณด์ฌ์ฃผ๊ณ ์๋ ์ถ์ธ๋ผ๊ณ ํ๋ค.
2. DeBERTa#
BERT๋ก๋ถํฐ ๋ ๊ฐ์ง ๊ฐ์ ์ ์ ๋ณด์ฌ์ค๋ค. ์ฐ์ DA(Disentengled Attention : ๋ถ๋ฆฌ๋ ์ดํ
์
), ๊ทธ๋ฆฌ๊ณ enhanced mask decoder์ด๋ค. ์ด์ ์ single vector๋ก ํ๋์ input word์ ๋ด์ฉ๊ณผ ์์ง์ ๋ณด๋ฅผ ํํํ๋ ค๋ ๊ฒ๊ณผ๋ ๋ค๋ฅธ๊ฒ, DA๋ ๋ ๊ฐ์ seperate vector๋ฅผ ์ฌ์ฉํ๋ค. one for the content and one for the position
. ๊ทธ๋ฌ๋ฉด์ DA ๋ฉ์ปค๋์ฆ์ ๋จ์ด๋ค ์ฌ์ด์ attention weight๋ disentangled matrices์ ์ํด์ ๊ณ์ฐ๋๊ณ ์ด๋ ๊ฐ๊ฐ ๋ด์ฉ๊ณผ ์๋์ ์์น ๋ ๊ฐ๊ฐ ๊ฐ๊ฐ์ ํ๋ ฌ๋ก ๋ค๋ฅด๊ฒ ๊ณ์ฐ๋๋ค.
๊ทธ๋ฆฌ๊ณ MLM์ ๋ํด์๋ BERT์ ๋์ผํ๊ฒ ์ฌ์ฉ๋๋ค. DA๊ฐ ์ด๋ฏธ ๋ด์ฉ๊ณผ ์๋์ ์์น์ ๋ํ ๊ณ ๋ฏผ์ด ๋ค์ด๊ฐ๋ ์์ง๋ง, ์ค์ํ ๊ฒ์ absolute position์ ๋ํ ๊ณ ๋ฏผ์ ์๋ค. absolute position์ ์์ธก์์ ๊ฝค๋ ์ฃผ์ํ ์์์์ผ๋ก ์ด๋ฌํ ์ ์ ๋ณด์ํ๊ธฐ ์ํด์ DeBERTa์์๋ enhanced Mask Decoder๋ฅผ MLM์ ๋ณด์ํ๊ธฐ ์ํด์ ์ฌ์ฉํ๋ค. ์ด๋ MLM decoding layer์์ context word์ absolute position information์ด ์ถ๊ฐ๋ก ๋ค์ด๊ฐ๋ ๋ฐฉ์์ด๋ค.
3. ELECTRA#
2.3.1 Masked Language Model(MLM)#
Large-scale Transformer-based PLMs ๋ ๋ณดํต ๋ง์ ์์ ํ ์คํธ๋ก ์ฌ์ ํ์ต๋๋ฉด์ self-supervision objective(์๊ธฐ์ฃผ๋ํ์ต์ ๋ชฉ์ )์ธ MLM์ ์คํํ๊ณ ์ด ๋ง์ธ ์ฆ์จ ๋ฌธ๋งฅ์ ์ดํดํ๊ฒ ๋๋ค๋ ๊ฒ์ด๋ค.
\(X = \{x_i\}\)๋ ํ๋์ sequence์ด๊ณ \(\tilde{X}\)๋ 15%์ ํ ํฐ์ด ์ค์ผ(masking)๋ ๊ฒ์ด๋ค. ๋ชฉํ๋ ๋ค์ reconstruct \(X\)์ด๋ค. ๋ฐฉ๋ฒ์ language model์ predicting the masked tokens \(\tilde{x} \text{ conditioned on } \tilde{X}\) ํ๋ฉด์ ํ๋ จ์ํค๋ ๊ฒ์ด๊ณ parameterized by \(\theta\)์ด๋ค.
C : index set of the masked tokens in the sequence
BERT์์๋ 10%์ masked tokens๋ฅผ ๋ฐ๊พผ ์ํ๋ก ์ ์งํ๊ณ , ๋ค๋ฅธ 10%์ ๋ฌด์์ ์ ํ๋ ํ ํฐ์ ๋ฐ๊พธ์๊ณ , ๋๋จธ์ง 80%๋ ์์ \([MASK]\) token์ผ๋ก ์ ์งํ๋ค.
2.3.2 Replaced Token Detection(RTD)#
BERT๋ ํ๋์ transformer encoder๋ฅผ ์ฌ์ฉํ๊ณ , MLM์ผ๋ก ํ๋ จ๋์๋ค. ์ด์ ๋ค๋ฅด๊ฒ ELECTRA๋ ๋ ๊ฐ์ transformer encoders๋ฅผ ์ฌ์ฉํ๋ฉด์ GAN์ฒ๋ผ ํ๋ จํ๋ค. Generator encoder๋ MLM์ผ๋ก ํ๋ จ๋์๊ณ , discriminator encoder๋ token-level binary classifier
๋ก ํ๋ จ๋์๋ค. generator๋ input sequence์์ ๋ง์คํน๋ token์ ๋์ฒดํ ambiguousํ ํ ํฐ์ ์์ฑํ๋ค. ๊ทธ๋ฆฌ๊ณ ์ด๋ ๊ฒ ๋ง๋ค์ด์ง sequence๋ dicriminator๋ก ๋ค์ด๊ฐ์ ํด๋น ํ ํฐ์ด original ํ ํฐ์ด ๋ง๋์ง ์๋๋ฉด generator๊ฐ ๋ง๋ ํ ํฐ์ธ์ง๋ฅผ ์ด์ง ๋ถ๋ฅํ๋ค. ๊ทธ๋ฆฌ๊ณ ์ด ์ด์ง๋ถ๋ฅํ๋ ๊ฒ์ด RTD์ด๋ค. ์ฌ๊ธฐ์ parameterized ๋๋ ๋ถ๋ถ์ด \(\theta_{G}\)๋ generator์ ํ๋ผ๋ฏธํฐ์ด๊ณ , \(\theta_{D}\)๋ discriminator์ ํ๋ผ๋ฏธํฐ์ด๋ค. Loss function of the generator๋ ์๋์ ๊ฐ๋ค.
\(p_{\theta_G} \big(\tilde{x}_{i,G} = x_i|\tilde{X}_G \big) \Big)\) : ์ด ๋ถ๋ถ์ด G๊ฐ \(\tilde{X}_G\)๋ฅผ \(x_i\)๋ก ์ฌ๊ตฌ์ฑํ ํ๋ฅ ์ด๋ค.
ํนํ๋ masking๋ token(\(\tilde{X}_G\))์ ์๋ณธ์์ randomly maskingํ 15%์ด๋ค.
Log์ ์ฑ์ง
log์์ ๋ถ์ - ๋ ์ฌ๊ตฌ์ฑํ ํ๋ฅ ์ด ๋์์ง์๋ก generator๊ฐ ์ผ์ ์ํ๋ ๊ฒ์ด๊ธฐ ๋๋ฌธ์, ์์ ์ฌ๊ตฌ์ฑํ ํ๋ฅ ์ด ๋์์ง์๋ก loss๊ฐ์ ๋ฎ์์ง๋ ๊ตฌ์ฑ์ด ๋์ด์ผ ๋ชจ๋ธ์ด ๋ชฉํํจ์์ ์๊ฑฐํ ํ์ต์ ์งํํ๊ธฐ์ ๋ถ์ ๊ฒ์ด๋ค.
log๋ 0~1์ฌ์ด์ ํ๋ฅ ๊ฐ\(\mathbb{E}\)์ 0~\(\infty\) ๊ฐ์ผ๋ก ๋ณํํ๋ค. ์ด๋ฅผ ํตํด์ underflow(0 * 0 * โฆ -> 0์ ์ ์ ๋ ๊ฐ๊น์์ง๋)๋ฅผ ๋ฐฉ์งํ๋ฉฐ, ๊ณฑ์ ์ ๋ง์ ์ผ๋ก ๋ณํํ๋ ์ฑ์ง์ ์ด์ฉํ๋ค.
discriminator์ ๋ค์ด๊ฐ๋ input sequences๋ generator์ output probability์ ๋ฐ๋ผ์ new tokens์ด masked tokens ์๋ฆฌ๋ฅผ ์ฑ์ด์ฑ๋ก ๋ค์ด์จ๋ค. ๊ทธ๋์ i๊ฐ C์์ ์์ผ๋ฉด ๊ทธ๋๋ก x์ด๊ณ (replaced token์ด ์๋๋ผ ์์ ์๋ณธ์ด๊ธฐ ๋๋ฌธ), index๊ฐ C์ ์๋๊ฑฐ๋ง ๋น๊ต๋ฅผํ๋ค.
\(\sim\) : sim ๊ธฐํธ๋ โdistributed asโ ๋๋ โhas a distribution ofโ๋ก ํด์๋๋ฉฐ, โํน์ ํ๋ฅ ๋ถํฌ๋ฅผ ๋ฐ๋ฅธ๋ค๋ ์๋ฏธโ, โ๋ฐ๋ฅธ๋ค, ๋ถํฌํ๋คโ์ ๋๋ค. ์ฌ๊ธฐ์ sim ๊ธฐํธ๋ ์ฃผ์ด์ง ํ๋ฅ ๋ถํฌ \(p_{\theta_G}\)์ ๋ฐ๋ผ ๋ณ์ \(\tilde{x}_i\)๊ฐ ๋ถํฌํ๋ค๋ ๊ฒ์ ์๋ฏธ.
๋ฐ๋ผ์, \(\tilde{x}_i \sim p{\theta_G} (\tilde{x}_{i,G} = x_i|\tilde{X}_G)\)๋ ๋ณ์ \(\tilde{x}_i\)๊ฐ ์กฐ๊ฑด๋ถ ํ๋ฅ ๋ถํฌ \(p{\theta_G} (\tilde{x}_{i,G} = x_i|\tilde{X}_G)\)๋ฅผ ๋ฐ๋ฅธ๋ค๋ ๊ฒ์ ์๋ฏธํฉ๋๋ค. ์ด ๋ถํฌ๋ \(p_{\theta_G}\) ํ๋ผ๋ฏธํฐ๋ฅผ ๊ฐ์ง๋ฉฐ, \(\tilde{X}_G\)๊ฐ ์ฃผ์ด์ก์ ๋ \(\tilde{x}_{i,G} = x_i\)์ธ ์กฐ๊ฑด๋ถ ๋ถํฌ๋ฅผ ๋ํ๋ ๋๋ค. ์ ์์์ ์ ์ฒด์ ์ธ ์๋ฏธ๋, ์ธ๋ฑ์ค i๊ฐ ์งํฉ C์ ํฌํจ๋์ด ์์ผ๋ฉด, \(\tilde{x}_i\)๋ ์กฐ๊ฑด๋ถ ๋ถํฌ \(p_{\theta_G} (\tilde{x}_{i,G} = x_i|\tilde{X}_G)\)๋ฅผ ๋ฐ๋ฅด๊ณ , ๊ทธ๋ ์ง ์์ผ๋ฉด \(\tilde{x}_i\)๋ \(x_i\)์ ๋์ผํ๋ค๋ ๊ฒ์ ์๋ฏธํฉ๋๋ค.
\(\mathbb{I}\) : indicator function์ ๋งํ๋ค. \((\tilde{x_{i,D}} = x_i)\)๋ฅผ ์ถฉ์กฑํ๋ฉด 1์ ๋ฐํํ๊ณ , ์๋๋ฉด(\(\tilde{X}_D,i\)) 0์ ๋ฐํํ๋ ํจ์๋ฅผ ๋งํ๋ค. ๋งค์ฐ ์๊ฒฉํ ํจ์๋ก ๋ณผ ์ ์์ผ๋ฉฐ ํ๋ณด๋ โsigmoidโ, โtahnโ, โReLUโ, โLeaky ReLUโ๋ฑ์ด ์๋ค.
์ discriminator์ loss function์ input์ \(\tilde{X_D}\) ์ด๋ฉฐ ์ด๊ฒ์ ์์ 3๋ฒ equation์ result์ด๋ค. ์ฆ discriminator์์ ํ๋ฒ ๊ฑธ๋ฌ์ ธ ๋์จ ๊ฒ์ด loss function์ ๋ค์ด๊ฐ๋ ๊ฒ์ด๋ค.
์ด์ฒด์ ์ธ ELECTRA์ loss function์ ์๋์ ๊ฐ์ด ์ ๋ฆฌํ ์ ์๋ค. $\( L = L_{MLM} + \lambda L_{RTD} \)$
\(\lambda\) : discriminator loss function์ ๋ํ weight๋ฅผ ๋ํ๋. ๋ชจ๋ธ ํ์ต์์ ํด๋น ์์ค์ ์ค์์ฑ์ ์กฐ์ ํ๋๋ฐ ์ฌ์ฉ๋จ. ์ฌ๊ธฐ์๋ 50์์ผ๋ก MLM์ ๋นํด์ 50๋ฐฐ์ ๊ฐ์ค์น๋ฅผ ๋ ๊ณฑํด์ค๋ค๋ ์๋ฏธ์์ผ๋ก RTD์ ์์ฒญ ์ค์์ฑ์ ๋๊ฒ ์ณ์ฃผ๋ ๊ฒ์ด๋ผ๊ณ ๋ณผ ์ ์๋ค.
DeBERTaV3#
DeBERTa + RTD training loss + new weight-sharing method
3.1 DeBERTa + RTD#
ELECTRA์์ ๊ฐ์ ธ์จ RTD, ๊ทธ๋ฆฌ๊ณ DeBERTa disentangled attention mechanism์ ํฉ์ ํ๋ฆฌํธ๋ ์ด๋ ๊ณผ์ ์์ ๊ฐ๋จํ๊ณ ํจ๊ณผ์ ์ธ ๊ฒ์ผ๋ก ํ๋ณ๋์๋ค. ์ด์ DeBERTa์์ ์ฌ์ฉ๋์๋ MLM objective๋ฅผ RTD objective๋ก ๋ฐ๊ฟ์ผ๋ก์จ ๋์ฑ disentangled attention mechainsm์ ๋์ฑ ๊ฐํํ๋ ๊ฒ์ด๋ค.
training ๋ฐ์ดํฐ๋ก๋ Wikipedia, bookcorpus์ ๋ฐ์ดํฐ๊ฐ ์ฌ์ฉ๋์๋ค. generator๋ discriminator์ ๊ฐ์ width๋ฅผ ๊ฐ์ง๋ depth๋ ์ ๋ฐ๋ง ๊ฐ์ ธ๊ฐ๋ค. batch size๋ 2048์ด๋ฉฐ 125,000 step์ด ํ๋ จ๋์๋ค. learning_rate = 5e-4, warmup_steps = 10,000, ๊ทธ๋ฆฌ๊ณ ์์์ ๋งํ๋ฏ์ด RTD loss function์ ๊ฐ์ค์น๋ฅผ 50์ ์ค์ผ๋ก์ optimization hyperparameter๋ฅผ ์ฌ์ฉํ์๋ค.
๊ฒ์ฆ ๋ฐ์ดํฐ๋ก๋ MNLI, SQuAD v2.0์ ์ฌ์ฉํ์๊ณ , ์ด ๋ฐ์ดํฐ๋ค์ ๋ํ ์ ๋ฆฌ๋ ํ์ํ ๊ฒ์ด๋ค. ๊ฒฐ๊ณผ์ ์ผ๋ก DeBERTa๋ฅผ ์๋ํ์ง๋ง ๋์ฑ๋ improved๋ ์๋ ํฌ์ธํธ๋ฅผ ๋งํ๋ ์ง์ ์ด ์๋ค. token Embedding Sharing(ES) used for RTD(๊ธฐ์กด์ ์ฌ์ฉ๋์๋)๋ฅผ new Gradient-Disentangled Embedding Sharing(GDES) method๋ก ๋ฐ๊ฟ์ผ๋ก์จ ๋์ฑ ๋ฐ์ ๋ ๊ฐ๋ฅ์ฑ์ด ์๋ค๊ณ ๋งํ๋ค.
3.2 Token Embedding Sharing (in ELECTRA)#
ELECTRA์์๋ generator์ discriminator๊ฐ token embedding์ ๊ณต์ ํ๋ค. ์ด๊ฒ์ด Embedding Sharing(ES)
์ด๋ค. ์ด ๋ฐฉ๋ฒ์ generator๊ฐ discriminator์ input์ผ๋ก ๋ค์ด๊ฐ ์ ๋ณด๋ฅผ ์ ๊ณตํจ์ผ๋ก์จ ํ์ต์ ํ์ํ parameter๋ฅผ ์ค์ฌ์ฃผ๋ ์ญํ ์ ํ๊ณ ํ์ต์ ์ฉ์ดํ๊ฒ ํด์ค๋ค. ๊ทธ๋ฌ๋ ์์์ ๋งํ๋ฏ์ด ๋ ๊ธฐ์ ์ ๋ชฉ์ ๋ฐฉํฅ์ฑ์ด ๋ฐ๋์ด๊ธฐ ๋๋ฌธ์ ์๋ก๋ฅผ ๋ฐฉํดํ๊ณ , ํ์ต ์๋ ด์ ์ ํดํ ๊ฐ๋ฅ์ฑ์ด ํฌ๋ค.
\(E\) : token embeddings
\(g_E\) : gradients = \(\frac{\delta L_{MLM}}{\delta E} + \lambda\frac{\delta L_{RTD}}{\delta E}\)
์์ equation์ token embeddings(E)๊ฐ ๋ ๊ฐ์ ์ผ์์์ gradient๋ฅผ ํ ๋ฒ์ ์กฐ์ ํ๋ฉด์ update๋๋ค๋ ๊ฒ์ ์๋ฏธํ๋ค. ์์์ ๋งํ๋ฏ์ด ์ด๊ฒ์ ์ค๋ค๋ฆฌ๊ธฐ ์ด๋ค. ์์ฃผ ์กฐ์ฌ์ค๋ฝ๊ฒ update ์๋๋ฅผ ์กฐ์ ํ๋ฉด์(small learning_rate, gradient clipping) ํ์ต์ ์งํํ๋ฉด ๊ฒฐ๋ก ์ ์ผ๋ก๋ ์๋ ด์ ํ๊ธฐ๋ ํ๋ค๊ณ ๋งํ๋ค. ํ์ง๋ง ๋ ๊ฐ์ task๊ฐ ์ ๋ฐ๋์ ๋ชฉ์ ์ ๊ฐ์ง๋ค๋ฉด ์ด๊ฒ์ ๊ต์ฅํ ๋นํจ์จ์ ์ด๋ฉฐ, ํด๋น ์ํฉ(MLM,RTD)์ ์ ํํ ๊ทธ๋ฐ ์ํฉ์ด๋ผ๊ณ ๋ณผ ์ ์๋ค. ๋ ๊ฐ์ task๊ฐ token embedding์ ์ ๋ฐ์ดํธ ํ๋ฉด์ ๋ฐ๋ผ๋ ๊ฒ์ด ํ๋๋ ์ ์ฌ์ฑ์ ๋ฐ๋ผ์ ๊ฐ๊น๊ฒ ํ๊ณ , ๋ค๋ฅธ ํ๋๋ ์ ์ฌ์ฑ์ ๋ฐ๋ผ์ ๋ฉ๊ฒํ์ฌ ๋ถ๋ฅํ๋ ๊ฒ์ด๊ธฐ ๋๋ฌธ์ด๋ค.
์ด๊ฒ์ ์ค์ ๋ก ํ์ธํ๊ธฐ ์ํด์ ์ฌ๋ฌ ๋ค์ํ ELECTRA๋ฅผ ๊ตฌํํ๋, ํด๋น ELECTRA๋ค์ token embedding์ ๊ณต์ ํ์ง ์๋๋ก ๊ตฌํํ๋ค๊ณ ํ๋ค. ๊ทธ๋ ๊ฒ ๊ตฌํ์ ํ๋ฉด No Embedding Sharing(NES)๊ฐ ๋๋ ๊ฒ์ด๋ค. ์๋ค๋ gradient update๊ฐ ๊ฐ๊ฐ ๋๋ค. ์ฐ์ ์ (1) generator์ parameter(token embedding with \(E_G\))๊ฐ MLM loss๋ฅผ back-propํ๋ฉด์ ์ ๋ฐ์ดํธ๋๊ณ , (2) ์ดํ์ discriminator๊ฐ generator output์ input์ผ๋ก ๋ฐ๋๋ค (3) ๋ง์ง๋ง์ผ๋ก discriminator parameter(token embeddings with \(E_D\))๋ฅผ RTD loss๋ฅผ back-propํ๋ฉด์ updateํ๋ค.
์ด๋ค์ 3๊ฐ์ง๋ก ES vs NES๋ฅผ ๋น๊ตํ๋ค๊ณ ํ๋ค.
convergence speed : NES๊ฐ ๋น์ฐํ gradient conflict๋ฅผ ๋ฐฉ์งํจ์ผ๋ก ์น๋ฆฌ
quality of token embeddings : average cosine similiarity scores๋ฅผ ๋น๊ตํ๋ค. \(E_G\)์์๋ ๊ต์ฅํ ํจ๊ณผ๊ฐ ์ข์์ง๋ง \(E_D\)๋ ํ์ต์ ๊ฑฐ์ ๋ชปํ ๊ฒ์ผ๋ก ๋ณด์๋ค. ํจ๊ณผ๊ฐ ์ข๋ค๋ ๊ฒ์ ์๋ฏธ์ ์ผ๋ก coherent ์ผ๊ด์ฑ ์๊ฒ \(E_G\)๊ฐ ์ฝ์ฌ์ธ ์ ์ฌ๋๊ฐ ๋งค์ฐ ๋์์ง๋ ๊ฒ์ ๋งํ๋ค.
performance on downstream NLP tasks : ๋ํ NES๊ฐ ๋ค์ด์คํธ๋ฆผ test์์๋ ์ข์ ๋ชจ์ต์ ๋ณด์ด์ง ๋ชปํ๋ค
ES๊ฐ generator embedding์ผ๋ก๋ถํฐ discriminator๊ฐ ํ์ต์ ํ ๋ ๋์์ ๋ฐ๋๊ฒ ์ฅ์ ์ด ์๋ค๋ ๊ฒ์ ์ ์ ์๋ค.
??? average cosine similiarity of word embeddings of the G vs D#
average cosine similiarity๊ฐ ๋์์๋ก ์ข์ ๊ฒ์ธ๊ฐ? ์ด๋ค ์๋ฏธ์ธ์ง ์ ๋๋ก ์ดํด๋ฅผ ๋ชปํ ๊ฑฐ ๊ฐ๋ค.
3.3 Gradient-Disentangled Embedding Sharing(GDES)#
ES, NES์ ๋จ์ ์ ๊ฝค ๋ซ๊ธฐ ์ํด ํด๋น ๋ ผ๋ฌธ์์ ์ค์ํ๊ฒ ๋งํ๋ ์ง์ ์ด๋ค. ๋ ๊ฐ์ ์ฅ ๋จ์ ์ด ๋ถ๋ช ํ๊ฒ ์กด์ฌํ๋ฉด์ ๋ ๊ฐ๋ฅผ ๋ชจ๋ ์ฑ๊ธธ ๋ฐฉ๋ฒ์ผ๋ก ๋์จ ๊ฒ์ด๋ค. ํ ๋ฒ ์ ๋ฆฌ๋ฅผ ํ์๋ฉด ES๋ ํ์ต์ ๋๋ฆฌ์ง๋ง generator output : token embedding๋ฅผ discriminator๊ฐ ์ฐธ์กฐํ๋ฉด์ ํ์ต parameter reducing์ ๋์์ ๋ฐ๋ ๋ค๋ ๊ฒ์ด๋ค. ํ์ง๋ง ๋จ์ ์ generator discriminator token embeddings๊ฐ ๋ ๋ค ์ผ๊ด์ฑ์ด ์์ด์ง๋ค๋ ๊ฒ์ด๋ค.
๋ฐ๋ฉด์ NES๋ ํ์ต์ด ๊ต์ฅํ ๋นจ๋ผ์ง๋ค. G,D์ ๋ฐฉํฅ์ฑ์ ์ ๋ฐ๋์ ์ฑ์ง์ ํด๊ฒฐํด ์ค์ผ๋ก์จ ํ์ต์ด ์ฉ์ดํ๊ฒ ๋๋ค๊ณ ๋ณผ ์ ์๋ค. ํ์ง๋ง ๊ฒฐ๋ก ์ ์ผ๋ก๋ ํ์ต์ ์คํ๋ ค ES๋ณด๋ค ๋ชปํ ๊ผด์ด ๋๋ค. ๊ทธ๋๋ ์ฅ์ ์ G์ token embedding์ด ์ฝ์ฌ์ธ ์ ์ฌ๋๊ฐ ๋์ ์ผ๊ด์ฑ ์๋ embedding์ ๋ง๋ค์ด์ฃผ๋ ๊ฒฝํฅ์ฑ์ ๋ง๋๋ ๊ฒ์ด ๊ฐ๋ฅํ๋ค๊ณ ๋ณผ ์ ์๋ค.
์ด ๋ชจ๋ ๋จ์ ์ ์ปค๋ฒํ๊ณ ๋๋์ฒด ์ด๋ป๊ฒ ์ฅ์ ๋ง ๋จ๊ธด๋ค๋ ๊ฒ์ธ๊ฐ? ์ฅ์ ๋ง ๋จ๊ธด๋ค๋ฉด ํ์ต์ ์๋๋ ๋นจ๋ผ์ง๋ฉด์ ๋์์ G,D์ token embedding์ด ์ ์ฌ์ฑ์ ์ ์งํ๋ ๊ฒ์ด ๋ ๊ฒ์ด๋ค. ํ์๋ฅผ ๋ ผ๋ฌธ์์๋ โlearn from same vocabulary and leverage the rich semantic information encoded in the embeddingsโ๋ผ๊ณ ๋งํ๋ค.
์ด๊ฒ์ GDES๋ ์ค์ง generator embeddings๋ฅผ MLL loss๋ง ๊ฐ์ง๊ณ ์ ๋ฐ์ดํธํจ์ผ๋ก์จ, output์ ์ผ๊ด์ฑ๊ณผ ํต์ผ์ฑ์ ์ ์งํ๋ค. ๊ทธ๊ฒ์ ์์์ ์ผ๋ก ๋ณด๋ฉด
์๋๋ NES์์๋ \(E_D\)๋ก ๋ฐ๋ก backpropํ๋ ๊ฒ์ re-parameterizeํ์ฌ discriminator embedding๋ฅผ ์๋ก ์ ์ํ๋ค. \(sg\)d์ ์ญํ ์ \(E_G\)์์ ๋์จ gradient๊ฐ ๊ณ์ ํ๋ฌ๋ค์ด๊ฐ๋ ๊ฒ์ ๋ง๊ณ , residual embeddings \(E_{\Delta}\) ๋ง์ ์ ๋ฐ์ดํธ ํ๊ฒ ํ๋ ๊ฒ์ด๋ค. ์ด๊ฒ์ residual learning์์์ ์์ด๋์ด์ ๊ต์ฅํ ์ ์ฌํ ๊ฒ ๊ฐ์๋ฐ!!!.
(1) G output -> input for discriminator = \(E_G\)
(2) update \(E_G\) \(E_D\) with MLM loss
(3) D run on G output
(4) update \(E_D\) with RTD loss with only \(E_{\Delta}\)
(5) after training, \(E_{\Delta}\) + \(E_G\) = \(E_D\)
์ ์ embedding sharing์ ์ฐจ์ด๋ง ์๊ณ , computation cost์ ์ฐจ์ด๋ ์์ผ๋ค. computation cost์ ์ฐจ์ด๊ฐ ์์ด ์์ด๋์ด๋ง์ผ๋ก ์ฑ๋ฅ์ up์ ํ ๊ฒ๋ resnet์ด๋ ๋น์ทํ๋ค.
์ฝ์ฌ์ธ ํ๊ท ์ ์ฌ๋์์๋ ์ฐจ์ด๊ฐ NES๋ณด๋ค๋ ๋ํ๋ฐ ์ด๋ โpreserves more semantic information in the discriminator embeddings through the partial weight shargingโ์ด๋ผ๊ณ ๋งํ๋ค. ์ฌ๊ธฐ์ ๋ณด์ด๋ partial weight sharing์ด \(E_{\Delta}\)์ด๋ฉฐ ์ด๊ฒ์ด embedding์ ์์ฐจ๋ฅผ ํ์ตํ๋ ๋ฐฉ์์ผ๋ก ์งํ๋จ์ผ๋ก์จ ํ์ต์ ์ฉ์ดํ๊ฒ ๊ฐ์ ธ๊ฐ๋คโฆ ์ ๋๋ก ๋ณด์ธ๋ค.
Conclusion#
pre-training paradigm for language models based on the combination of DeBERTa and ELECTRA, two state-of-the-art models that use relative position encoding and replaced token detection (RTD) respectively
interference issue between the generator and the discriminator in the RTD framework which is well known as the โtug-of-warโ dynamics.
GDES : the discriminator to leverage the semantic information encoded in the generatorโs embedding layer without interfering with the generatorโs gradients and thus improves the pre-training ef๏ฌciency
a new way of sharing information between the generator and the discriminator in the RTD framework, which can be easily applied to other RTD-based language models
debertav3-large : 1.37% on the GLUE average score
๋ชฉ์ :
parameter-ef๏ฌcient pre-trained language models