Bases: Aggregator
Aggregator: Federated Averaging (FedAvg)
Authors: McMahan et al.
Year: 2016
Note: This is a modified version of FedAvg for SVMs.
Source code in nebula/core/aggregation/fedavgSVM.py
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41 | class FedAvgSVM(Aggregator):
"""
Aggregator: Federated Averaging (FedAvg)
Authors: McMahan et al.
Year: 2016
Note: This is a modified version of FedAvg for SVMs.
"""
def __init__(self, config=None, **kwargs):
super().__init__(config, **kwargs)
def run_aggregation(self, models):
super().run_aggregation(models)
models = list(models.values())
total_samples = sum([y for _, y in models])
coeff_accum = np.zeros_like(models[-1][0].coef_)
intercept_accum = 0.0
for model, w in models:
if not isinstance(model, LinearSVC):
return None
coeff_accum += model.coef_ * w
intercept_accum += model.intercept_ * w
coeff_accum /= total_samples
intercept_accum /= total_samples
aggregated_svm = LinearSVC()
aggregated_svm.coef_ = coeff_accum
aggregated_svm.intercept_ = intercept_accum
return aggregated_svm
|