Skip to content

median

Median

Bases: Aggregator

Aggregator: Median Authors: Dong Yin et al et al. Year: 2021 Note: https://arxiv.org/pdf/1803.01498.pdf

Source code in nebula/core/aggregation/median.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
class Median(Aggregator):
    """
    Aggregator: Median
    Authors: Dong Yin et al et al.
    Year: 2021
    Note: https://arxiv.org/pdf/1803.01498.pdf
    """

    def __init__(self, config=None, **kwargs):
        super().__init__(config, **kwargs)

    def get_median(self, weights):
        # check if the weight tensor has enough space
        weight_len = len(weights)

        median = 0
        if weight_len % 2 == 1:
            # odd number, return the median
            median, _ = torch.median(weights, 0)
        else:
            # even number, return the mean of median two numbers
            # sort the tensor
            arr_weights = np.asarray(weights)
            nobs = arr_weights.shape[0]
            start = int(nobs / 2) - 1
            end = int(nobs / 2) + 1
            atmp = np.partition(arr_weights, (start, end - 1), 0)
            sl = [slice(None)] * atmp.ndim
            sl[0] = slice(start, end)
            arr_median = np.mean(atmp[tuple(sl)], axis=0)
            median = torch.tensor(arr_median)
        return median

    def run_aggregation(self, models):
        super().run_aggregation(models)

        models = list(models.values())
        models_params = [m for m, _ in models]

        total_models = len(models)

        accum = {layer: torch.zeros_like(param).float() for layer, param in models[-1][0].items()}

        # Calculate the trimmedmean for each parameter
        for layer in accum:
            weight_layer = accum[layer]
            # get the shape of layer tensor
            l_shape = list(weight_layer.shape)

            # get the number of elements of layer tensor
            number_layer_weights = torch.numel(weight_layer)
            # if its 0-d tensor
            if l_shape == []:
                weights = torch.tensor([models_params[j][layer] for j in range(0, total_models)])
                weights = weights.double()
                w = self.get_median(weights)
                accum[layer] = w

            else:
                # flatten the tensor
                weight_layer_flatten = weight_layer.view(number_layer_weights)

                # flatten the tensor of each model
                models_layer_weight_flatten = torch.stack(
                    [models_params[j][layer].view(number_layer_weights) for j in range(0, total_models)],
                    0,
                )

                # get the weight list [w1j,w2j,··· ,wmj], where wij is the jth parameter of the ith local model
                median = self.get_median(models_layer_weight_flatten)
                accum[layer] = median.view(l_shape)
        return accum