ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • [논문리뷰] Improving non-transferable representation learning by harnessing content and style
    논문 스터디 2024. 2. 8. 16:41

     

     

     

    논문 : https://openreview.net/pdf?id=FYKVPOHCpE

     

    ICLR2024 splotlight 에서 재미있는 논문 몇 편을 찾았다. 그 중 첫 번째로 읽는 논문이다. Non-transferable representation learning 이라는 제목에 transfer learning, representation learning에 관심이 있어 읽어보게 되었다. 읽다보니 domain adaptation, continual learning, knowledge distillation 과 상당히 비슷하다는 느낌을 받았다. 찾아보니 저자가 또 그쪽에 많은 논문을 낸 듯 하다. 

     

    Introduction

     

    먼저 task를 소개하겠다. Non-transferable representation learning이라는 분야가 생소한 것 같지만 그렇지 않다. 우리가 transfer 하고싶은 정보만 보내고 그 외에 transfer하고싶지 않은 정보는 보내지 않는걸 연구하는 분야인 것 같다. transfer라고 하면 굉장히 많은 분야들이 떠오르는데, 위에 말한 DA, CL, KD가 그렇다. 

     

    이 논문에서는 transfer할 정보를 딱 정해뒀다기보단, content와 style이라는 factor를 정의하고 label과의 spurious correlation, fake independence를 없애면서 content와 style 각각의 label과의 causal relationship을 찾는 듯 하다. 

     

    그래서 이 논문의 contribution은 (1) 이미지의 특징을 content와 style로 나눠 spurious correlation과 fake independence를 정의한다. 기존에는 statistical dependence로 나눴었는데 그렇게 하면 small distribution shifts에 fragile하다. (2) 기존의 문제를 해결하기위해, contents와 style을 latent factor로 하는 causal model을 만들어 variational inference를 했다고 한다. 

     

    그럼 여기서 드는 생각, 만약 배경이 다양한 오리 사진들이 있으면, content는 오리고 style은 배경인데, 여기서 오리만 학습시키고싶다면 contrastive learning 같은걸 한 다음 linear probing을 하면 배경에 상관없이 오리만 학습이 될까? 그중에서도 오리의 배경이 중복되어 배경이 같이 학습될 수도 있다. 이걸 어떻게 제거할까? 이걸 어떻게 나눌까?

     

    ViT needs register와 같이 class마다 register를 두고 하나는 content를 학습하고, 하나는 style을 학습하도록 할 수 있을까? causal relationship을 어떻게 찾지?

     

    Related works

    관련된 연구를 두가지로 나눠놨는데, 하나는 NTL(non-transferable learning)이고 하나는 causal-inspired representation learning 이다. 먼저 NTL은 certain domain or all domain에서 learned model의 generalizability를 restrict하는 task이다. 이걸 하는 이유는 intellectual property protection 과 controlleable AI 라고 한다. 아까 말했듯 가장 관련된 분야는 DA와 DG(domain generalization) 이고 DA와 task가 반대라고 보면 된다. DA는 source domain에서 학습한 모델을 target domain에서도 잘 학습하도록 하는 task이고 NTL은 certain domain(target specified) 관점에서만 봤을 때 source domain에서 학습한 모델을 target domain에서 어떤 것들은 학습되지 않도록?하는 것 같다. 그리고 또 all domain(source only) 관점에서 봤을 때 DG와 반대라고 하는데 내가 아직 DG가 뭔지 잘 모른다. 공부를 좀 해야겠다. 

     

    NTL은 2022년에 Wang et al, 가 처음으로 propose 했다고 한다. 정말 얼마 안된 task이다. 이 논문에서는 target domain과 label 사이의 KL divergence를 maximizing하고 source와 target domain의 distribution의 MMD를 maximizing하는 statistical dependence relaxation term을 넣었다. 근데 이유는 잘 모르겠지만 이런 statistical based method는 spurious correlation과 fake independence에 쉽게 빠질 수 있다고 한다. 많이 쓰던 방식인데 왜일까? 아마도 distribution을 비교하는 방식들이다 보니 거리가 멀게 되는 것들이 정확히 뭔지 정의되지 않아서 그런가? 약간 뭔지 알 것 같으면서도 정확히는 잘 모르겠다.

     

    그리고 그 다음 causal inspired representation learning 에 대한 연구들이다. statistical learning이 image background와 class label의 spurious correlation 에 빠지기 쉽게 만들고 이러한 문제들은 poor generalizability towards unseen domains라는 결과를 낳는다. 이러한 이유로 causal-inspired representation learning이 promising한데, 이 task는 다른 causal factor에 의해서 invariant한 representation을 학습하는 것을 목표로 한다. 그래서 causal-inspired DA 와 DG가 있다. 여기서는 certain domain에서 generalizability 없이 학습하는 걸 목표로하고 있다고 하는데 이 뜻은 다른 domain에서 variant한 feature representation을 학습하겠다는 얘기다. 

     

    위의 말이 좀 애매한데, source domain에서 학습한 모델이 target domain에선 잘 안되도록 한다는 건데, 그때 잘 안되도록 하는 걸 controllable하게 만드는건가? variant한 feature representation을 학습하면 그냥 잘 안되는거 아닌가? 그럼 그때 loss는 어떻게 나오는거지?

     

    Methods

     

    method에서 설명할 내용은 3가지이다. 첫번째로 idea non-transferable mechanism based on causal model, 두번째로  variational inference framework to disentangle the unobservable content factor C and style factor S. 세번째로 dual-path knowledge distillation to harness the disentangled factors as guidance to teach a student to learn ideal non-transferable representations. NTL task는 target-specified 와 source-only 두가지 프로세스로 나눠서 설명한다. 

    Notation.

     

     

     

    instance X in domain D with label Y has two unobservable cause variables, content factor C and style factor S. 반면  X, Y, D 는 observable variables 이며 우리는 XYD의 joint distribution인 source domain data D_s, target domain data D_t 에 접근할 수 있다. 우리는 이 두 training set을 merge하여 D라 하고 y_i는 content label, x_i는 i번째 sample, d_i는 domain label이다. NTL의 goal은 f_ntl 이라는 classification model을 학습시키는 것인데 동시에 target domain performance를 degrade하고 source domain performance를 maintain해야 한다. f_ntl은 feature extractor f_e와 classifier f_cls 로 나뉜다. 

     

     

     

    1. Causal-driven non-transferable representation

    causal model을 구성했다는 건, X는 C,S에 dependent 하고, D는 S에 dependent 하고, Y는 C에 dependent 하다는 것 정도만 알면 될 것 같다. 그런데 여기서 잠깐, 그럼 모델을 구성하고 나서는? 그래서 dependent 하고 나면 뭐가 좋은거지? 라는 생각이 드는데, 여기서는 PGM(probabilistic graphical model)에서 배웠던 VAE(variational auto-encoder)와 ELBO(evidence lower bound)개념을 가져온다. 왜 가져오느냐?는 좀더 생각해보자. X, Y, D 는 observable variable이고, C,S는 unobservable하다. 그리고 여기서, target domain 의 representation N_t 와 source domain representation N_s 가 independent 하고, N_t와 label이 Independent하면, optimally untransferable하다라고 definition을 내린다. 그렇다면, N_t는 content와 연관이 있으면 안되고, N_s 랑도 연관이 있으면 안된다. 동시에 N_s는 content와 연관이 있어야 한다. 따라서 N_t를 style과 dependent 하게 만든 것이다. 

     

    2. Disentanglement under the causal model

    latent factor와 observable variable 사이에 causal model이 형성되는 것은 알겠구, 그럼 이걸 가지고 뭘 하냐? 가 앞에서 나왔는 질문인데, 챗지피티한테 물어보니 입력 데이터에서 다양한 요소들을 분리하여 독립적으로 학습시키거나 추출하는 과정이 disentanglement (분리) 이고 이는 데이터의 특성을 명확이 이해하고 잠재된 요소를 추출하는데 도움이 된다고 한다. 

     

    그래서 causal model을 disentangle 한다는 것은 입력 데이터 요소들이 인과적으로 형성되어 있고 이를 독립적으로 학습하거나 추출한다. 즉, 데이터의 각 특성이 서로 다른 요인에 의해 독립적으로 변화되는 것 을 인식하고 이를 학습한다. 그리고 또 disentanglement는 latent variable을 찾는 과정이라고 보면 된다. 잠재변수는 데이터의 표현력을 높이고, 데이터의 분포를 잘 모델링하는데 사용된다. 그래서 disentanglement로 찾아진 하나의 잠재변수는 한가지 특성이나 요인에 집중되도록 하는 것 이 목표이다. 이를 통해 데이터의 특성을 더 잘 이해할 수 있다. 

     

    이제 다시 본론으로 돌아와서, C,S 는 naturally X에 entangled 되어있고, 이들을 이용하기(infer 하기) 위해서 우리는 observable한 X, Y, D를 사용한다. 그리고 joint distribution을 factorize 한다. 

    $$P(X,Y,D,C,S) = P(C,S)P(Y|C,S) P(X|C,S) P(D|C,S)$$

    그리고 이제 factorize 된 식을 보고 저자들은 VAE를 떠올렸다는데 나는 왜인지 잘 모르겠다. 그래서 VAE를 다시 복습해보자.

     

    https://process-mining.tistory.com/161 이 블로그에 자세히 설명되어 있다. 

     

    여전히 왜 저 factorize 된 식을 보고 VAE가 떠오르는지는 의문이지만, encoder $q_\phi(z|x)$ 로 input을 latent space로 변환할 때, input x가 주어졌을 때 latent z의 분포 $q_\phi(z|x)$ (posterior)를 approximate 할 것이고 이를 approximate 할 때는 이 정규분포를 나타내는 평균과 표준편차를 찾는것이 목적이다. 

     

    VAE를 학습할 때 loss는 뭘까? MLE (maximum likelihood)를 택한다. 즉 $p_\theta(x)$ 를 maximize하는 theta를 찾는 것을 목적으로 한다. 그래서 $\log p_\theta (x)$ 를 maximize 한다. ( $\log p_\theta (x)$ 이게 뭔데??? x의 분포를 흉내내어 비슷한 걸 만들어낼 거니까 x의 log likelihood를 maximize하는 theta를 찾고 그때 encoder와 decoder를 이용하도록 식을 수정하는듯)

     

    아무튼 또다시 본론으로 돌아와서, 이 논문에서는 C,S에 각각 encoder $\hat{q}_{\phi_c}(c|x), \hat{q}_{\phi_s}(s|x)$ 를 놓고 posterior $q_{\phi_c}(C|X), q_{\phi_s}(S|X)$ 를 learnable parameter $\phi$로 모델링 하도록 한다. 그리고 distribution $p_{\theta_x}(X|C,S)$를 모델링하는 decoder $\hat{p}_{\theta_x}$ 를 놓는다. 더해서, $P(Y|C,S)$ 를 $P(Y|C)$ 로 approximate 하고, $P(D|C,S) = P(D|S)$ 도 마찬가지로 한다. 그러므로, classifier $\hat{p}_{\theta_d}, \hat{p}_{\theta_y}$를 이용하여 distribution $p_{\theta_d}(D|S), p_{\theta_y}(Y|C)$ 를 모델링한다. 

     

    위의 파라미터들($\phi_c, \phi_s, \theta_x, \theta_y, \theta_d$)을 학습하기 위해, ELBO를 maximize 한다. 

     

    $\log p(x,y,d) = \log \int_c \int_s p(x,y,d,c,s)dc ds$ 로 시작, 

     

    결론적으로, 

     

    $ELBO(x,y,d) = KL(q_{\phi_c}(c|x)||p(C)) - KL(q_{\phi_s}(s|x)||p(S)) $

     

    $ + \mathbb{E}_{c\sim q_{\phi_c}(c|x)}[\log q_{\theta_c}(y|c)] + \mathbb{E}_{s\sim q_{\phi_s}(s|x)}[\log q_{\theta_s}(d|s)] $

     

    $+ \mathbb{E}_{c,s}[\log p_{\theta_x}(x|c,s)]$

     

     

    아무튼 저 파라미터들을 학습시켜  latent variable C,S를 만들어 X,Y,D 와 비슷한 데이터를 만드는 encoder와 decoder, classifier들을 얻은 다음, source domain의 instance $X_s$, target domain의 instance $X_t$ 를 input으로 하는 dual path knowledge distillation을 구성한다. encoder가 C,S를 뱉으면 그걸 content classifier로 보내는데, 그러면 content인 C는 잘 classify 될것이고, style인 S는 잘 classify되지 않을 것이다. $X_s$의 path에서는 content classifier output의 logit을 $f_{ntl}$ 의 output logit과 MSE를 최소화하도록 $f_{ntl}$을 distill하고, $X_t$의 path에서도 content classifier output의 logit과 $f_{ntl}$의 output logit의 MSE를 최소화하도록 한다. 

     

    처음엔 이 방식이 좀 말이 안된다고 생각했다. content와 style이 weak relationship이 있지만 그래도 style을 content classifier에 넣는다고? 라는 생각을 했었다. 물론 지금도 dual path KD 또는 NTL task는 이 방법의 장점을 잘 살린다는 생각이 들지는 않는다. content와 style을 잘 구분하는 encoder를 구성했다면, classification 뿐만 아니라 segmentation이나 detection에도 사용될 수 있고, style에 상관없이 content를 잘 classify하는 DA에 사용될 수도 있고, 여러 class의 content를 구분할 수 있다면 clustering 같은 곳에 사용 될 수 있지 않나..

     

    그리고 이게 NTL에 쓰인 이유도 약간 좀 이해가 안가는게, source에서 content 뽑고, target에서 style 뽑아 content classifier를 돌린걸 KD하면 source에서 잘되고 target에서 잘 안되는건 당연한데,, source에서 잘되고 target에서 잘 안되는 방법은 좀 더 쉽게 접근할 수 있지 않나 생각했다. 물론 그러면 논문이 ICLR에 억셉되진 않겠지만.

     

    아무튼 content와 style을 구분하는 방법과 NTL task에 대해 알려준 논문으로 재미있게 읽었다.

Designed by Tistory.