Skip to content

fedavgSVM

FedAvgSVM

Bases: Aggregator

Aggregator: Federated Averaging (FedAvg) Authors: McMahan et al. Year: 2016 Note: This is a modified version of FedAvg for SVMs.

Source code in nebula/core/aggregation/fedavgSVM.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class FedAvgSVM(Aggregator):
    """
    Aggregator: Federated Averaging (FedAvg)
    Authors: McMahan et al.
    Year: 2016
    Note: This is a modified version of FedAvg for SVMs.
    """

    def __init__(self, config=None, **kwargs):
        super().__init__(config, **kwargs)

    def run_aggregation(self, models):
        super().run_aggregation(models)

        models = list(models.values())

        total_samples = sum([y for _, y in models])

        coeff_accum = np.zeros_like(models[-1][0].coef_)
        intercept_accum = 0.0

        for model, w in models:
            if not isinstance(model, LinearSVC):
                return None
            coeff_accum += model.coef_ * w
            intercept_accum += model.intercept_ * w

        coeff_accum /= total_samples
        intercept_accum /= total_samples

        aggregated_svm = LinearSVC()
        aggregated_svm.coef_ = coeff_accum
        aggregated_svm.intercept_ = intercept_accum

        return aggregated_svm