MultiCenter Federated Learning
0x00 Abstract
Existing federated learning approaches usually adopt a single global model to capture the shared knowledge of all users by aggregating their gradients, regardless of the discrepancy between their data distributions.
Our paper proposes a novel multicenter aggregation mechanism for federated learning, which learns multiple global models from the nonIID user data and simultaneously derives the optimal matching between users and centers.
We formulate the problem as a joint optimization that can be efficiently solved by a stochastic expectation maximization (EM) algorithm.
Key Words: Federated Learning,clustering,multicenter,EM
0x01 Introduction
Early federated learning approaches use only one global model as a singlecenter to aggregate the information of all users. The stochastic gradient descent (SGD) for singlecenter aggregation is designed for IID data, and therefore, conflicts with the nonIID setting in federated learning.
Sattler.F proposed an idea of clustered federated learning (FedCluster) that addresses the nonIID issue by dividing the users into multiple clusters. However, the hierarchical clustering in FedCluster is achieved by multiple rounds of bipartite separation, each requiring to run the federated SGD algorithm till convergence.
We formulate the problem of multicenter federated learning as jointly optimizing the clustering of users and the global model for each cluster such that 1) each user’s local model is assigned to its closest global model, and 2) each global model leads to the smallest loss over all the users in the associated cluster.
Main contribution:
 We propose a novel multicenter aggregation approach (Section 4.1) to address the nonIID challenge of federated learning.
 We design an objective function, namely multicenter federated loss (Section 4.2), for user clustering in our problem.
 We propose Federated Stochastic Expectation Maximization (FeSEM) (Section4.3) to solve the optimization of the proposed objective function.
 We present the algorithm (Section 4.4) as an easytoimplement and strong baseline for federated learning . Its effectiveness is evaluated on benchmark datasets. (Section 5)
0x02 Related Work
To solve the problem caused by nonIID or heterogeneity data in federated setting:
"Clustered federated learning: Modelagnostic distributed multitask optimization under privacy constraints" proposed clustered federated learning (FedCluster) by integrating federated learning and bipartitioningbased clustering into an overall framework.
“Robust federated learning in a heterogeneous environment” proposed robust federated learning composed of three steps: 1) learning a local model on each device, 2) clustering model parameters to multiple groups, each being a homogeneous dataset, and 3) running a robust distributed optimization in each cluster.
“Feddane: A federated newtontype method” proposed FedDANE by adapting the DANE to federated setting. In particular, FedDANE is a federated NewtonType optimization method.
“Federated optimization in heterogeneous networks” proposed FedProx for a generalization and reparameterization of FedAvg. It adds a proximal term to the objective function of each device’s supervised learning task, and the proximal term is to measure the parameterbased distance between the server and the local model.
“Federated learning with personalization layers” added a personalized layer for each local model, i.e., FedPer, to tackle heterogeneous data.
0x03 Background
 在拥有隐私数据$\mathcal{D_i}=\{\mathcal{X_i,Y_i}\}$的终端进行监督学习，获得模型$\mathcal{M_i:X_i→Y_i}$：
$W_i^{'}=\mathop{argmin}\limits_{W_i}L(\mathcal{M_i,D_i},W_i)$
$L(\cdot)$是损失函数，$W_i$是模型权重。
 联邦学习考虑到将所有设备的损失函数最小化，根据数据量的多少引入损失的权重,$m$是设备数：
$\mathcal{L}=\sum_{i=1}^m\frac{\mathcal{D_i}}{\sum_j\mathcal{D_j}}L(\mathcal{M_i,D_i},W_i)$
联邦聚合过程改变如下:$\tilde{W}^g=\sum_{i=1}^m\frac{\mathcal{D_i}}{\sum_j\mathcal{D_j}}W_i$
0x04 Methodology
1.Multicenter Aggregation
考虑到每个用户的习惯不同，所以本地模型采用的优化方式也是不一样的。
本文将$m$个训练节点根据特定的方式分成$K$组，每个组内做模型聚合。为了描述这个特定的方式，引入了集群间距的概念，如下图：
下图应该是将模型映射到一个二维特征空间，集群间距就是聚合模型到其余模型的距离和。
但是下文并未描述此图的作图根据。
2.Objective Function
为了解决NonIID
问题，本文提出两个联邦损失：

distancebased federated loss – a new objective function using a distance between parameters from the global and local models

multicenter federated loss – the total distancebased loss to aggregate local models to multiple centers.
Distancebased federated loss（DFLoss）
基于假设：“较好”的模型初始化更容易收敛到全局最优。
Moreover, considering the limited computation power and insufficient training data on each device, a “good” initialization is vitally important to train a supervised learning model on the device.
这个图的作图依据也没说明…就算给出代码，也是概率复现???
用$DFLoss$代替原始损失函数的准确性：
“较好”初始模型越靠近$W^*$→越容易收敛到$W^*$→传统损失函数值会降低。
新的目标函数如下（此处未引入权重）:
$\mathcal{L}=\frac{1}{m}\sum_{i=1}^mDist(W_i,\tilde{W})$
此处的$Dist(\cdot)$是测量两个模型之间的差距，有多种选择，本文选择了$L2$范数：$Dist(W_i,\tilde{W})\triangleq W_i\tilde{W}^2.$
Multicenter DFLoss
Furthermore, according to the nonIID assumption, the datasets in different devices can be grouped into multiple clusters where the ondevice datasets in the same cluster are likely to be generated from one distribution.
定义集群间距如下（所有集群的所有距离和）：
$\mathcal{L}=\frac{1}{m}\sum_{k=1}^{K}\sum_{i=1}^m r_i^{(k)}Dist(W_i,\tilde{W}^{(k)})$
此处$r_i^{(k)}$指明节点$i$是否属于集群$k$,$\tilde{W}^{(k)}$是集群$k$聚合模型。
3.OPtimization Method
使用EM算法优化，因为$W_i$是随机初始化，所以使用 Stochastic Expectation Maximization (SEM),步骤如下：
Estep
– updating cluster assignment $r_i^{(k)}$ with fixed $W_i$Mstep
– updating cluster centers $\tilde{W}^{(k)}$ updating local models by providing new initialization $\tilde{W}^{(k)}$.
EStep
:$\begin{aligned} r_i^{(k)}=\left\{ \begin{array}{l} 1,if k=argmin_jDist(W_i,\tilde{W}^{(j)})\\ 0,otherwise \end{array} \right. \end{aligned}$
MStep
:$\tilde{W}^{(k)}=\frac{1}{\sum^m_{i=1}r_i^{(k)}}\sum_{i=1}^mr_i^{(k)}W_i$
updating the local models
:the global model’s parameters $\tilde{W}^{(k)}$ are sent to each device in cluster $k$ to update its local model, and then we can finetune the local model’s parameters $W_i$ using a supervised learning algorithm on its own private training data.
4. Algorithm
In particular, in the third step for updating the local model, we need to finetune the local model by implementing Algorithm 2.