nebula.core.models.cifar10.dualagg#

Attributes#

Classes#

ContrastiveLoss

Contrastive loss function.

DualAggModel

Module Contents#

nebula.core.models.cifar10.dualagg.logging_training#
class nebula.core.models.cifar10.dualagg.ContrastiveLoss(mu=0.5)#

Bases: torch.nn.Module

Contrastive loss function.

mu#
cross_entropy_loss#
forward(local_out, global_out, historical_out, labels)#

Calculates the contrastive loss between the local output, global output, and historical output.

Parameters:
  • local_out (torch.Tensor) – The local output tensor of shape (batch_size, embedding_size).

  • global_out (torch.Tensor) – The global output tensor of shape (batch_size, embedding_size).

  • historical_out (torch.Tensor) – The historical output tensor of shape (batch_size, embedding_size).

  • labels (torch.Tensor) – The ground truth labels tensor of shape (batch_size,).

Returns:

The contrastive loss value.

Return type:

torch.Tensor

Raises:

ValueError – If the input tensors have different shapes.

Notes

  • The contrastive loss is calculated as the difference between the mean cosine similarity of the local output

    with the historical output and the mean cosine similarity of the local output with the global output, multiplied by a scaling factor mu.

  • The cosine similarity values represent the similarity between the corresponding vectors in the input tensors.

Higher values indicate greater similarity, while lower values indicate less similarity.

class nebula.core.models.cifar10.dualagg.DualAggModel(input_channels=3, num_classes=10, learning_rate=0.001, mu=0.5, metrics=None, confusion_matrix=None, seed=None)#

Bases: lightning.LightningModule

process_metrics(phase, y_pred, y, loss=None, mode='local')#

Calculate and log metrics for the given phase. :param phase: One of ‘Train’, ‘Validation’, or ‘Test’ :type phase: str :param y_pred: Model predictions :type y_pred: torch.Tensor :param y: Ground truth labels :type y: torch.Tensor :param loss: Loss value :type loss: torch.Tensor, optional

log_metrics_by_epoch(phase, print_cm=False, plot_cm=False, mode='local')#

Log all metrics at the end of an epoch for the given phase. :param phase: One of ‘Train’, ‘Validation’, or ‘Test’ :type phase: str :param : param phase: :param : param plot_cm:

input_channels#
num_classes#
learning_rate#
mu#
local_train_metrics#
local_val_metrics#
local_test_metrics#
historical_train_metrics#
historical_val_metrics#
historical_test_metrics#
global_train_metrics#
global_val_metrics#
global_test_metrics#
local_epoch_global_number#
historical_epoch_global_number#
global_epoch_global_number#
config#
example_input_array#
criterion#
model#
historical_model#
global_model#
forward(x, mode='local')#

Forward pass of the model.

configure_optimizers()#
step(batch, batch_idx, phase)#
training_step(batch, batch_id)#

Training step for the model. :param batch: :param batch_id:

Returns:

on_train_epoch_end()#
validation_step(batch, batch_idx)#

Validation step for the model. :param batch: :param batch_idx:

Returns:

on_validation_epoch_end()#
test_step(batch, batch_idx)#

Test step for the model. :param batch: :param batch_idx:

Returns:

on_test_epoch_end()#
save_historical_model()#

Save the current local model as the historical model.

global_load_state_dict(state_dict)#

Load the given state dictionary into the global model. :param state_dict: The state dictionary to load into the global model. :type state_dict: dict

historical_load_state_dict(state_dict)#

Load the given state dictionary into the historical model. :param state_dict: The state dictionary to load into the historical model. :type state_dict: dict

adapt_state_dict_for_model(state_dict, model_prefix)#

Adapt the keys in the provided state_dict to match the structure expected by the model.

get_global_model_parameters()#

Get the parameters of the global model.

print_summary()#

Print a summary of local, historical and global models to check if they are the same.