본문 바로가기
NLP/논문리뷰

[논문 Review] 06. REALM : Retrieval-Augmented Language Model Pre-Training

by ㅣlㅣl 2023. 12. 13.
retriever - encoder 모델 구조를 통해 QA task에서의 성능을 높이자!


https://arxiv.org/abs/2002.08909

 

REALM: Retrieval-Augmented Language Model Pre-Training

Language model pre-training has been shown to capture a surprising amount of world knowledge, crucial for NLP tasks such as question answering. However, this knowledge is stored implicitly in the parameters of a neural network, requiring ever-larger networ

arxiv.org

 

Abstract

언어 모델 사전학습은 많은 양의 world knowledge를 포착하는 것으로 나타났고 이는 QA와 같은 NLP 작업에 매우 중요하다. 하지만 이러한 지식은 네트워크 파라미터에 함축적으로 저장되므로, 더 많은 지식을 위해서는 점점 더 큰 규모의 네트워크가 필요해진다.

좀 더 모듈화되고 해석 가능한 방식으로 지식을 저장하기 위해서 언어 모델 사전 학습을 latent knowledge retriever 로 보강하여 Wikipedia 같은 웹 상의 대규모 말뭉치에서 수집된 문서를 검색해 사용할 수 있게 한다. 

본 논문에서는 수백만 개의 문서들을 고려하는 retrieval step을 통해 backpropagation을 진행하며 unsupervised 방식으로 knowledge retriever를 pre-train하는 방식을 보인다.

Open-QA라는 어려운 태스크에 대해  REALM 모델의 Fine-tuning을 진행했는데, 그 결과 3가지 Open-QA 벤치마크에 대해 이전 방법론들을 효과적으로 넘어서는 성능을 보였다.


1. Introduction

최근 언어 모델 사전 학습의 발전으로 BERT, RoBERTa, T5와 같은 모델은 방대한 지식을 저장하고 있다. 이는 신경망의 파라미터에 암시적으로 저장되고, 따라서 네트워크에 어떤 지식이 어디에 저장되어 있는지는 파악하기가 어렵다. 게다가 저장 공간은 네트워크 크기에 따라 제한되기 때문에 더 많은 지식을 학습하기 위해서는 더 큰 규모의 네트워크를 훈련해야 하는데, 이는 큰 훈련 비용을 수반한다.

따라서 보다 해석 가능하고 모듈화된 방식으로 지식을 포착하기 위해 learned textual-knowledge retriever을 통해 LM의 pre-training을 강화하는 새로운 프레임 워크, REALM을 제안한다.

이전 모델들이 지식을 파라미터에 암시적으로 저장한 것과는 달리 모델이 어떤 지식을 retrieve하고 추론 과정에서 사용할지를 결정하도록 만들어 사전 학습 때 습득한 지식의 역할을 명시적으로 제시한다.

각 예측 이전 LM은 retriever를 사용해 대량의 말뭉치로부터 문서를 검색한 다음, 해당 문서를 통해 예측에 도움이 되는 정보를 얻는다. 이 모델을 end-to-end로 학습하기 위해서는 전체 말뭉치를 고려하는 retrieval step을 통해 backpropagation을 진행해야 한다.

모든 말뭉치에 대해서 backpropagation을 진행하려면 cost가 상당히 많이 들어가는데, 이를 개선하기 위한 방법은 뒷 챕터에서 보다 자세히 설명하도록 하겠다.

[그림 1] REALM backpropagation

\(\theta\) : knowledge retriever와 관련된 모든 파라미터
\(\phi\) : knowledge augmented encoder와 관련된 모든 파라미터

그림 1에 REALM의 backpropagation 과정이 나와있다. textual knowledge corpus Z에서 지식을 반환하는 neural knowledge retriever를 통해서 LM pretraining을 강화한다. 

 

REALM의 중요한 포인트는 retriever를 학습하는 것인데, LM의 perplexity를 개선하는(=예측에 도움이 되는) retrieval은 보상을 받고 그렇지 않은 retrieval은 패널티를 받는 방식으로 학습이 진행된다.

그림 1의 문장을 예시로 들면, 주어진 입력 "The [MASK] at the top of the pyramid" 에서의 예측을 위해 Retrieval은 "top of the pyramid" 에 대한 정보를 포함하고 있는 문서 "The pyramidion on top allows for less material higher up the pyramid" 를 선택했을 때 보상을 받아야 한다.

이를 위해서는 retrieve-then-predict 접근 방식을 잠재 변수 언어 모델 (latent variable language model) 로 모델링하여 marginal likelihood를 최적화한다.

  • latent variable model : 관찰 가능한 데이터에서 직접 측정할 수 없는 '잠재 변수'를 이용해 모델을 구성하는 방식이다. 잠재 변수는 직접 관찰되지 않지만 관찰 가능한 데이터의 분포를 설명하는 데 중요한 역할을 하는 변수이다. 예를 들어, 사용자의 구매 패턴에서 각각의 구매 이벤트는 관찰 가능하지만, 사용자의 숨겨진 구매 선호도는 직접적으로 관찰할 수 없는 잠재 변수가 될 수 있다. [각주:1]
  • marginal likelihood : knowledge augmented encoder 에서 marginal log-likelihood loss를 사용해 학습을 진행

Marginal likelihood에 대한 정의 & 간단한 예시

더보기

결합분포하는 두 확률변수에서 한 변수만의 확률분포를 고려하는 경우 확률변수의 marginal distribution이라 한다.

= joint distribution이 주어졌을 때,  variable 하나에 대해서 관심이 없는 상태

예를 들어서 동전을 두 개를 던집니다. 그런데, 이 동전 두 개가 독립적이지 않다고 합시다.

첫번째 동전이 어떻게 나오느냐에 따라서, 두번째 동전이 head이냐, tail이냐가 바뀐다고 가정하죠.

그런데 어느 순간 보니까 첫번째 동전은 아무 의미가 없는 것 같습니다.
만약 우리가 두번째 동전이 head가 나온 것만이 중요하다고 생각한다면, 두번째 동전이 head가 나온 case에 대해서 첫번째 동전의 case를 모두 더합니다. 그럼 그것이 marginalize라는 것입니다.

여러 개의 확률 변수로 구성된 조합 확률분포(joint distribution)에서 한 가지 변수에 대한 확률값을 추정하기 위해 나머지 변수를 모두 적분하여 제거하는 과정을 말합니다. [각주:2]

 

여기서는 document z가 명시적으로 드러나지 않는 latent variable이 되어 marginalize의 대상이 된다.

 

Pre-training 중에 규모가 큰 retrieval 모듈을 통합하는 것은 상당한 계산적 어려움을 초래하는데, 각 pre-training step에서 수백만 개의 후보 문서 Z를 고려해야 하고 이를 통해 backpropagation을 계산해야 하기 때문이다.

이 문제를 해결하기 위해 각 문서에 대해 수행되는 계산이 캐시되고 비동기적으로 업데이트될 수 있도록 Retriever를 구성하고, 최상의 문서를 선택하는 것이 Maximum Inner Product Search (MIPS) 로 공식화될 수 있도록 한다.

 

  • MIPS

쿼리 벡터를 q, 답변 후보군을 V라고 하겠습니다. 우리는 결국 V에 속해 있는 각 답변 벡터들 v_i 중에서 q와의 ‘상성이 가장 좋은’ 것을 찾아야합니다. 그 좋은 상성의 기준이 q와 v_i와의 내적이 큰 것이면 MIPS, q와 v_i의 유클리디안 거리(Euclidean Distance)가 작은 것이면 NNS, q와 v_i의 코사인유사도(Cosine Similarity)가 큰 것이면 MCSS가 됩니다. [각주:3]

MIPS 수식

 

선행 연구[각주:4][각주:5]에서 네트워크에 discrete retrieval step을 추가하는 것이 도움이 된다는 것을 증명했지만, LM pre-training에 프레임워크를 적용하지 않았고 큰 문서를 다루기 위해 non-learned retriever를 사용했다.

다른 연구인 knn-LM[각주:6]은 retrieval 매커니즘의 적용 방식이 불명확해서 downstream task로 fine-tuning 되지 않았고, 목표 작업에 대해 레이블링 된 예제만 사용할 수 있는 반면 REALM retriever다른 태스크에 transfer 될 수 있도록 설계되었고 레이블링 되지 않은 텍스트 데이터를 사용한다.

 

3가지 openQA benchmark에 대해서 REALM 방식으로 pre-training된 모델을 테스트해본 결과 시도해본 벤치마크 모두에서 SOTA를 달성했고, 이전 SOTA보다 4~16%의 큰 성능 향상을 보였다.

  • Natural Questions-Open
  • WebQuestions
  • CuratedTrec

 

 

2. Background

Language model pre-training

LM pre-training의 목표는 레이블링되지 않은 텍스트 데이터로부터 언어의 유용한 표현 방법을 학습하는 것이다.

Masked LM

레이블링되지 않은 pre-training corpus X가 주어졌을 때 텍스트의 토큰을 확률적으로 마스킹샘플 (x, y)를 생성하고, 모델은 입력에서 누락된 토큰을 예측하도록 훈련되는데 이 과정에서 모델은 구문 및 의미 정보, 일부 world knowledge를 인코딩하는 방법을 학습하게 된다.

e.g., x = “The [MASK] is the currency [MASK] the UK”; y = (“pound”, “of”)).

 

Open-domain question answering (Open-QA)

모델이 world knowledge를 포함하는 능력을 평가하기 위해서 world knowledge가 중요한 downstream task가 필요하다.

이 때 자연어 처리 태스크 중 가장 knowledge-intensive한 태스크가 바로 open-QA이다.

e.g.
input(question) x : “What is the currency of the UK?”
output(answer) y : “pound”

 

Open-QA vs Reading Comprehension

  • Open-QA는 SQuAD와 같은 전통적인 읽기 이해(RC) 작업과 달리 답변이 포함된 문서를 미리 식별하지 않는다.
  • RC 모델은 1개의 문서에 대한 이해가 필요하지만, Open-QA는 수백만 개의 문서에 대한 질문이 있을 수 있으므로 그에 대한 지식을 보유해야 한다.

 

Open-QA에 대해 retrieval-approach를 시행한 선행연구[각주:7][각주:8] [footnote]Chen, D., Fisch, A., Weston, J., and Bordes, A. Reading wikipedia to answer open-domain questions. In Pro-ceedings of the 55th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers),volume 1, pp. 1870–1879, 2017.[/footnote]에서는 질문 x가 주어지면 textual knowledge corpus Z를 지식 소스로 활용하여 x와 잠재적 관련성이 존재하는 문서 z를 검색한 다음, 문서에서 답변 y를 추출한다. 본 연구도 이에 영감을 받아 LM pre-training 학습 방식에 해당 방법을 적용하였다.

최근에는 x에 seq2seq 모델을 적용하여 y를 토큰 단위로 직접 생성하는 generation-based approach가 제안되었는데, 뒷 부분에서 두 방식의 선행 연구와 성능을 비교한다.

 

3. Approach

 

[그림 2] REALM의 전체 figure

왼쪽 그림은 Pre-training 과정, 오른쪽 그림은 Fine-tuning 과정을 나타낸다.

이 때 \(\theta\)는 retriever와 관련된 모든 파라미터, \(\phi\)는 encoder와 관련된 모든 파라미터이다.

 

 

REALM's generative process

pre-training, fine-tuning 모두 입력 x를 받아 가능한 출력 y에 대한 분포 p(y|x)를 학습한다.

  • pre-training : Masked Language Modeling task
    • Pre-training corpus X에서 일부 토큰을 마스킹한 문장을 통해 누락된 토큰 값인 answer y를 예측
  • Fine-tuning : Open-QA task
    • x : question, y : answer

 

REALM은 p(y|x)를 retrive - predict의 2단계로 분해한다.

  • Retrieve : 입력 x가 주어졌을 때 knowledge-corpus Z에서 유용한 문서 z를 검색하고, 이를 p(z|x) 분포의 샘플로 모델링
  • Predict : 검색된 문서 z, 원래 입력 x를 모두 조건으로 한 p(y|z,x) 분포로 모델링해 출력 y를 생성한다. 이후 y를 생성할 전체 likelihood를 구하기 위해 z를 latent variable로 취급하며, 가능한 모든 문서 z에 대해 marginalize를 진행하여 다음과 같은 최종 수식을 얻을 수 있다.

 

Model Architecture

모델은 neural knowledge retriever, knowledge-augmented encoder 2개의 주요 구성요소로 이루어져있다.

 

Neural knowledge retriever

\(p(z|x)\)를 모델링한다.

\(Embed_{input}\) : x를 d차원 벡터로 매핑하는 임베딩 함수
\(Embed_{doc}\) : z를 d차원 벡터로 매핑하는 임베딩 함수
f(x, z) : x와 z 벡터 임베딩을 내적한 relevance score

retrieval distribution은 모든 relevance scores에 대한 softmax이다.

BERT-style transformer를 사용하여 임베딩 함수를 구현하는데, text를 wordpiece tokenizer를 통해 결합하고 [SEP] 토큰으로 분리하며, [CLS] 토큰을 접두어로 사용하고 마지막에 [SEP] 토큰을 추가한다.

BERT에서와 마찬가지로 이를 Transformer에 전달해 각 토큰마다 하나의 벡터(=[CLS]의 hidden state vector) 를 생성한다. 

 

마지막으로 벡터의 차원 축소를 위해 선형 투영을 진행하고 수식은 아래와 같다.

\(z_{title}\) : document의 title
\(z_{body}\) : document의 body

 

 

knowledge-augmented encoder

\(p(y|z, x)\)를 모델링한다.

주어진 입력 x검색된 문서 z단일 시퀀스로 결합하고, Retriever module에서 사용된 것과는 다른 Transformer에 넣는다. 

여기서는 pre-training과 fine-tuning에서의 architecture에 차이가 있다.

  • pre-training : 각 [MASK] 토큰의 원래 값을 x로 예측해야 하는데, 이를 위해서 BERT에서 쓰인 것과 동일한 MLM loss를 사용한다.

\(BERT_{MASK_{(j)}}\) : j번째 마스킹 토큰에 상응하는 transformer 출력 벡터
\(J_x\) : x에 존재하는 [MASK] 토큰의 총 개수
\(w_j\) : token \(y_j\) 에 대해 학습된 word embedding

 

  • fine-tuning : 답변 y가 어떤 문서 z에서 연속적인 토큰 시퀀스로 발견될 수 있다고 가정하고, S(z, y)y와 일치하는 span 집합이다.

\(BERT_{START_{(s)}}\), \(BERT_{END_{(s)}}\) : span S의 시작 토큰, 끝 토큰에 해당하는 transformer 출력 벡터
MLP : Feed Forward Neural Network
\(\phi\) : knowledge - augmented encoder와 관련된 모든 파라미터

 

 

Training

pre-training, fine-tuning 두 과정에서 모두 log likelihood인 \(log p(y|x)\)를 최대화하는 방식으로 학습이 진행된다.

retriever, encoder 모두 미분 가능한 네트워크이므로 모델 파라미터 \(\theta, \phi\) 에 대한  \(log p(y|x)\) 의 gradient를 계산하고 SGD 방식을 통해서 이를 최적화할 수 있다.

 

SGD 방법론

https://ll2ll.tistory.com/29

 

[논문 Review] 04. Overview of Gradient Descent algorithms

Gradient Descent를 활용한 Optimizer가 어떻게 발전했는지 알아보자 Abstract 본 논문에서는 각 Gradient Descent Algorithms의 장단점을 알아보고, 다양한 알고리즘의 동작에 대한 직관을 가질 수 있게 한다. 수

ll2ll.tistory.com

 

 

그러나 p(y|x) 를 계산하기 위해서는 다음과 같이 knowledge corpus Z에 속해있는 모든 documents z 에 대해 합연산을 수행해야 하므로 연산 비용이 크게 든다.

따라서 이에 대한 근사치를 구하는 방법으로 p(z|x) 에서 가장 높은 확률을 가진 상위 k개의 문서만 합산하는 방식을 사용한다. 

 

그렇다면 상위 k개의 문서를 뽑을 때 어떻게 해야 효율적으로 찾을 수 있을까?

p(z|x) 하에서 문서의 ordering은 neural knowledge retriever에서 구한 relevance score f(x,z) 에 따른다.

따라서, MIPS 알고리즘을 적용해 top k documents를 뽑을 수 있다.

 

MIPS 적용을 위해 먼저 모든 \(z\in Z\) 에 대해 \(Embed_{doc}(z)\) 를 계산해야 하고 효율적인 search index를 구축해야 한다. 그러나 이 데이터 구조는 \(Embed_{doc}\)의 파라미터 \(\theta\) 가 업데이트되면 더 이상 유지되지 않으므로, gradient update 이후에는 search index가 "무효" 상태가 된다.

이를 해결하기 위해 수백 번의 training step마다 모든 문서를 비동기적으로 다시 embedding & indexing하여 index를 새롭게 갱신한다. (MIPS index의 유효성은 갱신할 때마다 살짝 떨어지지만, 이는 오직 상위 k개의 문서를 선별하는 것에만 사용되므로 그렇게 중요하지 않다)

상위 k개의 문서를 검색해서 받아온 후 새롭게 갱신된 \(\theta\)를 사용해 p(z|x)와 gradient를 다시 계산한다. 이후 섹션 4.5에서 이 절차가 안정적인 최적화를 가능하게 한다는 것을 실험적으로 입증했다.

 

Implementing asynchronous MIPS refreshes

아래 2가지 작업을 병렬적으로 수행하며 비동기적으로 MIPS index를 갱신한다.

  • trainer : 파라미터에 대한 gradient update를 수행
  • index builder : 문서를 embedding & indexing

trainer는 index builder에게 파라미터의 스냅샷인 \(\theta '\) 를 전송하고, trainer가 학습을 지속하는 동안 index builder전달받은 \(\theta '\) 를 이용해 새로운 인덱스를 빌드한다. 

index builder는 인덱스 구축이 완료될 때 다시 MIPS index를 trainer로 보내고 이 과정을 반복하며 MIPS index를 갱신한다.

asynchronous refresh는 pre-training, fine-tuning 두 과정에서 모두 사용 가능하지만 이번 REALM 모델에서는 pre-training 에서만 사용하였고, fine-tuning에서는 보다 간략한 과정을 위해 pre-trained theta를 사용해 초기에 MIPS index를 한 번만 빌드하고 \(Embed_{doc}\)를 업데이트 하지 않는다.

논문에서는 pre-training 과정에서 이미 좋은 \(Embed_{doc}\) 를 얻을 수 있기 때문에 효과적인 방법이라고 했으나, 인덱스를 갱신할 경우 성능이 더 향상될 수도 있다는 점을 언급했다.

 

What does the retriever learn?

REALM의 knowledge retrieval은 잠재적이기 때문에 훈련 목표가 어떻게 의미 있는 검색 도출을 이끌어내는지는 명확하지 않다. 이번 섹션에서는 예측에 도움을 주는 검색에 대해 어떻게 보상을 제공하는지를 다룬다.

pre-training 과정에서 knowledge retriever의 파라미터 \(\theta\) 에 대한 그래디언트를 계산하여 다음 수식에 따라 relevance score f(x,z) 를 변경한다.

relevance score f(x,z) : knowledge retriever가 문서 z에 할당하는 점수
p(y| z,x) : 문서 z를 사용할 때 올바른 출력 y를 예측할 확률
p(y|x) : p(z|x)에서 문서를 무작위로 샘플링할 때 예상되는 p(y|x, z)의 값

각 문서 z에 대해 gradient는 retriever가 relevance scorer(z)만큼 변경하도록 유도한다. (r(z)가 양수이면 증가, 음수이면 감소)

r(z)는 p(y|z, x) > p(y|x) 인 경우에만 양수가 되므로, 문서 z는 예상보다 더 나은 성능을 보일 때마다 양수 업데이트를 받는다.

 

 

Injecting inductive biases into pre-training

의미 있는 검색을 유도하기 위한 추가 전략들이다.

 

Salient span masking

REALM pre-training에서는 마스킹된 토큰을 예측하기 위해 world knowledge가 필요한 예시 x에 집중하고자 한다.

그러나 일부 MLM span은 local context만을 필요로 한다. world knowledge를 필요로 하는 문제에 집중하기 위해 "Salient spans" 에 대해 마스킹을 진행한다.

  • Salient spans? 문장에서 핵심적인, 두드러진 정보를 담고 있는 단어 -> 언어 표현 학습만으로는 맞추기 힘듦!
  • 본 논문에서는 아래 예시와 같이 named entity와 date를 salient spans로 규정한 듯 하다.
e.g. "United Kindom", "July 1969"

 

named entities 식별을 위해 BERT-based tagger를 이용하고 날짜를 식별하기 위해 정규식을 사용한다.

Experiments에서 이러한 마스킹 방법이 기존 BERT, SpanBERT의 마스킹 방법보다 성능 향상에 도움을 준다는 사실을 밝혔다.

 

Null document

salient span masking이 있더라도 모든 마스킹된 토큰이 예측에 world knowledge를 필요로 하는 것은 아니다.

이와 같이 world knowledge가 필요없는 경우를 모델링하기 위해 검색된 k개의 문서들 상단에 null document \(\emptyset\) 를 추가하고, knowledge가 필요없는 경우에는 추가한 null document를 retrieve 하도록 한다.

 

Prohibiting trivial retrievals

pre-training corpus Xknowledge corpus Z가 같다면 너무 많은 정보를 제공하는 자명한 후보 문서 z가 존재할 수 있다. 

이게 무슨 말이냐면, "마스킹된 문장 x가 문서 z에서 나온 경우 마스킹되지 않은 원본 문장을 z에서 볼 수 있다" 는 얘기다. 

이렇게 y를 예측해버린다면 p(z|x)의 graidient가 커지고, 모델은 우리가 의도한 학습 방법이 아니라 x와 z 사이 정확한 문자열 일치를 찾는 방법을 학습하게 될 것이다.

따라서 pre-training 중에서는 이러한 후보들을 제외한다.

 

Initialization

훈련이 시작될 때 Retriever에 \(Embed_{input}(x)\), \(Embed_{doc}(z)\) 에 대한 좋은 임베딩을 가지고 있지 않으면 반환된 문서 z도 x와 별로 연관성이 없을 수 있다. 이렇게 되면 knowledge augmented encoder를 반환된 문서들을 제대로 참조하지 않는 방향으로 학습시켜 버릴 수 있는데, 이후에도 knowledge augmented encoder가 제대로 된 gradient를 받아 업데이트하는 것이 불가능해지므로 악순환이 반복된다.

이러한 cold-start problem을 해결하고자 주어진 문장에 대해 Inverse Cloze Task(ICT) 라는 간단한 훈련 목표를 사용해  \(Embed_{input}(x)\), \(Embed_{doc}(z)\) 에 대해 warm-start가 가능하도록 한다.

  • ICT (Inverse Cloze Task) : standard cloze task에서 context를 기반으로 masked-out text를 예측하는 방식을 사용했다면, ICT는 문장의 역추정을 요구하고 문장의 맥락을 예측한다. [각주:9]

ICT 예시

무작위 문장(질문인척 하는)과 문장의 context(관련이 있는 context인 척 하는)가 text snippet(=단편, 한편의 글)으로 부터 주어진다. "얼룩말은 4개의 이동 패턴을 보인다 : 걷다가, 총총 걷다가, 구보 하듯이 걷다가 전속력으로 질주한다. 얼룩말은 일반적으로 말보다 느린 대신 매우 좋은 지구력이 있어서 포식자로부터 도망갈 수 있도록 한다. 쫓기게 되는 상황에서 얼룩말은 측면으로 지그재그로 움직인다..." 한 배치내에 있는 후보 지문들 중에서 관련이 있는 지문을 고르는 것이 목적이다. [각주:10]

 

retriever는 기존 Open Retrieval Question Answering system의 파라미터를 초기값으로 설정하였고, knowledge-augmented encoder의 경우 uncased BERT 모델의 파라미터를 초기값으로 설정했다.

해당 기법에 대한 자세한 설명은 선행 연구[각주:11] 에 나와있다.

 

 

4. Experiments

Open-QA Benchmarks

  • NaturalQuestions-Open

google query - answer로 구성되어 있으며 각 answer에는 "answer type"이 존재한다. 이 중 최대 5개의 토큰으로 구성된 "short answer type" 질문들로만 구성한다.

  • WebQuestions

Google Suggest API를 통해 수집된 해당 데이터셋은 하나의 seed question을 사용해 관련 질문으로 확장한 질문 집합으로 구성했다.

  • CuratedTrec

MSNSearch, AskJeeves와 같은 사이트에서 수집된 실제 사용자 QA 쌍 모음이다. 복수 정답 또는 다양한 철자 변형을 포함하기 위해 해당 데이터셋의 답변은 모든 정답과 일치하는 정규식으로 정의된다.

그러나 이러한 유형의 supervised dataset으로 generation-based model을 학습하는 방법은 불명확하므로 해당 데이터셋에서는 실험을 진행하지 않았다.

 

Approaches compared (Retrieval VS Generation)

Retrieval-based OpenQA

대부분의 기존 시스템은 먼저 knowledge corpus에서 잠재적으로 관련 있는 문서를 검색한 다음 RC system을 이용해 해당 문서에서 정답을 추출한다. 이와 같이 Retrieval-based approach에서는 지식이 corpus에 "명시적으로" 저장된다. 

이를 구현한 접근 방식에는 다음과 같은 선행 연구들이 존재한다.

[non-learned heuristic retrieval]

  • sparse BoW matching[각주:12]
  • 질문에 대한 entity linking을 통해 소수의 관련성 있는 문서 집합 선택
  • DrQA
  • HardEM
  • GraphRetriever
  • PathRetriever

여기서는 학습된 모델에 의해 문서들이 re-ranked 되긴 하지만 초기 heuristic retrieval step의 한계로 인해 수렴은 제한될 수 있다.

 

[MIPS index]

REALM처럼 잠재 변수 모델을 사용해 Open-QA를 공식화하고, marginal likelihood를 최대화하는 방식으로 학습을 진행한다. 또한 retriever 역시 ICT를 통해 초기화된다.

그러나 REALM에서는 새로운 LM pre-training 단계를 추가하고, 고정된 인덱스를 사용하는 대신 MIPS index로 backpropagation을 진행한다는 차이점이 존재한다.

 

Generation-based OpenQA

시퀀스 예측 태스크처럼, 질문을 인코딩한 다음 해당 인코딩을 바탕으로 토큰 단위 디코딩으로 답변을 반환하는 방식이다.

처음에는 얼마나 많은 양의 지식을 모델에 넣을 수 있는지 불분명했지만 GPT-2에서 seq2seq를 통해 주어진 컨텍스트를 사용하지 않고도 직접 답을 생성할 수 있는 능력을 암시했다. 그러나 해당 실험에서는 Fine-tuning이 부족해서 성능 자체는 경쟁력이 그닥 없었다.

 

GPT2에 대한 지난 포스팅

https://ll2ll.tistory.com/30

 

[논문 Review] 05. Language Models are Unsupervised Multitask Learners

학습시킨 범용적 사전학습 모델의 성능을 보다 높여보자! Abstract QA, 기계 번역, 독해, 요약과 같은 자연어 처리 작업은 일반적으로 task-specific 데이터셋을 통한 지도학습으로 접근한다. 본 논문에

ll2ll.tistory.com

 

T5도 이와 유사하게 가능성을 보여주었지만, 컨텍스트 문서가 제공되는 reading comprehension task에서만 실험이 진행되었다.

 

REALM과의 비교를 위해 T5를 Fine-tuning한 것과 비교했고, [각주:14]

모델 크기에 따른 비교도 하기 위해 T5 모델에 대해 Base, Large, 이것보다 큰 11B 파라미터 모델과도 비교했다.

 

 

Implementation Details

Fine-tuning

  • 직관적 비교를 위해 선행 연구[각주:15]에서 사용된 모든 하이퍼파라미터를 그대로 사용
  • knowledge corpus는 2018년 12월까지의 정보를 담고 있는 영문판 위키피디아를 사용
  • 문서에 대해 최대 288개의 BERT wordpieces로 greed split을 진행한 결과 1,300만개가 조금 넘는 크기의 검색 후보(문서)
  • Fine-tuning inference 진행 시 top-5 후보 선정
  • 모든 모델은 12GB GPU 단일 디바이스에서 실행가능

 

Pre-training

  • 64개의 google cloud TPU
  • 20만번의 step 동안 pre-training
  • batch size = 512; learning rate = 3e-5
  • MIPS index의 document embedding step은 16개의 TPU를 이용해 병렬화됨
  • 각 예시에 대해 (null document를 포함한) 8개의 후보를 검색하고 marginalize
  • pre-training corpus는 아래 2개를 선택해서 실험해보았다
    • Wikipedia (knowledge corpus Z와 동일)
    • CC-News

 

Main results

[표 1] Open-QA benchmark 성능 비교

  • REALM은 가장 큰 T5-11B 모델보다 크기는 30배 더 작으면서도 성능이 뛰어남
  • 모든 시스템 중 REALM과 가장 직접적 비교가 가능한 것은 Fine-tuning setting, 하이퍼파라미터, 학습 데이터가 동일한 ORQA
  • 따라서 ORQA에 비해 REALM 성능이 개선된 것은 사전훈련 방법때문이라고 볼 수 있으며, 이는 단일 corpus setting (X = wikipedia), 분리 corpus setting(X = CC-news) 모두에 적용될 수 있음
  • 20-80개의 문서를 검색하는 다른 검색 기반 시스템 (PathRetriever, GraphRetriever) 과 비교했을 때 REALM은 5개의 문서만 검색하면서도 우수한 성능을 발휘함

 

Analysis

Ablation study

[표 2] REALM에 대한 ablation study

  • retriever, encoder 모두 모델 성능 개선에 도움을 줌
  • 마스킹 기법의 영향을 알아보기 위해 2가지 기법을 모두 실험
    • BERT에서의 random token masking
    • SpanBERT에서의 random span masking
  • salient span masking이 REALM에서 중요한 역할
  • pre-training 과정에서 병렬 프로세스를 실행해 re-indexing을 진행하는데, 500 steps 당 1번 정도의 re-indexing이 이루어진다. 중요성을 입증하고자 더 느린 re-indexing 주기와 비교했을 때, 갱신한지 오래된 인덱스는 모델 학습에 악영향을 미치며 주기를 더 줄이면 더 나은 최적화의 가능성 존재

[표 3] MLM prediction example

  • (c) REALM은 (a) BERT에 비해 훨씬 높은 확률로 올바른 단어를 찾아냄
  • (b) REALM은 정답과 관련이 있는 문서를 반환해서, 정답에 대한 marginalized probability를 극적으로 높인다.
  • 이는 REALM이 레이블링되지 않은 텍스트만으로 학습되었음에도 불구하고 마스킹된 단어를 채우기 위한 관련성 높은 문서를 검색할 수 있는 능력을 보여줌

 

 

5. Discussion and Related Work

Open-QA를 넘어, 보다 폭넓은 범위의 태스크에서 REALM을 활용할 수 있는 방법을 소개한다.

 

Language modeling with corpus as context

Language representation model은 예측을 수행할 때 주변 단어, 문장, 문단 순으로 점점 더 넓은 범위의 문맥을 통합하고 있다.  REALM은 이 다음 단계인 전체 텍스트 코퍼스를 고려한 모델로 나아갔다고 볼 수 있다.

 

Retrieve-and-edit with learned retrieval

입력 텍스트의 분산을 더 잘 설명하고 제어 가능한 생성을 위해 선행 연구[각주:16]에서는 lexical overlap이 많은 텍스트에서의 retrieve-and-edit 프레임워크를 제안했다.

REALM도 비슷한 접근 방식을 사용하지만, 모델이 어떤 텍스트가 perplexity를 줄이는 데 가장 유용한지 스스로 학습한다는 점이 다르다. Retriever를 함께 학습함으로써 REALM은 lexical overlap을 넘어 정보에 따라 답을 생성할 수 있는 능력을 가지게 된다.

 

Scalable grounded neural memory

document index는 임베딩을 key로 가지는 메모리라고 볼 수 있는데, 선행 연구에 따라 이러한 확장 가능한 메모리 레이어를 대규모 LM에 통합할 수 있다. 선행 연구와의 차이점은 REALM에서의 메모리는 명명되지 않은 값 벡터가 아니라 문서와 연관되어 있으므로 답변이 보다 해석 가능하고 신뢰성 있는 출처를 가지게 된다.

 

Unsupervised Corpus Alignment

선행 연구인 seq2seq with attention 에서는 텍스트가 관련있는 토큰의 잠재적인 선택으로 생성된다. 그 결과 target token - source token 사이 model-centric한 정렬이 이루어진다.

이와 유사하게 REALM은 관련 문서의 잠재적 선택을 통해 텍스트를 생성하는데, 이를 통해 pre-training corpus X와 knowledge corpus Z 사이의 unsupervised alignments를 얻을 수 있다.

 

 


참고 문헌

https://gbdai.tistory.com/63

  1. https://wikidocs.net/200922 [본문으로]
  2. https://wikidocs.net/21790 [본문으로]
  3. https://medium.com/platfarm/mips-c1db30a3e73e [본문으로]
  4. Miller, A., Fisch, A., Dodge, J., Karimi, A.-H., Bordes, A.,and Weston, J. Key-alue memory networks for directlyreading documents. arXiv preprint arXiv:1606.03126,2016. [본문으로]
  5. Chen, D., Fisch, A., Weston, J., and Bordes, A. Reading wikipedia to answer open-domain questions. In Proceedings of the 55th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), volume 1, pp. 1870–1879, 2017. [본문으로]
  6. Khandelwal, U., Levy, O., Jurafsky, D., Zettlemoyer,L., and Lewis, M. Generalization through memorization:Nearest neighbor language models. ArXiv,abs/1911.00172, 2019. [본문으로]
  7. Brill, E., Dumais, S., and Banko, M. An analysis of the askmsr question-answering system. In Empirical Methods in Natural Language Processing, 2002. [본문으로]
  8. Lee, K., Chang, M.-W., and Toutanova, K. Latent retrieval for weakly supervised open domain question answering.In Proceedings of the Conference of Association for Computational Linguistics, 2019. [본문으로]
  9. https://jeonsworld.github.io/NLP/orqa/ [본문으로]
  10. https://velog.io/@sangmandu/Latent-Retrieval-for-Weakly-SupervisedOpen-Domain-Question-Answering [본문으로]
  11. Lee, K., Chang, M.-W., and Toutanova, K. Latent retrieval
    for weakly supervised open domain question answering.In Proceedings of the Conference of Association for Computational Linguistics, 2019. [본문으로]
  12. Robertson, S., Zaragoza, H., et al. The probabilistic relevance framework: Bm25 and beyond. Foundations and Trends in Information Retrieval, 3(4):333–389, 2009. [본문으로]
  13. Lee, K., Chang, M.-W., and Toutanova, K. Latent retrieval for weakly supervised open domain question answering.In Proceedings of the Conference of Association for Computational Linguistics, 2019. [본문으로]
  14. Roberts, A., Raffel, C., and Shazeer, N. How much knowledge can you pack into the parameters of a language model? arXiv preprint arXiv:TBD, 2020. [본문으로]
  15. Lee, K., Chang, M.-W., and Toutanova, K. Latent retrieval for weakly supervised open domain question answering.In Proceedings of the Conference of Association for Computational Linguistics, 2019. [본문으로]
  16. Hashimoto, T. B., Guu, K., Oren, Y., and Liang, P. S.A retrieve-and-edit framework for predicting structured outputs. In Advances in Neural Information Processing Systems, pp. 10052–10062, 2018. [본문으로]