안녕하세요! 5G DNA* 기술 개발 (*Digital Network Automation)이라는 주제로 SKT AI Fellowship 4기에서 활동 중인 팀 DNA(다나와) 입니다.
이전 연구 계획 글에서 저희 팀의 목표는 네트워크로부터 수집되는 빅데이터를 이용하여 고객의 체감을 개선하는 것으로, 이를 위해 서비스 품질 예측 모형을 개발하고자 한다고 전해드렸습니다. 특히, 해당 모델은 일반적으로 결과에 대한 해석이 불가능한 인공지능 모델에서 한 단계 발전하여 결과에 대한 해석이 가능하다는 점을 강조 드렸는데요, 이러한 모델을 구현해내기 위해 저희 팀에서는 다양한 관련 논문에 대해 survey를 진행하고 그 중 중요한 논문에 대해서는 내부 세미나를 진행하고 있습니다.
앞으로 Fellowship활동 기간 동안 이렇게 내부 세미나를 진행한 논문들에 대한 리뷰를 포스팅을 통해 함께 공유 드리고자 합니다.
첫번째로 저희가 공유해 드릴 논문의 제목은 RETAIN : An Interpretable Predictive Model for Healthcare using Reverse Time Attention Mechanism (NIPS 2016) 입니다. 시계열 데이터에 적용되고 interpretable한 모델을 제안하는 논문으로, 해석이 가능한 모델을 만들고자 하는 생각을 들게 해준 논문입니다.
1.Data notation
사용하는 데이터는 EHR데이터로 환자가 병원에 방문했을 때 처방 혹은 진단 받은 기록을 모아 놓은 데이터입니다.
sequnetial한 high-dimensional clincal variables(진단기록, 처방기록 등) 으로 요약할 수 있습니다.
r개의 다른 variable을 가지는 총 N명 중 n번째 환자에 대한 표현은 다음과 같습니다.
(t(n)i,x(n)i)∈R×Rr,i=1,...,T(n)
이때 t(n)i는 i번째 visit의 timstamp를 의미하고 T(n)는 n번째 환자의 총 방문 횟수를 의미합니다.
해당 논문은 편의상 (n)을 생략하고 진행합니다.
예측하고자 하는 label은 each time step에서의 yi∈{0,1}s 혹은 마지막 time step에서의 y∈{0,1}s입니다.
예시1)ESM(Encounter Sequence Modeling)
환자의 각 visit이 medical code들의 set {c1,c2,⋯,cn}으로 표현됩니다. 따라서 number of variable은 r=|C|으로 input은 x∈{0,1}|C|의 binary vector로 표현됩니다. ESM의 목표는 주어진 input sequence x1,⋯,xT 에서 각 time step i에 대해 next visit인 x2,⋯,xT+1 을 예측하는 것이 목표 입니다. 즉, s=|C|입니다.
예시2)L2D(Learning to Diagnose)
xi는 continuous clinic measures로 xi∈Rr 입니다. L2D의 목표는 input sequnce x1,⋯,xT 가 주어졌을 때 특정 질병(s=1)혹은 여러 개의 질병들(s>1)의 발병 여부를 예측하는 것입니다.
2.Methology
해당 논문의 모델인 RETAIN을 설명하기에 앞서 모델에 활용되는 Attention based neural network에 대해서 간단하게 설명드리겠습니다.
Attention의 유용성은 language translation task에서 알 수 있습니다. 일반적으로 전체 문장을 하나의 고정된 크기의 vector로 표현하는 것은 효율적이지 않기에 neural translation machine들은 고정된 크기의 vector로 표현된 문장을 해석하는데 어려움을 겪습니다. 따라서 성능향상을 위해 Attention이 사용되고, 이 때 Attention은 다음과 같은 역할을 합니다. 예를 들어 original 언어(input)에서 S의 length를 가지는 하나의 문장이 주어졌다고 하면, 문장속의 단어를 표현하는 h1,⋯,hS의 vector들을 만들어 냅니다. target 언어(output)에서 j번째에 해당하는 단어를 찾기 위해 attention αji for i=1,⋯,S을 만들어 냅니다. 그리고 context vector cj=∑iαjihi를 계산하고 이를 통해 target언어 예측에 사용합니다. 즉, attention을 활용해 target언어를 예측함에 있어 orginal 언어에서 어떤 단어에 집중해야 하는지를 알게 해줍니다. RETAIN에서 이러한 attention의 활용은 의사가 환자의 기록을 살필 때 특정 clinic 정보에 더 중점을 두는 행동과 유사하다고 볼 수 있습니다.
RETAIN

모델의 구조는 위의 이미지와 같고 총 5가지 step으로 구성되며 있습니다.
Step1. Embedding v생성.
Step2. RNNα를 이용한 attention α생성.
Step3. RNNβ를 이용한 attention β생성.
Step4. attetion들과 embedded vector를 이용한 convex vector c 생성.
Step5. Prediction.
모델의 입력은 i번째 visit까지의 timestamp 이고 xi∈Rr입니다.
각각의 Step을 수식과 함께 표현하면 다음과 같습니다.
Step1.
vi=Wembxi,Wemb∈Rm×r 으로 vi는 input vector xi의 embedding을 의미합니다.
attention의 경우 visit-level attention α와 variable-level attention β로 2가지를 사용합니다. 이는 모델 결과에 대한 interpretability를 얻기 위함인데 뒷 부분에 설명을 따로 추가하도록 하겠습니다.
Step2.
gi,gi−1,⋯,g1=RNNα(vi,vi−1,⋯,v1)ej=wTαgj+bα,for j=1,⋯,iα1,α2,⋯,αi=Softmax(e1,e2,⋯,ei)
gi∈Rp로 RNNα의 hidden layer에 속하고 wα∈Rq와 bα∈R는 학습 가능한 파라미터 입니다.
Step3.
hi,hi−1,⋯,h1=RNNβ(vi,vi−1,⋯,v1)βj=tanh(Wβhj+bβ),for j=1,⋯,i
hi∈Rq로 RNNβ의 hidden layer에 속하고 Wβ∈Rm×q와 bβ∈Rm는 학습 가능한 파라미터 입니다.
주의해야할 점은 α는 visit-level attention으로 (α1,α2,⋯,αi)가 Ri의 차원을 이루지만 β는 variable-level attention으로 β1,β2,⋯,βi 로 표현되는 각각의 방문이 Rm의 차원을 이룸을 기억해야 합니다.
Step4.
ci=∑ij=1αjβj⊙vj로 i번째 timestamp에 대응하는 context vector ci는 i시점 까지의 embedding vector vi와 대응하는 attention들이 element-wise하게 곱해짐으로 이전 시점의 값들 중 어떤 시점의 값에 더 중점을 줘야 하는지가 반영된 것으로 생각할 수 있습니다.
Step5.
ˆyi=Softmax(Wci+b)로 예측값을 얻어낼 수 있고 loss는 cross-entropy를 사용합니다.
3.Interpreting RETAIN
논문에서 제시하는 모델은 다른 모델과 달리 결과가 Interpretable하다는 특징을 가지고 있습니다.
이번에는 어떻게 모델이 해석력을 가지게 되었는지 & 어떻게 해석을 하는지 두가지 측면에 대해 설명드리겠습니다.
3-1)어떻게 모델이 해석력을 가지게 되었는지

해당 이미지는 attention을 사용하는 일반적인 모델과 RETAIN의 구조를 보여주는 이미지입니다.
크게 embedding vector vi를 만드는 부분⋯(ㄱ)과 attention weight를 만드는 부분⋯(ㄴ)으로 나눈다고 할때, 일반적인 모델은 (ㄱ)에서 RNN , (ㄴ)에서 MLP를 사용하는 반면 RETAIN에서는 (ㄱ)에서 MLP, (ㄴ)에서 RNN을 사용함을 확인할 수 있습니다. RNN에서는 recurrent한 weights가 과거의 정보를 hidden layer에 반복적으로 제공하는 특징을 가지고 있습니다. 따라서 일반적인 모델에서는 embedding vector vi를 형성하는데 input xi가 어떻게 작용하는지를 파악할 수가 없는 반면 RETAIN에서는 MLP를 통해 embedding vector를 형성하기 때문에 input xi가 어떻게 작용하는지를 파악이 가능하다고 설명하고 있습니다.
3-2)어떻게 해석을 하는지
예측에 있어 가장 많이 기여한 visit을 찾는 것은 단순히 가장 큰 αi값에 대응하는 viist으로 생각하면 되지만, 가장 많이 기여한 variable을 찾는 것은 각각의 visit이 다양한 variable들의 ensemble로 표현되기 때문에 좀 더 복잡합니다. 따라서 이를 표현하기 위해 visit-level attention α만 사용하는 것이 아닌 variable-level attention β까지 도입하였습니다.
α와β를 고정한 채, original input x1,1,⋯,x1,r,⋯,xi,1,⋯,xi,r 값 변화에 따른 label yi,1,⋯,yi,s의 확률 변화를 계산하면 다음과 같습니다.
p(yi|x1,⋯,xi)=p(yi|ci)=Softmax(Wci+b)=Softmax(W(i∑j=1αjβj⊙vj)+b)=Softmax(W(i∑j=1αjβj⊙r∑k=1xj,kWemb[:,k])+b)=Softmax(i∑j=1r∑k=1xj,kαjW(βj⊙Wemb[:,k])+b)
Eq(1)에서 Eq(2)로의 변형은 convex vector가 attention들을 가중치로 가지는 embedding vector들의 합으로 표현이 가능하기 때문이고 Eq(2)에서 Eq(3)으로의 변형은 embedding vector는 input xi의 element를 가중치로 가지는 Wemb의 column들의 합으로 표현이 가능하기 때문입니다.
여기서 중요한 점은 Eq(4)가 xj,k별로 분해가 가능하다는 점입니다. 덕분에 해당 확률값에 대한 xj,k 각각의 영향력을 계산할 수 있습니다. 이를 ω(yi,xj,k)=αjW(βj⊙Wemb[:,k]) xj,k로 denote하고 해석은 yi를 예측하는데 있어 time step j에서의 input xj의 k번째 variable의 contribution입니다. 해당 contribution 값이 클 수록 영향력 있는 변수라고 생각할 수 있습니다.
4.Experiment
해당 논문에서는 환자에 대한 visit sequence x1,⋯,xT가 주어졌을 때 해당 환자가 HF(Heart Failure)를 진단 받았는지를 예측하는 문제에 대한 실험을 진행하였습니다. detail한 setting을 생략하고 결과를 보여드리면 다음과 같습니다.

여기서 RNN+αM은 위에서 설명드린 attention을 사용하는 일반적인 모델을 의미합니다. convtext vector를 생성함에 있어 RNN을 사용하기 때문에 해석력을 갖지 않습니다. RNN+αR은 RNN+αM과 동일한데 attention생성에 있어 MLP 대신 RNN을 사용하여 RNN을 사용한 attention의 효과가 있는지를 확인하기 위함입니다.
위의 결과를 보면 알 수 있듯이 RETAIN이 다른 방법론들에 비해 성능이 뛰어나진 않습니다. 하지만 RETAIN은 다른 방법론들과 달리 결과에 대한 해석이 가능하므로, 다른 방법론들의 성능을 유지하면서 설명력을 갖는다는 점이 큰 장점입니다.
해당 실험에 있어 RETAIN이 정말로 해석력을 갖는지를 보여주며 마무리하겠습니다.

SD,ESL,BN 등등은 데이터 설명 부분에서 말씀 드린 medical code 즉 variable들 입니다.
test input sequnce를 살펴보면 해당 환자는 HF의 증상인 CD, CA, HVD 등을 보이기 전에 피부 문제인 SD, BN, ESL 등을 겪었음을 볼 수 있습니다. 해당 그래프는 학습된 RETAIN모델에 대해 test input sequnce를 입력하여 얻어낸 contribution들을 나타낸 것입니다.(x축이 시점, y축이 contribution)
(a)는 test input sequnce을 그대로 모델에 입력한 결과입니다. 환자의 초기 방문에서 보였던 skin-related code들이 HF 진단에 크게 영향을 주지 않고, 최근 방문에서의 HF의 증상들이 더 관련있다는 적절한 결과를 도출하였습니다.
(b)는 test input sequnce의 순서를 거꾸로 입력하여 얻어낸 결과입니다. HF의 증상에 대한 contribution이 여전히 가장 높게 나오지만 (a)에서의 HF의 증상에 대한 contribution보다는 작게 나왔음을 확인 가능합니다. 또한 HF을 진단 받을 확률이 (a)의 0.2474보다 훨씬 낮은 0.0905값이 나와 RETAIN이 데이터의 temporal한 특성도 잘 이용한 모델임을 확인 가능합니다.
(c)는 test input sequence에 HF의 증상인 CD에 대한 약인 AA와 AC를 특정 시간에 처방하였다고 가정한 데이터를 입력한 결과 입니다. 보시다시피 AA와 AC의 contribution은 음수로 HF 진단확률을 낮추게끔 영향을 주고 실제 진단 확률도 (a)의 0.2474보다 낮은 0.2165를 얻었음을 확인 가능합니다.
이렇게 첫번째 내부 세미나 논문인 RETAIN에 대한 리뷰를 마무리 짓겠습니다.
Multivariate한 시계열 데이터를 활용한 예측을 진행하며 설명력까지 갖는다는 점에서 저희 팀 목표를 이루기 위한 어느정도의 아이디어를 준 논문이었습니다. 물론 데이터의 특성이 다르고, 추가적으로 다양하게 반영할 수 있는 요소가 많기 때문에 이러한 부분들을 해결하고 반영하기 위해 저희 팀에서는 이어서 다양한 논문들에 대한 리뷰를 열심히 진행하고 있습니다. 이에 대한 내용도 계속 블로깅을 진행할 예정이니 많은 관심 부탁드립니다. 감사합니다!
팀원 소개

Reference
Choi, E., Bahadori, M. T., Sun, J., Kulas, J., Schuetz, A., & Stewart, W. (2016). Retain: An interpretable predictive model for healthcare using reverse time attention mechanism. Advances in neural information processing systems, 29.
'SKT AI Fellowship 4기' 카테고리의 다른 글
5G DNA(Digital Network Automation) 기술 개발 - 연구과정(2) (0) | 2022.09.03 |
---|