ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • [논문 리뷰] source free domain adaptation via distribution estimation
    논문 스터디 2023. 6. 22. 17:21

    어디까지나 뇌피셜인 블로그

    논문 링크 : https://arxiv.org/pdf/2204.11257.pdf

    Motivation

    domain adaptation 은 source dataset으로 train된 model을 unlabeled target dataset에 사용하려는 task이다. source dataset과 real-world는 다르므로 source dataset으로 train된 모델을 general하게 사용하기는 어렵기 때문에 관심을 얻고 있다. domain adaptation은 곧 source dataset의 distribution과 target dataset의 distribution이 다른 문제를 말하는 domain shift problem을 푸는 것과 같다. 

     

    그 중에서도 source dataset에 접근하지 않고 source dataset으로 train된 model의 weight만을 사용해 domain adaptation 하는 task를 source free domain adaptation이라고 하고, 이 논문에서는 source domain의 distribution을 estimation 하는 방법을 사용한다.

     

    먼저 선행되어야 하는 notation들을 소개한다. 

     

    source dataset 을 $D_s = {(x_i^s, y_i^s)}_{i=q}^{n_s}$, target dataset  $D_t = {(x_i^t)}_{i=1}^{n_t}$,  linear classifier  $G(·)$, CNN feature extractor를 $F(·)$, m dimensional feature representation을 $f = F(x) \in \mathbb{R}^m $, trained model  $G(F(·))$, weights learned by $G$ 를 $w^G$, $k$-th weight vector of $w^G$ 를 $w_k^G$ 라고 하겠다. 

     

    따라서 linear classifier $G$에 의해 predict된 class label은 아래와 같이 나타내어 진다. 

     

    $$\hat{y_i} = \arg \max_k {f_i^{\top} w_k^G}$$

     

    각 element는 feature와 weight vector의 dot product이고, k-th class의 data는 k-th weight vector of $G$를 activate하는 feature representation을 도출한다. 따라서 $w_k^G$ 는 k-th class를 나타내는 anchor 라고 볼 수 있다. 

     

    Method

    1) Pseudo-labeling by exploiting anchors

    본 논문에서는 spherical k-means 를 통해 target data를 cluster하고 pseudo label 을 부여한다. initial center 는 anchor들로 선정한다 : $\mathit{A}^{(0)}_k = w^G_k$. $\hat{y}^t_i = \arg \min_k Dist(\mathit{A}^{(m)}_k, f^t_i)$ 로 m번째 iteration의 cluster center와 $f^t_i$ 의 cosine distance를 최소화하는 k 가 target data의 pseudo label이 된다. 그리고 initial center는 anchor였으나 $\hat{y}^t_i = k$ 이면 $f^t_i$ 쪽으로 옮겨가게 된다. 

     

    2) Source distribution estimation

    SFDA task는 traditional DA 와 달리 source data의 distritubution을 알지 못하므로, 본 논문에서는 source feature distribution을  estimation 한다. 첫 번째로, source data의 feature representation은 class-conditioned multivariate Gaussian distribution을 따른다고 가정한다 : $f^s_{i,k} ~ \mathit{N}^s_k(\mu^s_k, \sigma^s_k)$, where $f^s_{i,k} = \mathbb{F}(x^s_i|y^s_i=k) $ 여기서 $\mu^s_k$ 가 feature representation 의 center이고 covariance matrix 는 feature의 variation 과 rich semantic information 을 나타낸다..고 하는데 feature variation까지는 이해가 가는데 rich semantic information은 이해가 잘 안간다.

     

    아무튼 surrogate distribution $\mathit{N}^{sur}_k(\hat{\mu}^s_k, \hat{\sigma}^s_k)$ 으로 source distribution $\mathit{N}^s_k$ 을 근사한다. $\hat{\mu}^s_k$ 를 구하기 위해  feature mean을 directly 사용하면 domain distribution shift problem 을 반영하지 못하므로, anchor 를 사용하여 surrogate source distribution의 estimator mean을 calibrate한다. 

     

    $$\hat{\mu}^s_k = ||\bar{f^t_k}_2|| \cdot {{w^G_k} \over {||w^G_k||^2}}$$

     

    여기서 의미하는 바는 estimated source feature mean의 scale은 target feature의 scale과 같고 direction은 anchor 와 같다는 것이다. 그리고  covariance matrix는 이전의 연구들에서 class conditioned covariance 가 activated semantic direction과 서로다른  feature channel간의 correlation을 나타낸다는 것을 밝혔다고 한다.. ;; 읽어봐야겠다. 왜? 직관적으로 잘 와닿지 않는다.

     

    본 논문에서는 target feature의 intra-class semantic information 이 source 와 roughly consistent한다고 가정한다. 

    따라서 source covariance의 estimator 를 target feature의 statistics로 구한다. 

     

    $$\hat{\Sigma}^s_k = \gamma \cdot \Sigma^t_k = \gamma \cdot {{{\mathbf{f}^t_k} \cdot {{\mathbf{f}^t_k}^\top}} \over {\sum \mathbb{1}(\hat{y}^t_i = k)}}$$

     

     where ${\mathbf{f}^t_k} = [f^t_{1,k} - \bar{f}^t_k, ... ,f^t_{i,k} - \bar{f}^t_k, ... ]$ 인 matrix whose columms are centralized target features of k-th class in ${D'}_t$ 

     

    아무튼 $\gamma$ 를 controlling coefficient 로 놓고 sampled surrogate features의 sampling range와 semantic diversity를 조절한다고 한다. 

     

    anchors와 target features를 이용하여, $K$ class-conditioned surrogate source distributions 

     

    $$ \mathcal{N}_k^{sur}(||\bar{f}_t^k||_2 {{w_k^G} \over {||w_k^G||_2}}, {{\gamma \cdot f_k^t \cdot {f_k^t}^{\top}} \over {\sum \mathbb{1}(\hat{y}_i^t = k)}} ), k \in \mathcal{C}$$

     

    를 derive한다. where we can sample surrogate features ${f_k}^{sur} \sim {\mathcal{N}_k}^{sur}(\hat{\mu}_k^s, \hat{\Sigma}_k^s)$

     

     

    3) Source-free domain adaptation

    이전 section에서, pretrained model에 보존된 domain knowledge를 이용하여 accessing source data 없이 source distribution 을 estimate 했다. 따라서 이제 surrogate source data인 estimated distritubution으로부터 data 를 sample하여 SFDA 문제는 traditional DA 문제로 바뀌게 된다. 여기서 Contrastive Domain Discrepancy (CDD) 를 targe domain과 estimated source distribution을 explicitly align 하기 위해 사용한다. 

     

     

     

Designed by Tistory.