Bases: Aggregator
Aggregator: Federated Averaging (FedAvg)
Authors: McMahan et al.
Year: 2016
Source code in nebula/core/aggregation/fedavg.py
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44 | class FedAvg(Aggregator):
"""
Aggregator: Federated Averaging (FedAvg)
Authors: McMahan et al.
Year: 2016
"""
def __init__(self, config=None, **kwargs):
super().__init__(config, **kwargs)
def run_aggregation(self, models):
super().run_aggregation(models)
models = list(models.values())
total_samples = float(sum(weight for _, weight in models))
if total_samples == 0:
raise ValueError("Total number of samples must be greater than zero.")
last_model_params = models[-1][0]
accum = {layer: torch.zeros_like(param, dtype=torch.float32) for layer, param in last_model_params.items()}
with torch.no_grad():
for model_parameters, weight in models:
normalized_weight = weight / total_samples
for layer in accum:
accum[layer].add_(
model_parameters[layer].to(accum[layer].dtype),
alpha=normalized_weight,
)
del models
gc.collect()
# self.print_model_size(accum)
return accum
|