Skip to content

Documentation for Nebuladataset Module

NebulaDataset

Source code in nebula/core/datasets/nebuladataset.py
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
class NebulaDataset:
    def __init__(
        self,
        num_classes=10,
        partitions_number=1,
        batch_size=32,
        num_workers=4,
        iid=True,
        partition="dirichlet",
        partition_parameter=0.5,
        seed=42,
        config_dir=None,
    ):
        self.num_classes = num_classes
        self.partitions_number = partitions_number
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.iid = iid
        self.partition = partition
        self.partition_parameter = partition_parameter
        self.seed = seed
        self.config_dir = config_dir

        logging.info(
            f"Dataset {self.__class__.__name__} initialized | Partitions: {self.partitions_number} | IID: {self.iid} | Partition: {self.partition} | Partition parameter: {self.partition_parameter}"
        )

        # Dataset
        self.train_set = None
        self.train_indices_map = None
        self.test_set = None
        self.test_indices_map = None
        self.local_test_indices_map = None

        enable_deterministic(self.seed)

    @abstractmethod
    def initialize_dataset(self):
        """
        Initialize the dataset. This should load or create the dataset.
        """
        raise NotImplementedError("Subclasses must implement this method.")

    def clear(self):
        """
        Clear the dataset. This should remove or reset the dataset.
        """
        if self.train_set is not None and hasattr(self.train_set, "close"):
            self.train_set.close()
        if self.test_set is not None and hasattr(self.test_set, "close"):
            self.test_set.close()

        self.train_set = None
        self.train_indices_map = None
        self.test_set = None
        self.test_indices_map = None
        self.local_test_indices_map = None

    def data_partitioning(self, plot=False):
        """
        Perform the data partitioning.
        """

        logging.info(
            f"Partitioning data for {self.__class__.__name__} | Partitions: {self.partitions_number} | IID: {self.iid} | Partition: {self.partition} | Partition parameter: {self.partition_parameter}"
        )

        self.train_indices_map = (
            self.generate_iid_map(self.train_set)
            if self.iid
            else self.generate_non_iid_map(self.train_set, self.partition, self.partition_parameter)
        )
        self.test_indices_map = self.get_test_indices_map()
        self.local_test_indices_map = self.get_local_test_indices_map()

        if plot:
            self.plot_data_distribution("train", self.train_set, self.train_indices_map)
            self.plot_all_data_distribution("train", self.train_set, self.train_indices_map)
            self.plot_data_distribution("local_test", self.test_set, self.local_test_indices_map)
            self.plot_all_data_distribution("local_test", self.test_set, self.local_test_indices_map)

        self.save_partitions()

    def get_test_indices_map(self):
        """
        Get the indices of the test set for each participant.

        Returns:
            A dictionary mapping participant_id to a list of indices.
        """
        try:
            test_indices_map = {}
            for participant_id in range(self.partitions_number):
                test_indices_map[participant_id] = list(range(len(self.test_set)))
            return test_indices_map
        except Exception as e:
            logging.exception(f"Error in get_test_indices_map: {e}")

    def get_local_test_indices_map(self):
        """
        Get the indices of the local test set for each participant.
        Indices whose labels are the same as the training set are selected.

        Returns:
            A dictionary mapping participant_id to a list of indices.
        """
        try:
            local_test_indices_map = {}
            test_targets = np.array(self.test_set.targets)
            for participant_id in range(self.partitions_number):
                train_labels = np.array([self.train_set.targets[idx] for idx in self.train_indices_map[participant_id]])
                indices = np.where(np.isin(test_targets, train_labels))[0].tolist()
                local_test_indices_map[participant_id] = indices
            return local_test_indices_map
        except Exception as e:
            logging.exception(f"Error in get_local_test_indices_map: {e}")
            raise

    def save_partition(self, obj, file, name):
        try:
            logging.info(f"Saving pickled object of type {type(obj)}")
            pickled = pickle.dumps(obj)
            ds = file.create_dataset(name, data=np.void(pickled))
            ds.attrs["__type__"] = "pickle"
            logging.info(f"Saved pickled object of type {type(obj)} to {name}")
        except Exception as e:
            logging.exception(f"Error saving object to HDF5: {e}")
            raise

    def save_partitions(self):
        """
        Save each partition data (train, test, and local test) to separate pickle files.
        The controller saves one file per partition for each data split.
        """
        try:
            logging.info(f"Saving partitions data for ALL participants ({self.partitions_number}) in {self.config_dir}")
            path = self.config_dir
            if not os.path.exists(path):
                raise FileNotFoundError(f"Path {path} does not exist")
            # Check that the partition maps have the expected number of partitions
            if not (
                len(self.train_indices_map)
                == len(self.test_indices_map)
                == len(self.local_test_indices_map)
                == self.partitions_number
            ):
                raise ValueError("One of the partition maps has an unexpected length.")

            # Save global test data
            file_name = os.path.join(path, "global_test.h5")
            with h5py.File(file_name, "w") as f:
                indices = list(range(len(self.test_set)))
                test_data = [self.test_set[i] for i in indices]
                self.save_partition(test_data, f, "test_data")
                f["test_data"].attrs["num_classes"] = self.num_classes
                test_targets = np.array(self.test_set.targets)
                f.create_dataset("test_targets", data=test_targets, compression="gzip")

            for participant in range(self.partitions_number):
                file_name = os.path.join(path, f"participant_{participant}_train.h5")
                with h5py.File(file_name, "w") as f:
                    logging.info(f"Saving training data for participant {participant} in {file_name}")
                    indices = self.train_indices_map[participant]
                    train_data = [self.train_set[i] for i in indices]
                    self.save_partition(train_data, f, "train_data")
                    f["train_data"].attrs["num_classes"] = self.num_classes
                    train_targets = np.array([self.train_set.targets[i] for i in indices])
                    f.create_dataset("train_targets", data=train_targets, compression="gzip")
                    logging.info(f"Partition saved for participant {participant}.")

            logging.info("Successfully saved all partition files.")

        except Exception as e:
            logging.exception(f"Error in save_partitions: {e}")

        finally:
            self.clear()
            logging.info("Cleared dataset after saving partitions.")

    @abstractmethod
    def generate_non_iid_map(self, dataset, partition="dirichlet", plot=False):
        """
        Create a non-iid map of the dataset.
        """
        pass

    @abstractmethod
    def generate_iid_map(self, dataset, plot=False):
        """
        Create an iid map of the dataset.
        """
        pass

    def plot_data_distribution(self, phase, dataset, partitions_map):
        """
        Plot the data distribution of the dataset.

        Plot the data distribution of the dataset according to the partitions map provided.

        Args:
            phase: The phase of the dataset (train, test, local_test).
            dataset: The dataset to plot (torch.utils.data.Dataset).
            partitions_map: The map of the dataset partitions.
        """
        logging_training.info(f"Plotting data distribution for {phase} dataset")
        # Plot the data distribution of the dataset, one graph per partition
        sns.set()
        sns.set_style("whitegrid", {"axes.grid": False})
        sns.set_context("paper", font_scale=1.5)
        sns.set_palette("Set2")

        # Plot bar charts for each partition
        partition_index = 0
        for partition_index in range(self.partitions_number):
            indices = partitions_map[partition_index]
            class_counts = [0] * self.num_classes
            for idx in indices:
                label = dataset.targets[idx]
                class_counts[label] += 1

            logging_training.info(f"[{phase}] Participant {partition_index} total samples: {len(indices)}")
            logging_training.info(f"[{phase}] Participant {partition_index} data distribution: {class_counts}")

            plt.figure(figsize=(12, 8))
            plt.bar(range(self.num_classes), class_counts)
            for j, count in enumerate(class_counts):
                plt.text(j, count, str(count), ha="center", va="bottom", fontsize=12)
            plt.xlabel("Class")
            plt.ylabel("Number of samples")
            plt.xticks(range(self.num_classes))
            plt.title(
                f"[{phase}] Participant {partition_index} data distribution ({'IID' if self.iid else f'Non-IID - {self.partition}'} - {self.partition_parameter})"
            )
            plt.tight_layout()
            path_to_save = f"{self.config_dir}/participant_{partition_index}_data_distribution_{'iid' if self.iid else 'non_iid'}{'_' + self.partition if not self.iid else ''}_{phase}.pdf"
            logging_training.info(f"Saving data distribution for participant {partition_index} to {path_to_save}")
            plt.savefig(path_to_save, dpi=300, bbox_inches="tight")
            plt.close()

        if hasattr(self, "tsne") and self.tsne:
            self.visualize_tsne(dataset)

    def visualize_tsne(self, dataset):
        X = []  # List for storing the characteristics of the samples
        y = []  # Ready to store the labels of the samples
        for idx in range(len(dataset)):  # Assuming that 'dataset' is a list or array of your samples
            sample, label = dataset[idx]
            X.append(sample.flatten())
            y.append(label)

        X = np.array(X)
        y = np.array(y)

        tsne = TSNE(n_components=2, verbose=1, perplexity=40, n_iter=300)
        tsne_results = tsne.fit_transform(X)

        plt.figure(figsize=(16, 10))
        sns.scatterplot(
            x=tsne_results[:, 0],
            y=tsne_results[:, 1],
            hue=y,
            palette=sns.color_palette("hsv", self.num_classes),
            legend="full",
            alpha=0.7,
        )

        plt.title("t-SNE visualization of the dataset")
        plt.xlabel("t-SNE axis 1")
        plt.ylabel("t-SNE axis 2")
        plt.legend(title="Class")
        plt.tight_layout()

        path_to_save_tsne = f"{self.config_dir}/tsne_visualization.png"
        plt.savefig(path_to_save_tsne, dpi=300, bbox_inches="tight")
        plt.close()

    def dirichlet_partition(
        self,
        dataset: Any,
        alpha: float = 0.5,
        min_samples_size: int = 50,
        balanced: bool = False,
        max_iter: int = 100,
        verbose: bool = True,
    ) -> dict[int, list[int]]:
        """
        Partition the dataset among clients using a Dirichlet distribution.
        This function ensures each client gets at least min_samples_size samples.

        Parameters
        ----------
        dataset : Dataset
            The dataset to partition. Must have a 'targets' attribute.
        alpha : float, default=0.5
            Concentration parameter for the Dirichlet distribution.
        min_samples_size : int, default=50
            Minimum number of samples required in each partition.
        balanced : bool, default=False
            If True, distribute each class evenly among clients.
            Otherwise, allocate according to a Dirichlet distribution.
        max_iter : int, default=100
            Maximum number of iterations to try finding a valid partition.
        verbose : bool, default=True
            If True, print debug information per iteration.

        Returns
        -------
        partitions : dict[int, list[int]]
            Dictionary mapping each client index to a list of sample indices.
        """
        # Extract targets and unique labels.
        y_data = self._get_targets(dataset)
        unique_labels = np.unique(y_data)

        # For each class, get a shuffled list of indices.
        class_indices = {}
        base_rng = np.random.default_rng(self.seed)
        for label in unique_labels:
            idx = np.where(y_data == label)[0]
            base_rng.shuffle(idx)
            class_indices[label] = idx

        # Prepare container for client indices.
        indices_per_partition = [[] for _ in range(self.partitions_number)]

        def allocate_for_label(label_idx: np.ndarray, rng: np.random.Generator) -> np.ndarray:
            num_label_samples = len(label_idx)
            if balanced:
                proportions = np.full(self.partitions_number, 1.0 / self.partitions_number)
            else:
                proportions = rng.dirichlet([alpha] * self.partitions_number)
            sample_counts = (proportions * num_label_samples).astype(int)
            remainder = num_label_samples - sample_counts.sum()
            if remainder > 0:
                extra_indices = rng.choice(self.partitions_number, size=remainder, replace=False)
                for idx in extra_indices:
                    sample_counts[idx] += 1
            return sample_counts

        for iteration in range(1, max_iter + 1):
            rng = np.random.default_rng(self.seed + iteration)
            temp_indices_per_partition = [[] for _ in range(self.partitions_number)]
            for label in unique_labels:
                label_idx = class_indices[label]
                counts = allocate_for_label(label_idx, rng)
                start = 0
                for client_idx, count in enumerate(counts):
                    end = start + count
                    temp_indices_per_partition[client_idx].extend(label_idx[start:end])
                    start = end

            client_sizes = [len(indices) for indices in temp_indices_per_partition]
            if min(client_sizes) >= min_samples_size:
                indices_per_partition = temp_indices_per_partition
                if verbose:
                    print(f"Partition successful at iteration {iteration}. Client sizes: {client_sizes}")
                break
            if verbose:
                print(f"Iteration {iteration}: client sizes {client_sizes}")

        else:
            raise ValueError(
                f"Could not create partitions with at least {min_samples_size} samples per client after {max_iter} iterations."
            )

        initial_partition = {i: indices for i, indices in enumerate(indices_per_partition)}

        final_partition = self.postprocess_partition(initial_partition, y_data)

        return final_partition

    @staticmethod
    def _get_targets(dataset) -> np.ndarray:
        if isinstance(dataset.targets, np.ndarray):
            return dataset.targets
        elif hasattr(dataset.targets, "numpy"):
            return dataset.targets.numpy()
        else:
            return np.asarray(dataset.targets)

    def postprocess_partition(
        self, partition: dict[int, list[int]], y_data: np.ndarray, min_samples_per_class: int = 10
    ) -> dict[int, list[int]]:
        """
        Post-process a partition to remove (and reassign) classes with too few samples per client.

        For each class:
        - For clients with a count > 0 but below min_samples_per_class, remove those samples.
        - Reassign the removed samples to the client that already has the maximum count for that class.

        Parameters
        ----------
        partition : dict[int, list[int]]
            The initial partition mapping client indices to sample indices.
        y_data : np.ndarray
            The array of labels corresponding to the dataset samples.
        min_samples_per_class : int, default=10
            The minimum acceptable number of samples per class for each client.

        Returns
        -------
        new_partition : dict[int, list[int]]
            The updated partition.
        """
        # Copy partition so we can modify it.
        new_partition = {client: list(indices) for client, indices in partition.items()}

        # Iterate over each class.
        for label in np.unique(y_data):
            # For each client, count how many samples of this label exist.
            client_counts = {}
            for client, indices in new_partition.items():
                client_counts[client] = np.sum(np.array(y_data)[indices] == label)

            # Identify clients with fewer than min_samples_per_class but nonzero counts.
            donors = [client for client, count in client_counts.items() if 0 < count < min_samples_per_class]
            # Identify potential recipients: those with at least min_samples_per_class.
            recipients = [client for client, count in client_counts.items() if count >= min_samples_per_class]
            # If no client meets the threshold, choose the one with the highest count.
            if not recipients:
                best_recipient = max(client_counts, key=client_counts.get)
                recipients = [best_recipient]
            # Choose the recipient with the maximum count.
            best_recipient = max(recipients, key=lambda c: client_counts[c])

            # For each donor, remove samples of this label and reassign them.
            for donor in donors:
                donor_indices = new_partition[donor]
                # Identify indices corresponding to this label.
                donor_label_indices = [idx for idx in donor_indices if y_data[idx] == label]
                # Remove these from the donor.
                new_partition[donor] = [idx for idx in donor_indices if y_data[idx] != label]
                # Add these to the best recipient.
                new_partition[best_recipient].extend(donor_label_indices)

        return new_partition

    def homo_partition(self, dataset):
        """
        Homogeneously partition the dataset into multiple subsets.

        This function divides a dataset into a specified number of subsets, where each subset
        is intended to have a roughly equal number of samples. This method aims to ensure a
        homogeneous distribution of data across all subsets. It's particularly useful in
        scenarios where a uniform distribution of data is desired among all federated learning
        clients.

        Args:
            dataset (torch.utils.data.Dataset): The dataset to partition. It should have
                                                'data' and 'targets' attributes.

        Returns:
            dict: A dictionary where keys are subset indices (ranging from 0 to partitions_number-1)
                and values are lists of indices corresponding to the samples in each subset.

        The function randomly shuffles the entire dataset and then splits it into the number
        of subsets specified by `partitions_number`. It ensures that each subset has a similar number
        of samples. The function also prints the class distribution in each subset for reference.

        Example usage:
            federated_data = homo_partition(my_dataset)
            # This creates federated data subsets with homogeneous distribution.
        """
        n_nets = self.partitions_number

        n_train = len(dataset.targets)
        np.random.seed(self.seed)
        idxs = np.random.permutation(n_train)
        batch_idxs = np.array_split(idxs, n_nets)
        net_dataidx_map = {i: batch_idxs[i] for i in range(n_nets)}

        # partitioned_datasets = []
        for i in range(self.partitions_number):
            # subset = torch.utils.data.Subset(dataset, net_dataidx_map[i])
            # partitioned_datasets.append(subset)

            # Print class distribution in the current partition
            class_counts = [0] * self.num_classes
            for idx in net_dataidx_map[i]:
                label = dataset.targets[idx]
                class_counts[label] += 1
            logging.info(f"Partition {i + 1} class distribution: {class_counts}")

        return net_dataidx_map

    def balanced_iid_partition(self, dataset):
        """
        Partition the dataset into balanced and IID (Independent and Identically Distributed)
        subsets for each client.

        This function divides a dataset into a specified number of subsets (federated clients),
        where each subset has an equal class distribution. This makes the partition suitable for
        simulating IID data scenarios in federated learning.

        Args:
            dataset (list): The dataset to partition. It should be a list of tuples where each
                            tuple represents a data sample and its corresponding label.

        Returns:
            dict: A dictionary where keys are client IDs (ranging from 0 to partitions_number-1) and
                    values are lists of indices corresponding to the samples assigned to each client.

        The function ensures that each class is represented equally in each subset. The
        partitioning process involves iterating over each class, shuffling the indices of that class,
        and then splitting them equally among the clients. The function does not print the class
        distribution in each subset.

        Example usage:
            federated_data = balanced_iid_partition(my_dataset)
            # This creates federated data subsets with equal class distributions.
        """
        num_clients = self.partitions_number
        clients_data = {i: [] for i in range(num_clients)}

        # Get the labels from the dataset
        if isinstance(dataset.targets, np.ndarray):
            labels = dataset.targets
        elif hasattr(dataset.targets, "numpy"):  # Check if it's a tensor with .numpy() method
            labels = dataset.targets.numpy()
        else:  # If it's a list
            labels = np.asarray(dataset.targets)

        label_counts = np.bincount(labels)
        min_label = label_counts.argmin()
        min_count = label_counts[min_label]

        for label in range(self.num_classes):
            # Get the indices of the same label samples
            label_indices = np.where(labels == label)[0]
            np.random.seed(self.seed)
            np.random.shuffle(label_indices)

            # Split the data based on their labels
            samples_per_client = min_count // num_clients

            for i in range(num_clients):
                start_idx = i * samples_per_client
                end_idx = (i + 1) * samples_per_client
                clients_data[i].extend(label_indices[start_idx:end_idx])

        return clients_data

    def unbalanced_iid_partition(self, dataset, imbalance_factor=2):
        """
        Partition the dataset into multiple IID (Independent and Identically Distributed)
        subsets with different size.

        This function divides a dataset into a specified number of IID subsets (federated
        clients), where each subset has a different number of samples. The number of samples
        in each subset is determined by an imbalance factor, making the partition suitable
        for simulating imbalanced data scenarios in federated learning.

        Args:
            dataset (list): The dataset to partition. It should be a list of tuples where
                            each tuple represents a data sample and its corresponding label.
            imbalance_factor (float): The factor to determine the degree of imbalance
                                    among the subsets. A lower imbalance factor leads to more
                                    imbalanced partitions.

        Returns:
            dict: A dictionary where keys are client IDs (ranging from 0 to partitions_number-1) and
                    values are lists of indices corresponding to the samples assigned to each client.

        The function ensures that each class is represented in each subset but with varying
        proportions. The partitioning process involves iterating over each class, shuffling
        the indices of that class, and then splitting them according to the calculated subset
        sizes. The function does not print the class distribution in each subset.

        Example usage:
            federated_data = unbalanced_iid_partition(my_dataset, imbalance_factor=2)
            # This creates federated data subsets with varying number of samples based on
            # an imbalance factor of 2.
        """
        num_clients = self.partitions_number
        clients_data = {i: [] for i in range(num_clients)}

        # Get the labels from the dataset
        labels = np.array([dataset.targets[idx] for idx in range(len(dataset))])
        label_counts = np.bincount(labels)

        min_label = label_counts.argmin()
        min_count = label_counts[min_label]

        # Set the initial_subset_size
        initial_subset_size = min_count // num_clients

        # Calculate the number of samples for each subset based on the imbalance factor
        subset_sizes = [initial_subset_size]
        for i in range(1, num_clients):
            subset_sizes.append(int(subset_sizes[i - 1] * ((imbalance_factor - 1) / imbalance_factor)))

        for label in range(self.num_classes):
            # Get the indices of the same label samples
            label_indices = np.where(labels == label)[0]
            np.random.seed(self.seed)
            np.random.shuffle(label_indices)

            # Split the data based on their labels
            start = 0
            for i in range(num_clients):
                end = start + subset_sizes[i]
                clients_data[i].extend(label_indices[start:end])
                start = end

        return clients_data

    def percentage_partition(self, dataset, percentage=20):
        """
        Partition a dataset into multiple subsets with a specified level of non-IID-ness.

        Args:
            dataset (torch.utils.data.Dataset): The dataset to partition. It should have
                                                'data' and 'targets' attributes.
            percentage (int): A value between 0 and 100 that specifies the desired
                                level of non-IID-ness for the labels of the federated data.
                                This percentage controls the imbalance in the class distribution
                                across different subsets.

        Returns:
            dict: A dictionary where keys are subset indices (ranging from 0 to partitions_number-1)
                and values are lists of indices corresponding to the samples in each subset.

        Example usage:
            federated_data = percentage_partition(my_dataset, percentage=20)
            # This creates federated data subsets with varying class distributions based on
            # a percentage of 20.
        """
        if isinstance(dataset.targets, np.ndarray):
            y_train = dataset.targets
        elif hasattr(dataset.targets, "numpy"):  # Check if it's a tensor with .numpy() method
            y_train = dataset.targets.numpy()
        else:  # If it's a list
            y_train = np.asarray(dataset.targets)

        num_classes = self.num_classes
        num_subsets = self.partitions_number
        class_indices = {i: np.where(y_train == i)[0] for i in range(num_classes)}

        # Get the labels from the dataset
        labels = np.array([dataset.targets[idx] for idx in range(len(dataset))])
        label_counts = np.bincount(labels)

        min_label = label_counts.argmin()
        min_count = label_counts[min_label]

        classes_per_subset = int(num_classes * percentage / 100)
        if classes_per_subset < 1:
            raise ValueError("The percentage is too low to assign at least one class to each subset.")

        subset_indices = [[] for _ in range(num_subsets)]
        class_list = list(range(num_classes))
        np.random.seed(self.seed)
        np.random.shuffle(class_list)

        for i in range(num_subsets):
            for j in range(classes_per_subset):
                # Use modulo operation to cycle through the class_list
                class_idx = class_list[(i * classes_per_subset + j) % num_classes]
                indices = class_indices[class_idx]
                np.random.seed(self.seed)
                np.random.shuffle(indices)
                # Select approximately 50% of the indices
                subset_indices[i].extend(indices[: min_count // 2])

            class_counts = np.bincount(np.array([dataset.targets[idx] for idx in subset_indices[i]]))
            logging.info(f"Partition {i + 1} class distribution: {class_counts.tolist()}")

        partitioned_datasets = {i: subset_indices[i] for i in range(num_subsets)}

        return partitioned_datasets

    def plot_all_data_distribution(self, phase, dataset, partitions_map):
        """

        Plot all of the data distribution of the dataset according to the partitions map provided.

        Args:
            dataset: The dataset to plot (torch.utils.data.Dataset).
            partitions_map: The map of the dataset partitions.
        """
        sns.set()
        sns.set_style("whitegrid", {"axes.grid": False})
        sns.set_context("paper", font_scale=1.5)
        sns.set_palette("Set2")

        num_clients = len(partitions_map)
        num_classes = self.num_classes

        # Plot number of samples per class in the dataset
        plt.figure(figsize=(12, 8))

        class_counts = [0] * num_classes
        for target in dataset.targets:
            class_counts[target] += 1

        plt.bar(range(num_classes), class_counts, tick_label=dataset.classes)
        for j, count in enumerate(class_counts):
            plt.text(j, count, str(count), ha="center", va="bottom", fontsize=12)
        plt.title(f"[{phase}] Number of samples per class in the dataset")
        plt.xlabel("Class")
        plt.ylabel("Number of samples")
        plt.tight_layout()

        path_to_save_class_distribution = f"{self.config_dir}/full_data_distribution_{'iid' if self.iid else 'non_iid'}{'_' + self.partition if not self.iid else ''}_{phase}.pdf"
        plt.savefig(path_to_save_class_distribution, dpi=300, bbox_inches="tight")
        plt.close()

        # Plot distribution of samples across participants
        plt.figure(figsize=(12, 8))

        label_distribution = [[] for _ in range(num_classes)]
        for c_id, idc in partitions_map.items():
            for idx in idc:
                label_distribution[dataset.targets[idx]].append(c_id)

        plt.hist(
            label_distribution,
            stacked=True,
            bins=np.arange(-0.5, num_clients + 1.5, 1),
            label=dataset.classes,
            rwidth=0.5,
        )
        plt.xticks(
            np.arange(num_clients),
            ["Participant %d" % (c_id + 1) for c_id in range(num_clients)],
        )
        plt.title(f"[{phase}] Distribution of splited datasets")
        plt.xlabel("Participant")
        plt.ylabel("Number of samples")
        plt.xticks(range(num_clients), [f" {i}" for i in range(num_clients)])
        plt.legend(loc="upper right")
        plt.tight_layout()

        path_to_save = f"{self.config_dir}/all_data_distribution_HIST_{'iid' if self.iid else 'non_iid'}{'_' + self.partition if not self.iid else ''}_{phase}.pdf"
        plt.savefig(path_to_save, dpi=300, bbox_inches="tight")
        plt.close()

        plt.figure(figsize=(12, 8))
        max_point_size = 1200
        min_point_size = 0

        for i in range(self.partitions_number):
            class_counts = [0] * self.num_classes
            indices = partitions_map[i]
            for idx in indices:
                label = dataset.targets[idx]
                class_counts[label] += 1

            # Normalize the point sizes for this partition, handling the case where max_samples_partition is 0
            max_samples_partition = max(class_counts)
            if max_samples_partition == 0:
                logging.warning(f"[{phase}] Participant {i} has no samples. Skipping size normalization.")
                sizes = [min_point_size] * self.num_classes
            else:
                sizes = [
                    (size / max_samples_partition) * (max_point_size - min_point_size) + min_point_size
                    for size in class_counts
                ]
            plt.scatter([i] * self.num_classes, range(self.num_classes), s=sizes, alpha=0.5)

        plt.xlabel("Participant")
        plt.ylabel("Class")
        plt.xticks(range(self.partitions_number))
        plt.yticks(range(self.num_classes))
        plt.title(
            f"[{phase}] Data distribution across participants ({'IID' if self.iid else f'Non-IID - {self.partition}'} - {self.partition_parameter})"
        )
        plt.tight_layout()

        path_to_save = f"{self.config_dir}/all_data_distribution_CIRCLES_{'iid' if self.iid else 'non_iid'}{'_' + self.partition if not self.iid else ''}_{phase}.pdf"
        plt.savefig(path_to_save, dpi=300, bbox_inches="tight")
        plt.close()

balanced_iid_partition(dataset)

Partition the dataset into balanced and IID (Independent and Identically Distributed) subsets for each client.

This function divides a dataset into a specified number of subsets (federated clients), where each subset has an equal class distribution. This makes the partition suitable for simulating IID data scenarios in federated learning.

Parameters:

Name Type Description Default
dataset list

The dataset to partition. It should be a list of tuples where each tuple represents a data sample and its corresponding label.

required

Returns:

Name Type Description
dict

A dictionary where keys are client IDs (ranging from 0 to partitions_number-1) and values are lists of indices corresponding to the samples assigned to each client.

The function ensures that each class is represented equally in each subset. The partitioning process involves iterating over each class, shuffling the indices of that class, and then splitting them equally among the clients. The function does not print the class distribution in each subset.

Example usage

federated_data = balanced_iid_partition(my_dataset)

This creates federated data subsets with equal class distributions.

Source code in nebula/core/datasets/nebuladataset.py
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
def balanced_iid_partition(self, dataset):
    """
    Partition the dataset into balanced and IID (Independent and Identically Distributed)
    subsets for each client.

    This function divides a dataset into a specified number of subsets (federated clients),
    where each subset has an equal class distribution. This makes the partition suitable for
    simulating IID data scenarios in federated learning.

    Args:
        dataset (list): The dataset to partition. It should be a list of tuples where each
                        tuple represents a data sample and its corresponding label.

    Returns:
        dict: A dictionary where keys are client IDs (ranging from 0 to partitions_number-1) and
                values are lists of indices corresponding to the samples assigned to each client.

    The function ensures that each class is represented equally in each subset. The
    partitioning process involves iterating over each class, shuffling the indices of that class,
    and then splitting them equally among the clients. The function does not print the class
    distribution in each subset.

    Example usage:
        federated_data = balanced_iid_partition(my_dataset)
        # This creates federated data subsets with equal class distributions.
    """
    num_clients = self.partitions_number
    clients_data = {i: [] for i in range(num_clients)}

    # Get the labels from the dataset
    if isinstance(dataset.targets, np.ndarray):
        labels = dataset.targets
    elif hasattr(dataset.targets, "numpy"):  # Check if it's a tensor with .numpy() method
        labels = dataset.targets.numpy()
    else:  # If it's a list
        labels = np.asarray(dataset.targets)

    label_counts = np.bincount(labels)
    min_label = label_counts.argmin()
    min_count = label_counts[min_label]

    for label in range(self.num_classes):
        # Get the indices of the same label samples
        label_indices = np.where(labels == label)[0]
        np.random.seed(self.seed)
        np.random.shuffle(label_indices)

        # Split the data based on their labels
        samples_per_client = min_count // num_clients

        for i in range(num_clients):
            start_idx = i * samples_per_client
            end_idx = (i + 1) * samples_per_client
            clients_data[i].extend(label_indices[start_idx:end_idx])

    return clients_data

clear()

Clear the dataset. This should remove or reset the dataset.

Source code in nebula/core/datasets/nebuladataset.py
327
328
329
330
331
332
333
334
335
336
337
338
339
340
def clear(self):
    """
    Clear the dataset. This should remove or reset the dataset.
    """
    if self.train_set is not None and hasattr(self.train_set, "close"):
        self.train_set.close()
    if self.test_set is not None and hasattr(self.test_set, "close"):
        self.test_set.close()

    self.train_set = None
    self.train_indices_map = None
    self.test_set = None
    self.test_indices_map = None
    self.local_test_indices_map = None

data_partitioning(plot=False)

Perform the data partitioning.

Source code in nebula/core/datasets/nebuladataset.py
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
def data_partitioning(self, plot=False):
    """
    Perform the data partitioning.
    """

    logging.info(
        f"Partitioning data for {self.__class__.__name__} | Partitions: {self.partitions_number} | IID: {self.iid} | Partition: {self.partition} | Partition parameter: {self.partition_parameter}"
    )

    self.train_indices_map = (
        self.generate_iid_map(self.train_set)
        if self.iid
        else self.generate_non_iid_map(self.train_set, self.partition, self.partition_parameter)
    )
    self.test_indices_map = self.get_test_indices_map()
    self.local_test_indices_map = self.get_local_test_indices_map()

    if plot:
        self.plot_data_distribution("train", self.train_set, self.train_indices_map)
        self.plot_all_data_distribution("train", self.train_set, self.train_indices_map)
        self.plot_data_distribution("local_test", self.test_set, self.local_test_indices_map)
        self.plot_all_data_distribution("local_test", self.test_set, self.local_test_indices_map)

    self.save_partitions()

dirichlet_partition(dataset, alpha=0.5, min_samples_size=50, balanced=False, max_iter=100, verbose=True)

Partition the dataset among clients using a Dirichlet distribution. This function ensures each client gets at least min_samples_size samples.

Parameters

dataset : Dataset The dataset to partition. Must have a 'targets' attribute. alpha : float, default=0.5 Concentration parameter for the Dirichlet distribution. min_samples_size : int, default=50 Minimum number of samples required in each partition. balanced : bool, default=False If True, distribute each class evenly among clients. Otherwise, allocate according to a Dirichlet distribution. max_iter : int, default=100 Maximum number of iterations to try finding a valid partition. verbose : bool, default=True If True, print debug information per iteration.

Returns

partitions : dict[int, list[int]] Dictionary mapping each client index to a list of sample indices.

Source code in nebula/core/datasets/nebuladataset.py
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
def dirichlet_partition(
    self,
    dataset: Any,
    alpha: float = 0.5,
    min_samples_size: int = 50,
    balanced: bool = False,
    max_iter: int = 100,
    verbose: bool = True,
) -> dict[int, list[int]]:
    """
    Partition the dataset among clients using a Dirichlet distribution.
    This function ensures each client gets at least min_samples_size samples.

    Parameters
    ----------
    dataset : Dataset
        The dataset to partition. Must have a 'targets' attribute.
    alpha : float, default=0.5
        Concentration parameter for the Dirichlet distribution.
    min_samples_size : int, default=50
        Minimum number of samples required in each partition.
    balanced : bool, default=False
        If True, distribute each class evenly among clients.
        Otherwise, allocate according to a Dirichlet distribution.
    max_iter : int, default=100
        Maximum number of iterations to try finding a valid partition.
    verbose : bool, default=True
        If True, print debug information per iteration.

    Returns
    -------
    partitions : dict[int, list[int]]
        Dictionary mapping each client index to a list of sample indices.
    """
    # Extract targets and unique labels.
    y_data = self._get_targets(dataset)
    unique_labels = np.unique(y_data)

    # For each class, get a shuffled list of indices.
    class_indices = {}
    base_rng = np.random.default_rng(self.seed)
    for label in unique_labels:
        idx = np.where(y_data == label)[0]
        base_rng.shuffle(idx)
        class_indices[label] = idx

    # Prepare container for client indices.
    indices_per_partition = [[] for _ in range(self.partitions_number)]

    def allocate_for_label(label_idx: np.ndarray, rng: np.random.Generator) -> np.ndarray:
        num_label_samples = len(label_idx)
        if balanced:
            proportions = np.full(self.partitions_number, 1.0 / self.partitions_number)
        else:
            proportions = rng.dirichlet([alpha] * self.partitions_number)
        sample_counts = (proportions * num_label_samples).astype(int)
        remainder = num_label_samples - sample_counts.sum()
        if remainder > 0:
            extra_indices = rng.choice(self.partitions_number, size=remainder, replace=False)
            for idx in extra_indices:
                sample_counts[idx] += 1
        return sample_counts

    for iteration in range(1, max_iter + 1):
        rng = np.random.default_rng(self.seed + iteration)
        temp_indices_per_partition = [[] for _ in range(self.partitions_number)]
        for label in unique_labels:
            label_idx = class_indices[label]
            counts = allocate_for_label(label_idx, rng)
            start = 0
            for client_idx, count in enumerate(counts):
                end = start + count
                temp_indices_per_partition[client_idx].extend(label_idx[start:end])
                start = end

        client_sizes = [len(indices) for indices in temp_indices_per_partition]
        if min(client_sizes) >= min_samples_size:
            indices_per_partition = temp_indices_per_partition
            if verbose:
                print(f"Partition successful at iteration {iteration}. Client sizes: {client_sizes}")
            break
        if verbose:
            print(f"Iteration {iteration}: client sizes {client_sizes}")

    else:
        raise ValueError(
            f"Could not create partitions with at least {min_samples_size} samples per client after {max_iter} iterations."
        )

    initial_partition = {i: indices for i, indices in enumerate(indices_per_partition)}

    final_partition = self.postprocess_partition(initial_partition, y_data)

    return final_partition

generate_iid_map(dataset, plot=False) abstractmethod

Create an iid map of the dataset.

Source code in nebula/core/datasets/nebuladataset.py
470
471
472
473
474
475
@abstractmethod
def generate_iid_map(self, dataset, plot=False):
    """
    Create an iid map of the dataset.
    """
    pass

generate_non_iid_map(dataset, partition='dirichlet', plot=False) abstractmethod

Create a non-iid map of the dataset.

Source code in nebula/core/datasets/nebuladataset.py
463
464
465
466
467
468
@abstractmethod
def generate_non_iid_map(self, dataset, partition="dirichlet", plot=False):
    """
    Create a non-iid map of the dataset.
    """
    pass

get_local_test_indices_map()

Get the indices of the local test set for each participant. Indices whose labels are the same as the training set are selected.

Returns:

Type Description

A dictionary mapping participant_id to a list of indices.

Source code in nebula/core/datasets/nebuladataset.py
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
def get_local_test_indices_map(self):
    """
    Get the indices of the local test set for each participant.
    Indices whose labels are the same as the training set are selected.

    Returns:
        A dictionary mapping participant_id to a list of indices.
    """
    try:
        local_test_indices_map = {}
        test_targets = np.array(self.test_set.targets)
        for participant_id in range(self.partitions_number):
            train_labels = np.array([self.train_set.targets[idx] for idx in self.train_indices_map[participant_id]])
            indices = np.where(np.isin(test_targets, train_labels))[0].tolist()
            local_test_indices_map[participant_id] = indices
        return local_test_indices_map
    except Exception as e:
        logging.exception(f"Error in get_local_test_indices_map: {e}")
        raise

get_test_indices_map()

Get the indices of the test set for each participant.

Returns:

Type Description

A dictionary mapping participant_id to a list of indices.

Source code in nebula/core/datasets/nebuladataset.py
367
368
369
370
371
372
373
374
375
376
377
378
379
380
def get_test_indices_map(self):
    """
    Get the indices of the test set for each participant.

    Returns:
        A dictionary mapping participant_id to a list of indices.
    """
    try:
        test_indices_map = {}
        for participant_id in range(self.partitions_number):
            test_indices_map[participant_id] = list(range(len(self.test_set)))
        return test_indices_map
    except Exception as e:
        logging.exception(f"Error in get_test_indices_map: {e}")

homo_partition(dataset)

Homogeneously partition the dataset into multiple subsets.

This function divides a dataset into a specified number of subsets, where each subset is intended to have a roughly equal number of samples. This method aims to ensure a homogeneous distribution of data across all subsets. It's particularly useful in scenarios where a uniform distribution of data is desired among all federated learning clients.

Parameters:

Name Type Description Default
dataset Dataset

The dataset to partition. It should have 'data' and 'targets' attributes.

required

Returns:

Name Type Description
dict

A dictionary where keys are subset indices (ranging from 0 to partitions_number-1) and values are lists of indices corresponding to the samples in each subset.

The function randomly shuffles the entire dataset and then splits it into the number of subsets specified by partitions_number. It ensures that each subset has a similar number of samples. The function also prints the class distribution in each subset for reference.

Example usage

federated_data = homo_partition(my_dataset)

This creates federated data subsets with homogeneous distribution.

Source code in nebula/core/datasets/nebuladataset.py
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
def homo_partition(self, dataset):
    """
    Homogeneously partition the dataset into multiple subsets.

    This function divides a dataset into a specified number of subsets, where each subset
    is intended to have a roughly equal number of samples. This method aims to ensure a
    homogeneous distribution of data across all subsets. It's particularly useful in
    scenarios where a uniform distribution of data is desired among all federated learning
    clients.

    Args:
        dataset (torch.utils.data.Dataset): The dataset to partition. It should have
                                            'data' and 'targets' attributes.

    Returns:
        dict: A dictionary where keys are subset indices (ranging from 0 to partitions_number-1)
            and values are lists of indices corresponding to the samples in each subset.

    The function randomly shuffles the entire dataset and then splits it into the number
    of subsets specified by `partitions_number`. It ensures that each subset has a similar number
    of samples. The function also prints the class distribution in each subset for reference.

    Example usage:
        federated_data = homo_partition(my_dataset)
        # This creates federated data subsets with homogeneous distribution.
    """
    n_nets = self.partitions_number

    n_train = len(dataset.targets)
    np.random.seed(self.seed)
    idxs = np.random.permutation(n_train)
    batch_idxs = np.array_split(idxs, n_nets)
    net_dataidx_map = {i: batch_idxs[i] for i in range(n_nets)}

    # partitioned_datasets = []
    for i in range(self.partitions_number):
        # subset = torch.utils.data.Subset(dataset, net_dataidx_map[i])
        # partitioned_datasets.append(subset)

        # Print class distribution in the current partition
        class_counts = [0] * self.num_classes
        for idx in net_dataidx_map[i]:
            label = dataset.targets[idx]
            class_counts[label] += 1
        logging.info(f"Partition {i + 1} class distribution: {class_counts}")

    return net_dataidx_map

initialize_dataset() abstractmethod

Initialize the dataset. This should load or create the dataset.

Source code in nebula/core/datasets/nebuladataset.py
320
321
322
323
324
325
@abstractmethod
def initialize_dataset(self):
    """
    Initialize the dataset. This should load or create the dataset.
    """
    raise NotImplementedError("Subclasses must implement this method.")

percentage_partition(dataset, percentage=20)

Partition a dataset into multiple subsets with a specified level of non-IID-ness.

Parameters:

Name Type Description Default
dataset Dataset

The dataset to partition. It should have 'data' and 'targets' attributes.

required
percentage int

A value between 0 and 100 that specifies the desired level of non-IID-ness for the labels of the federated data. This percentage controls the imbalance in the class distribution across different subsets.

20

Returns:

Name Type Description
dict

A dictionary where keys are subset indices (ranging from 0 to partitions_number-1) and values are lists of indices corresponding to the samples in each subset.

Example usage

federated_data = percentage_partition(my_dataset, percentage=20)

This creates federated data subsets with varying class distributions based on

a percentage of 20.

Source code in nebula/core/datasets/nebuladataset.py
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
def percentage_partition(self, dataset, percentage=20):
    """
    Partition a dataset into multiple subsets with a specified level of non-IID-ness.

    Args:
        dataset (torch.utils.data.Dataset): The dataset to partition. It should have
                                            'data' and 'targets' attributes.
        percentage (int): A value between 0 and 100 that specifies the desired
                            level of non-IID-ness for the labels of the federated data.
                            This percentage controls the imbalance in the class distribution
                            across different subsets.

    Returns:
        dict: A dictionary where keys are subset indices (ranging from 0 to partitions_number-1)
            and values are lists of indices corresponding to the samples in each subset.

    Example usage:
        federated_data = percentage_partition(my_dataset, percentage=20)
        # This creates federated data subsets with varying class distributions based on
        # a percentage of 20.
    """
    if isinstance(dataset.targets, np.ndarray):
        y_train = dataset.targets
    elif hasattr(dataset.targets, "numpy"):  # Check if it's a tensor with .numpy() method
        y_train = dataset.targets.numpy()
    else:  # If it's a list
        y_train = np.asarray(dataset.targets)

    num_classes = self.num_classes
    num_subsets = self.partitions_number
    class_indices = {i: np.where(y_train == i)[0] for i in range(num_classes)}

    # Get the labels from the dataset
    labels = np.array([dataset.targets[idx] for idx in range(len(dataset))])
    label_counts = np.bincount(labels)

    min_label = label_counts.argmin()
    min_count = label_counts[min_label]

    classes_per_subset = int(num_classes * percentage / 100)
    if classes_per_subset < 1:
        raise ValueError("The percentage is too low to assign at least one class to each subset.")

    subset_indices = [[] for _ in range(num_subsets)]
    class_list = list(range(num_classes))
    np.random.seed(self.seed)
    np.random.shuffle(class_list)

    for i in range(num_subsets):
        for j in range(classes_per_subset):
            # Use modulo operation to cycle through the class_list
            class_idx = class_list[(i * classes_per_subset + j) % num_classes]
            indices = class_indices[class_idx]
            np.random.seed(self.seed)
            np.random.shuffle(indices)
            # Select approximately 50% of the indices
            subset_indices[i].extend(indices[: min_count // 2])

        class_counts = np.bincount(np.array([dataset.targets[idx] for idx in subset_indices[i]]))
        logging.info(f"Partition {i + 1} class distribution: {class_counts.tolist()}")

    partitioned_datasets = {i: subset_indices[i] for i in range(num_subsets)}

    return partitioned_datasets

plot_all_data_distribution(phase, dataset, partitions_map)

Plot all of the data distribution of the dataset according to the partitions map provided.

Parameters:

Name Type Description Default
dataset

The dataset to plot (torch.utils.data.Dataset).

required
partitions_map

The map of the dataset partitions.

required
Source code in nebula/core/datasets/nebuladataset.py
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
def plot_all_data_distribution(self, phase, dataset, partitions_map):
    """

    Plot all of the data distribution of the dataset according to the partitions map provided.

    Args:
        dataset: The dataset to plot (torch.utils.data.Dataset).
        partitions_map: The map of the dataset partitions.
    """
    sns.set()
    sns.set_style("whitegrid", {"axes.grid": False})
    sns.set_context("paper", font_scale=1.5)
    sns.set_palette("Set2")

    num_clients = len(partitions_map)
    num_classes = self.num_classes

    # Plot number of samples per class in the dataset
    plt.figure(figsize=(12, 8))

    class_counts = [0] * num_classes
    for target in dataset.targets:
        class_counts[target] += 1

    plt.bar(range(num_classes), class_counts, tick_label=dataset.classes)
    for j, count in enumerate(class_counts):
        plt.text(j, count, str(count), ha="center", va="bottom", fontsize=12)
    plt.title(f"[{phase}] Number of samples per class in the dataset")
    plt.xlabel("Class")
    plt.ylabel("Number of samples")
    plt.tight_layout()

    path_to_save_class_distribution = f"{self.config_dir}/full_data_distribution_{'iid' if self.iid else 'non_iid'}{'_' + self.partition if not self.iid else ''}_{phase}.pdf"
    plt.savefig(path_to_save_class_distribution, dpi=300, bbox_inches="tight")
    plt.close()

    # Plot distribution of samples across participants
    plt.figure(figsize=(12, 8))

    label_distribution = [[] for _ in range(num_classes)]
    for c_id, idc in partitions_map.items():
        for idx in idc:
            label_distribution[dataset.targets[idx]].append(c_id)

    plt.hist(
        label_distribution,
        stacked=True,
        bins=np.arange(-0.5, num_clients + 1.5, 1),
        label=dataset.classes,
        rwidth=0.5,
    )
    plt.xticks(
        np.arange(num_clients),
        ["Participant %d" % (c_id + 1) for c_id in range(num_clients)],
    )
    plt.title(f"[{phase}] Distribution of splited datasets")
    plt.xlabel("Participant")
    plt.ylabel("Number of samples")
    plt.xticks(range(num_clients), [f" {i}" for i in range(num_clients)])
    plt.legend(loc="upper right")
    plt.tight_layout()

    path_to_save = f"{self.config_dir}/all_data_distribution_HIST_{'iid' if self.iid else 'non_iid'}{'_' + self.partition if not self.iid else ''}_{phase}.pdf"
    plt.savefig(path_to_save, dpi=300, bbox_inches="tight")
    plt.close()

    plt.figure(figsize=(12, 8))
    max_point_size = 1200
    min_point_size = 0

    for i in range(self.partitions_number):
        class_counts = [0] * self.num_classes
        indices = partitions_map[i]
        for idx in indices:
            label = dataset.targets[idx]
            class_counts[label] += 1

        # Normalize the point sizes for this partition, handling the case where max_samples_partition is 0
        max_samples_partition = max(class_counts)
        if max_samples_partition == 0:
            logging.warning(f"[{phase}] Participant {i} has no samples. Skipping size normalization.")
            sizes = [min_point_size] * self.num_classes
        else:
            sizes = [
                (size / max_samples_partition) * (max_point_size - min_point_size) + min_point_size
                for size in class_counts
            ]
        plt.scatter([i] * self.num_classes, range(self.num_classes), s=sizes, alpha=0.5)

    plt.xlabel("Participant")
    plt.ylabel("Class")
    plt.xticks(range(self.partitions_number))
    plt.yticks(range(self.num_classes))
    plt.title(
        f"[{phase}] Data distribution across participants ({'IID' if self.iid else f'Non-IID - {self.partition}'} - {self.partition_parameter})"
    )
    plt.tight_layout()

    path_to_save = f"{self.config_dir}/all_data_distribution_CIRCLES_{'iid' if self.iid else 'non_iid'}{'_' + self.partition if not self.iid else ''}_{phase}.pdf"
    plt.savefig(path_to_save, dpi=300, bbox_inches="tight")
    plt.close()

plot_data_distribution(phase, dataset, partitions_map)

Plot the data distribution of the dataset.

Plot the data distribution of the dataset according to the partitions map provided.

Parameters:

Name Type Description Default
phase

The phase of the dataset (train, test, local_test).

required
dataset

The dataset to plot (torch.utils.data.Dataset).

required
partitions_map

The map of the dataset partitions.

required
Source code in nebula/core/datasets/nebuladataset.py
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
def plot_data_distribution(self, phase, dataset, partitions_map):
    """
    Plot the data distribution of the dataset.

    Plot the data distribution of the dataset according to the partitions map provided.

    Args:
        phase: The phase of the dataset (train, test, local_test).
        dataset: The dataset to plot (torch.utils.data.Dataset).
        partitions_map: The map of the dataset partitions.
    """
    logging_training.info(f"Plotting data distribution for {phase} dataset")
    # Plot the data distribution of the dataset, one graph per partition
    sns.set()
    sns.set_style("whitegrid", {"axes.grid": False})
    sns.set_context("paper", font_scale=1.5)
    sns.set_palette("Set2")

    # Plot bar charts for each partition
    partition_index = 0
    for partition_index in range(self.partitions_number):
        indices = partitions_map[partition_index]
        class_counts = [0] * self.num_classes
        for idx in indices:
            label = dataset.targets[idx]
            class_counts[label] += 1

        logging_training.info(f"[{phase}] Participant {partition_index} total samples: {len(indices)}")
        logging_training.info(f"[{phase}] Participant {partition_index} data distribution: {class_counts}")

        plt.figure(figsize=(12, 8))
        plt.bar(range(self.num_classes), class_counts)
        for j, count in enumerate(class_counts):
            plt.text(j, count, str(count), ha="center", va="bottom", fontsize=12)
        plt.xlabel("Class")
        plt.ylabel("Number of samples")
        plt.xticks(range(self.num_classes))
        plt.title(
            f"[{phase}] Participant {partition_index} data distribution ({'IID' if self.iid else f'Non-IID - {self.partition}'} - {self.partition_parameter})"
        )
        plt.tight_layout()
        path_to_save = f"{self.config_dir}/participant_{partition_index}_data_distribution_{'iid' if self.iid else 'non_iid'}{'_' + self.partition if not self.iid else ''}_{phase}.pdf"
        logging_training.info(f"Saving data distribution for participant {partition_index} to {path_to_save}")
        plt.savefig(path_to_save, dpi=300, bbox_inches="tight")
        plt.close()

    if hasattr(self, "tsne") and self.tsne:
        self.visualize_tsne(dataset)

postprocess_partition(partition, y_data, min_samples_per_class=10)

Post-process a partition to remove (and reassign) classes with too few samples per client.

For each class: - For clients with a count > 0 but below min_samples_per_class, remove those samples. - Reassign the removed samples to the client that already has the maximum count for that class.

Parameters

partition : dict[int, list[int]] The initial partition mapping client indices to sample indices. y_data : np.ndarray The array of labels corresponding to the dataset samples. min_samples_per_class : int, default=10 The minimum acceptable number of samples per class for each client.

Returns

new_partition : dict[int, list[int]] The updated partition.

Source code in nebula/core/datasets/nebuladataset.py
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
def postprocess_partition(
    self, partition: dict[int, list[int]], y_data: np.ndarray, min_samples_per_class: int = 10
) -> dict[int, list[int]]:
    """
    Post-process a partition to remove (and reassign) classes with too few samples per client.

    For each class:
    - For clients with a count > 0 but below min_samples_per_class, remove those samples.
    - Reassign the removed samples to the client that already has the maximum count for that class.

    Parameters
    ----------
    partition : dict[int, list[int]]
        The initial partition mapping client indices to sample indices.
    y_data : np.ndarray
        The array of labels corresponding to the dataset samples.
    min_samples_per_class : int, default=10
        The minimum acceptable number of samples per class for each client.

    Returns
    -------
    new_partition : dict[int, list[int]]
        The updated partition.
    """
    # Copy partition so we can modify it.
    new_partition = {client: list(indices) for client, indices in partition.items()}

    # Iterate over each class.
    for label in np.unique(y_data):
        # For each client, count how many samples of this label exist.
        client_counts = {}
        for client, indices in new_partition.items():
            client_counts[client] = np.sum(np.array(y_data)[indices] == label)

        # Identify clients with fewer than min_samples_per_class but nonzero counts.
        donors = [client for client, count in client_counts.items() if 0 < count < min_samples_per_class]
        # Identify potential recipients: those with at least min_samples_per_class.
        recipients = [client for client, count in client_counts.items() if count >= min_samples_per_class]
        # If no client meets the threshold, choose the one with the highest count.
        if not recipients:
            best_recipient = max(client_counts, key=client_counts.get)
            recipients = [best_recipient]
        # Choose the recipient with the maximum count.
        best_recipient = max(recipients, key=lambda c: client_counts[c])

        # For each donor, remove samples of this label and reassign them.
        for donor in donors:
            donor_indices = new_partition[donor]
            # Identify indices corresponding to this label.
            donor_label_indices = [idx for idx in donor_indices if y_data[idx] == label]
            # Remove these from the donor.
            new_partition[donor] = [idx for idx in donor_indices if y_data[idx] != label]
            # Add these to the best recipient.
            new_partition[best_recipient].extend(donor_label_indices)

    return new_partition

save_partitions()

Save each partition data (train, test, and local test) to separate pickle files. The controller saves one file per partition for each data split.

Source code in nebula/core/datasets/nebuladataset.py
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
def save_partitions(self):
    """
    Save each partition data (train, test, and local test) to separate pickle files.
    The controller saves one file per partition for each data split.
    """
    try:
        logging.info(f"Saving partitions data for ALL participants ({self.partitions_number}) in {self.config_dir}")
        path = self.config_dir
        if not os.path.exists(path):
            raise FileNotFoundError(f"Path {path} does not exist")
        # Check that the partition maps have the expected number of partitions
        if not (
            len(self.train_indices_map)
            == len(self.test_indices_map)
            == len(self.local_test_indices_map)
            == self.partitions_number
        ):
            raise ValueError("One of the partition maps has an unexpected length.")

        # Save global test data
        file_name = os.path.join(path, "global_test.h5")
        with h5py.File(file_name, "w") as f:
            indices = list(range(len(self.test_set)))
            test_data = [self.test_set[i] for i in indices]
            self.save_partition(test_data, f, "test_data")
            f["test_data"].attrs["num_classes"] = self.num_classes
            test_targets = np.array(self.test_set.targets)
            f.create_dataset("test_targets", data=test_targets, compression="gzip")

        for participant in range(self.partitions_number):
            file_name = os.path.join(path, f"participant_{participant}_train.h5")
            with h5py.File(file_name, "w") as f:
                logging.info(f"Saving training data for participant {participant} in {file_name}")
                indices = self.train_indices_map[participant]
                train_data = [self.train_set[i] for i in indices]
                self.save_partition(train_data, f, "train_data")
                f["train_data"].attrs["num_classes"] = self.num_classes
                train_targets = np.array([self.train_set.targets[i] for i in indices])
                f.create_dataset("train_targets", data=train_targets, compression="gzip")
                logging.info(f"Partition saved for participant {participant}.")

        logging.info("Successfully saved all partition files.")

    except Exception as e:
        logging.exception(f"Error in save_partitions: {e}")

    finally:
        self.clear()
        logging.info("Cleared dataset after saving partitions.")

unbalanced_iid_partition(dataset, imbalance_factor=2)

Partition the dataset into multiple IID (Independent and Identically Distributed) subsets with different size.

This function divides a dataset into a specified number of IID subsets (federated clients), where each subset has a different number of samples. The number of samples in each subset is determined by an imbalance factor, making the partition suitable for simulating imbalanced data scenarios in federated learning.

Parameters:

Name Type Description Default
dataset list

The dataset to partition. It should be a list of tuples where each tuple represents a data sample and its corresponding label.

required
imbalance_factor float

The factor to determine the degree of imbalance among the subsets. A lower imbalance factor leads to more imbalanced partitions.

2

Returns:

Name Type Description
dict

A dictionary where keys are client IDs (ranging from 0 to partitions_number-1) and values are lists of indices corresponding to the samples assigned to each client.

The function ensures that each class is represented in each subset but with varying proportions. The partitioning process involves iterating over each class, shuffling the indices of that class, and then splitting them according to the calculated subset sizes. The function does not print the class distribution in each subset.

Example usage

federated_data = unbalanced_iid_partition(my_dataset, imbalance_factor=2)

This creates federated data subsets with varying number of samples based on

an imbalance factor of 2.

Source code in nebula/core/datasets/nebuladataset.py
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
def unbalanced_iid_partition(self, dataset, imbalance_factor=2):
    """
    Partition the dataset into multiple IID (Independent and Identically Distributed)
    subsets with different size.

    This function divides a dataset into a specified number of IID subsets (federated
    clients), where each subset has a different number of samples. The number of samples
    in each subset is determined by an imbalance factor, making the partition suitable
    for simulating imbalanced data scenarios in federated learning.

    Args:
        dataset (list): The dataset to partition. It should be a list of tuples where
                        each tuple represents a data sample and its corresponding label.
        imbalance_factor (float): The factor to determine the degree of imbalance
                                among the subsets. A lower imbalance factor leads to more
                                imbalanced partitions.

    Returns:
        dict: A dictionary where keys are client IDs (ranging from 0 to partitions_number-1) and
                values are lists of indices corresponding to the samples assigned to each client.

    The function ensures that each class is represented in each subset but with varying
    proportions. The partitioning process involves iterating over each class, shuffling
    the indices of that class, and then splitting them according to the calculated subset
    sizes. The function does not print the class distribution in each subset.

    Example usage:
        federated_data = unbalanced_iid_partition(my_dataset, imbalance_factor=2)
        # This creates federated data subsets with varying number of samples based on
        # an imbalance factor of 2.
    """
    num_clients = self.partitions_number
    clients_data = {i: [] for i in range(num_clients)}

    # Get the labels from the dataset
    labels = np.array([dataset.targets[idx] for idx in range(len(dataset))])
    label_counts = np.bincount(labels)

    min_label = label_counts.argmin()
    min_count = label_counts[min_label]

    # Set the initial_subset_size
    initial_subset_size = min_count // num_clients

    # Calculate the number of samples for each subset based on the imbalance factor
    subset_sizes = [initial_subset_size]
    for i in range(1, num_clients):
        subset_sizes.append(int(subset_sizes[i - 1] * ((imbalance_factor - 1) / imbalance_factor)))

    for label in range(self.num_classes):
        # Get the indices of the same label samples
        label_indices = np.where(labels == label)[0]
        np.random.seed(self.seed)
        np.random.shuffle(label_indices)

        # Split the data based on their labels
        start = 0
        for i in range(num_clients):
            end = start + subset_sizes[i]
            clients_data[i].extend(label_indices[start:end])
            start = end

    return clients_data

NebulaPartition

A class to handle the partitioning of datasets for federated learning.

Source code in nebula/core/datasets/nebuladataset.py
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
class NebulaPartition:
    """
    A class to handle the partitioning of datasets for federated learning.
    """

    def __init__(self, handler: NebulaPartitionHandler, config: dict[str, Any] | None = None):
        self.handler = handler
        self.config = config if config is not None else {}

        self.train_set = None
        self.train_indices = None

        self.test_set = None
        self.test_indices = None

        self.local_test_set = None
        self.local_test_indices = None

        enable_deterministic(seed=self.config.participant["scenario_args"]["random_seed"])

    def get_train_indices(self):
        """
        Get the indices of the training set based on the indices map.
        """
        if self.train_indices is None:
            return None
        return self.train_indices

    def get_test_indices(self):
        """
        Get the indices of the test set based on the indices map.
        """
        if self.test_indices is None:
            return None
        return self.test_indices

    def get_local_test_indices(self):
        """
        Get the indices of the local test set based on the indices map.
        """
        if self.local_test_indices is None:
            return None
        return self.local_test_indices

    def get_train_labels(self):
        """
        Get the labels of the training set based on the indices map.
        """
        if self.train_indices is None:
            return None
        return [self.train_set.targets[idx] for idx in self.train_indices]

    def get_test_labels(self):
        """
        Get the labels of the test set based on the indices map.
        """
        if self.test_indices is None:
            return None
        return [self.test_set.targets[idx] for idx in self.test_indices]

    def get_local_test_labels(self):
        """
        Get the labels of the test set based on the indices map.
        """
        if self.local_test_indices is None:
            return None
        return [self.test_set.targets[idx] for idx in self.local_test_indices]

    def set_local_test_indices(self):
        """
        Set the local test indices for the current node.
        """
        test_labels = self.get_test_labels()
        train_labels = self.get_train_labels()

        if test_labels is None or train_labels is None:
            logging_training.warning("Either test_labels or train_labels is None in set_local_test_indices")
            return []

        if self.test_set is None:
            logging_training.warning("test_set is None in set_local_test_indices")
            return []

        return [idx for idx in range(len(self.test_set)) if test_labels[idx] in train_labels]

    def log_partition(self):
        logging_training.info(f"{'=' * 10}")
        logging_training.info(
            f"LOG NEBULA PARTITION DATASET [Participant {self.config.participant['device_args']['idx']}]"
        )
        logging_training.info(f"{'=' * 10}")
        logging_training.info(f"TRAIN - Train labels unique: {set(self.get_train_labels())}")
        logging_training.info(f"TRAIN - Length of train indices map: {len(self.get_train_indices())}")
        logging_training.info(f"{'=' * 10}")
        logging_training.info(f"LOCAL - Test labels unique: {set(self.get_local_test_labels())}")
        logging_training.info(f"LOCAL - Length of test indices map: {len(self.get_local_test_indices())}")
        logging_training.info(f"{'=' * 10}")
        logging_training.info(f"GLOBAL - Test labels unique: {set(self.get_test_labels())}")
        logging_training.info(f"GLOBAL - Length of test indices map: {len(self.get_test_indices())}")
        logging_training.info(f"{'=' * 10}")

    def load_partition(self):
        """
        Load only the partition data corresponding to the current node.
        The node loads its train, test, and local test partition data from HDF5 files.
        """
        try:
            p = self.config.participant["device_args"]["idx"]
            logging_training.info(f"Loading partition data for participant {p}")
            path = self.config.participant["tracking_args"]["config_dir"]

            train_partition_file = os.path.join(path, f"participant_{p}_train.h5")
            wait_for_file(train_partition_file)
            logging_training.info(f"Loading train data from {train_partition_file}")
            self.train_set = self.handler(train_partition_file, "train", config=self.config)
            self.train_indices = list(range(len(self.train_set)))

            test_partition_file = os.path.join(path, "global_test.h5")
            wait_for_file(test_partition_file)
            logging_training.info(f"Loading test data from {test_partition_file}")
            self.test_set = self.handler(test_partition_file, "test", config=self.config)
            self.test_indices = list(range(len(self.test_set)))

            self.local_test_set = self.handler(
                test_partition_file, "local_test", config=self.config, empty=True
            )
            self.local_test_set.set_data(self.test_set.data, self.test_set.targets)
            self.local_test_indices = self.set_local_test_indices()

            logging_training.info(f"Successfully loaded partition data for participant {p}.")
        except Exception as e:
            logging_training.error(f"Error loading partition: {e}")
            raise

get_local_test_indices()

Get the indices of the local test set based on the indices map.

Source code in nebula/core/datasets/nebuladataset.py
185
186
187
188
189
190
191
def get_local_test_indices(self):
    """
    Get the indices of the local test set based on the indices map.
    """
    if self.local_test_indices is None:
        return None
    return self.local_test_indices

get_local_test_labels()

Get the labels of the test set based on the indices map.

Source code in nebula/core/datasets/nebuladataset.py
209
210
211
212
213
214
215
def get_local_test_labels(self):
    """
    Get the labels of the test set based on the indices map.
    """
    if self.local_test_indices is None:
        return None
    return [self.test_set.targets[idx] for idx in self.local_test_indices]

get_test_indices()

Get the indices of the test set based on the indices map.

Source code in nebula/core/datasets/nebuladataset.py
177
178
179
180
181
182
183
def get_test_indices(self):
    """
    Get the indices of the test set based on the indices map.
    """
    if self.test_indices is None:
        return None
    return self.test_indices

get_test_labels()

Get the labels of the test set based on the indices map.

Source code in nebula/core/datasets/nebuladataset.py
201
202
203
204
205
206
207
def get_test_labels(self):
    """
    Get the labels of the test set based on the indices map.
    """
    if self.test_indices is None:
        return None
    return [self.test_set.targets[idx] for idx in self.test_indices]

get_train_indices()

Get the indices of the training set based on the indices map.

Source code in nebula/core/datasets/nebuladataset.py
169
170
171
172
173
174
175
def get_train_indices(self):
    """
    Get the indices of the training set based on the indices map.
    """
    if self.train_indices is None:
        return None
    return self.train_indices

get_train_labels()

Get the labels of the training set based on the indices map.

Source code in nebula/core/datasets/nebuladataset.py
193
194
195
196
197
198
199
def get_train_labels(self):
    """
    Get the labels of the training set based on the indices map.
    """
    if self.train_indices is None:
        return None
    return [self.train_set.targets[idx] for idx in self.train_indices]

load_partition()

Load only the partition data corresponding to the current node. The node loads its train, test, and local test partition data from HDF5 files.

Source code in nebula/core/datasets/nebuladataset.py
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
def load_partition(self):
    """
    Load only the partition data corresponding to the current node.
    The node loads its train, test, and local test partition data from HDF5 files.
    """
    try:
        p = self.config.participant["device_args"]["idx"]
        logging_training.info(f"Loading partition data for participant {p}")
        path = self.config.participant["tracking_args"]["config_dir"]

        train_partition_file = os.path.join(path, f"participant_{p}_train.h5")
        wait_for_file(train_partition_file)
        logging_training.info(f"Loading train data from {train_partition_file}")
        self.train_set = self.handler(train_partition_file, "train", config=self.config)
        self.train_indices = list(range(len(self.train_set)))

        test_partition_file = os.path.join(path, "global_test.h5")
        wait_for_file(test_partition_file)
        logging_training.info(f"Loading test data from {test_partition_file}")
        self.test_set = self.handler(test_partition_file, "test", config=self.config)
        self.test_indices = list(range(len(self.test_set)))

        self.local_test_set = self.handler(
            test_partition_file, "local_test", config=self.config, empty=True
        )
        self.local_test_set.set_data(self.test_set.data, self.test_set.targets)
        self.local_test_indices = self.set_local_test_indices()

        logging_training.info(f"Successfully loaded partition data for participant {p}.")
    except Exception as e:
        logging_training.error(f"Error loading partition: {e}")
        raise

set_local_test_indices()

Set the local test indices for the current node.

Source code in nebula/core/datasets/nebuladataset.py
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
def set_local_test_indices(self):
    """
    Set the local test indices for the current node.
    """
    test_labels = self.get_test_labels()
    train_labels = self.get_train_labels()

    if test_labels is None or train_labels is None:
        logging_training.warning("Either test_labels or train_labels is None in set_local_test_indices")
        return []

    if self.test_set is None:
        logging_training.warning("test_set is None in set_local_test_indices")
        return []

    return [idx for idx in range(len(self.test_set)) if test_labels[idx] in train_labels]

NebulaPartitionHandler

Bases: Dataset, ABC

A class to handle the loading of datasets from HDF5 files.

Source code in nebula/core/datasets/nebuladataset.py
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
class NebulaPartitionHandler(Dataset, ABC):
    """
    A class to handle the loading of datasets from HDF5 files.
    """

    def __init__(
        self,
        file_path: str,
        prefix: str = "train",
        config: dict[str, Any] | None = None,
        empty: bool = False,
    ):
        self.file_path = file_path
        self.prefix = prefix
        self.config = config
        self.empty = empty
        self.transform = None
        self.target_transform = None
        self.file = None

        self.data = None
        self.targets = None
        self.num_classes = None
        self.length = None

        self.load_data()

    def load_data(self):
        if self.empty:
            logging_training.info(
                f"[NebulaPartitionHandler] No data loaded for {self.prefix} partition. Empty dataset."
            )
            return
        with h5py.File(self.file_path, "r") as f:
            prefix = (
                "test" if self.prefix == "local_test" else self.prefix
            )  # Local test uses the test prefix (same data but different split)
            self.data = self.load_partition(f, f"{prefix}_data")
            self.targets = np.array(f[f"{prefix}_targets"])
            self.num_classes = f[f"{prefix}_data"].attrs.get("num_classes", 0)
            self.length = len(self.data)
        logging_training.info(
            f"[NebulaPartitionHandler] [{self.prefix}] Loaded {self.length} samples from {self.file_path} and {self.num_classes} classes."
        )

    def close(self):
        if self.file is not None:
            self.file.close()
            self.file = None
            logging_training.info(f"[NebulaPartitionHandler] Closed file {self.file_path}")

    def __del__(self):
        self.close()

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        data = self.data[idx]
        target = self.targets[idx]
        return data, target

    def set_data(self, data, targets, data_opt=None, targets_opt=None):
        """
        Set the data and targets for the dataset.
        """
        try:
            # Input validation
            if data is None or targets is None:
                raise ValueError("Primary data and targets cannot be None")

            if len(data) != len(targets):
                raise ValueError(f"Data and targets length mismatch: {len(data)} vs {len(targets)}")

            if data_opt is None or targets_opt is None:
                self.data = data
                self.targets = targets
                self.length = len(data)
                logging_training.info(f"[NebulaPartitionHandler] Set data with {self.length} samples.")
                return

            if len(data_opt) != len(targets_opt):
                raise ValueError(f"Optional data and targets length mismatch: {len(data_opt)} vs {len(targets_opt)}")

            main_count = int(len(data) * 0.8)
            opt_count = min(len(data_opt), int(len(data) * (1 - 0.8)))
            if isinstance(data, np.ndarray):
                self.data = np.concatenate((data[:main_count], data_opt[:opt_count]))
            else:
                self.data = data[:main_count] + data_opt[:opt_count]

            if isinstance(targets, np.ndarray):
                self.targets = np.concatenate((targets[:main_count], targets_opt[:opt_count]))
            else:
                self.targets = targets[:main_count] + targets_opt[:opt_count]
            self.length = len(self.data)

        except Exception as e:
            logging_training.exception(f"Error setting data: {e}")

    def load_partition(self, file, name):
        item = file[name]
        if isinstance(item, h5py.Dataset):
            typ = item.attrs.get("__type__", None)
            if typ == "pickle":
                logging_training.info(f"Loading pickled object from {name}")
                return pickle.loads(item[()].tobytes())
            else:
                logging_training.warning(f"[NebulaPartitionHandler] Unknown type encountered: {typ} for item {name}")
                return item[()]
        else:
            logging_training.warning(f"[NebulaPartitionHandler] Unknown item encountered: {item} for item {name}")
            return item[()]

set_data(data, targets, data_opt=None, targets_opt=None)

Set the data and targets for the dataset.

Source code in nebula/core/datasets/nebuladataset.py
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
def set_data(self, data, targets, data_opt=None, targets_opt=None):
    """
    Set the data and targets for the dataset.
    """
    try:
        # Input validation
        if data is None or targets is None:
            raise ValueError("Primary data and targets cannot be None")

        if len(data) != len(targets):
            raise ValueError(f"Data and targets length mismatch: {len(data)} vs {len(targets)}")

        if data_opt is None or targets_opt is None:
            self.data = data
            self.targets = targets
            self.length = len(data)
            logging_training.info(f"[NebulaPartitionHandler] Set data with {self.length} samples.")
            return

        if len(data_opt) != len(targets_opt):
            raise ValueError(f"Optional data and targets length mismatch: {len(data_opt)} vs {len(targets_opt)}")

        main_count = int(len(data) * 0.8)
        opt_count = min(len(data_opt), int(len(data) * (1 - 0.8)))
        if isinstance(data, np.ndarray):
            self.data = np.concatenate((data[:main_count], data_opt[:opt_count]))
        else:
            self.data = data[:main_count] + data_opt[:opt_count]

        if isinstance(targets, np.ndarray):
            self.targets = np.concatenate((targets[:main_count], targets_opt[:opt_count]))
        else:
            self.targets = targets[:main_count] + targets_opt[:opt_count]
        self.length = len(self.data)

    except Exception as e:
        logging_training.exception(f"Error setting data: {e}")

wait_for_file(file_path)

Wait until the given file exists, polling every 'interval' seconds.

Source code in nebula/core/datasets/nebuladataset.py
27
28
29
30
31
def wait_for_file(file_path):
    """Wait until the given file exists, polling every 'interval' seconds."""
    while not os.path.exists(file_path):
        logging_training.info(f"Waiting for file: {file_path}")
    return