본문 바로가기

데이터사이언스/federated learning

Labeled distillation data을 교환하는 federated distillation methods

 0. introdution

 Federated Learning에서 distilation을 이용하는 방법은 여러 방법으로 분류할 수 있다. 지난 글에서는 distillation이 어디에서 어디로 이루어지는지를 기준으로 나누어서 client에서 server로 지식을 distill하는 방식들을 살펴보았다. 이번에는 서버와 client가 어떤 정보를 주고받는지를 기준으로 분류를 해보려고 한다.

 

  서버와 client가 주고받는 정보를 크게 분류한다면 모델 파라미터, public data가 있을 것이다. 모델 파라미터를 주고받는 방식은 기존 연합학습에서 사용되는 방식으로, communication cost가 높지만 상대적으로 높은 정확도를 가진다. public data를 주고받는 방식은 communication cost를 줄일 수 있다. 하지만 client가 data를 전달 받아서 distillation을 진행하기 때문에 client의 연산 부담이 증가한다는 단점을 가진다. 이 방식은 주고 받는 public data에 label이 존재하는지를 기준으로 두 가지로 나눌 수 있다. 이 글에서는 labeled public data를 주고받는 federated distillation 방식을 살펴볼 예정이다.

 

 

1. FedMD

FedMD[1]는 FD에서 public data만을 주고받는 방식을 최초로 제안한 논문이다. 각 라운드에 클라이언트는 서버에서 받은 data로 distillation을 통해 지식을 전달받고, 자신의 데이터로 학습한다. 마지막으로 서버의 데이터로 추론을 진행하고, 그 값을 다시 서버로 보낸다. 서버에서는 client의 예측을 averaging하고, 그 결과를 각 cleint에게 다시 보낸다. 이 전체적인 프로세스는 향후 논문에도 지속해서 사용된다. 

 

2. MHAT

MHAT[2]은 서버에서 단순하게 client 예측의 평균을 구해서 aggregation 하는 것이 아니라, 어떤 client의 예측에 가중치를줄지 판단하는 model을 사용해, 효과적인 aggregation이 될 수 있도록 학습을 진행한다. 이를 통해서 client의 모델이 다른 경우에도 aggregation을 효과적으로 진행할 수 있도록 한다.

 

 

3.FEDGEMS

[3]에서는 client보다 자원이 많은 server의 특징을 고려하여 서버에서 더 큰 모델을 사용하는 GEM(larGer sErver Model)을 도입한 FEDGEM을 소개한다. 더 나아가서 효과적인 client 모델을 선정하고 가중치를 매기는 프로토콜 FEDGEMS을 제안한다.

 

(1) Self-distillation of server knowledge

 서버는 자신의 데이터를 성공적으로 예측한 $S_{Correct}$, 예측에 실패한 $S_{Incorrect}$로 나눈다. 그리고 $S_{Correct}$에 대해서는 실제 label을 바탕으로 cross-entrophy를 계산하여 loss 함수 $L_{S_{1}}$을 구한다. 또한 정답 logit을 $l^{i}_{global}$에 저장한다. 여러 round를 진행하면서 업데이트된 $l_{global}$에 정답이 없는 데이터를  $S^{*}_{Incorrect}$라고 정의한다.

 

 self-distillation의 최종 목적함수 $L_{{S}_{2}}$는 다음과 같다.

self-distillation의 최종 목적함수

이런 self-distillation은 1. 정보를 교환할 필요가 없고, 2. 서버 자신의 모델 구조에 맞는 데이터로 학습할 수 있다는 장점이 있다.

 

(2) selective ensemble

 서버의 모델이 예측을 실패한 $S^{*}_{Incorrect}$에 존재하는 데이터에 대해서만 client의 정보를 전달받는다. 

FEDGEMS framework

$S^{*}_{Incorrect}$ 안  $(x_{i}, y_{i})$에 대해서 예측을 성공한 클라이언트와 실패한 클라이언트로 나눈다, 예측을 성공한 client의 logit만을 사용해서 distillation을 진행한다. 예측을 성공한 client의 logit은 entrophy가 낮을수록 높은 신뢰도를 가진다고 판단한다. 이 가중치가 아래 식에서 $\alpha_{{C}_{j}}$이다.

distillation loss from client to server

 

 

reference

[1]: FedMD: Heterogenous Federated Learning via Model Distillation

[2]: MHAT: An efficient model-heterogenous aggregation training scheme for federated learning

[3]: FEDGEMS: FEDERATED LEARNING OF LARGER SERVER MODELS VIA SELECTIVE KNOWLEDGE FUSION