Skip to content

trimmedmean

TrimmedMean

Bases: Aggregator

Aggregator: TrimmedMean Authors: Dong Yin et al et al. Year: 2021 Note: https://arxiv.org/pdf/1803.01498.pdf

Source code in nebula/core/aggregation/trimmedmean.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
class TrimmedMean(Aggregator):
    """
    Aggregator: TrimmedMean
    Authors: Dong Yin et al et al.
    Year: 2021
    Note: https://arxiv.org/pdf/1803.01498.pdf
    """

    def __init__(self, config=None, beta=0, **kwargs):
        super().__init__(config, **kwargs)
        self.beta = beta

    def get_trimmedmean(self, weights):
        # check if the weight tensor has enough space
        weight_len = len(weights)

        if weight_len <= 2 * self.beta:
            remaining_wrights = weights
            res = torch.mean(remaining_wrights, 0)

        else:
            # remove the largest and smallest β items
            arr_weights = np.asarray(weights)
            nobs = arr_weights.shape[0]
            start = self.beta
            end = nobs - self.beta
            atmp = np.partition(arr_weights, (start, end - 1), 0)
            sl = [slice(None)] * atmp.ndim
            sl[0] = slice(start, end)
            print(atmp[tuple(sl)])
            arr_median = np.mean(atmp[tuple(sl)], axis=0)
            res = torch.tensor(arr_median)

        # get the mean of the remaining weights

        return res

    def run_aggregation(self, models):
        super().run_aggregation(models)

        models = list(models.values())
        models_params = [m for m, _ in models]

        total_models = len(models)

        accum = {layer: torch.zeros_like(param).float() for layer, param in models[-1][0].items()}

        for layer in accum:
            weight_layer = accum[layer]
            # get the shape of layer tensor
            l_shape = list(weight_layer.shape)

            # get the number of elements of layer tensor
            number_layer_weights = torch.numel(weight_layer)
            # if its 0-d tensor
            if l_shape == []:
                weights = torch.tensor([models_params[j][layer] for j in range(0, total_models)])
                weights = weights.double()
                w = self.get_trimmedmean(weights)
                accum[layer] = w

            else:
                # flatten the tensor
                weight_layer_flatten = weight_layer.view(number_layer_weights)

                # flatten the tensor of each model
                models_layer_weight_flatten = torch.stack(
                    [models_params[j][layer].view(number_layer_weights) for j in range(0, total_models)],
                    0,
                )

                # get the weight list [w1j,w2j,··· ,wmj], where wij is the jth parameter of the ith local model
                trimmedmean = self.get_trimmedmean(models_layer_weight_flatten)
                accum[layer] = trimmedmean.view(l_shape)

        return accum