nebula.core.models.nebulamodel#
Attributes#
Classes#
Abstract class for the NEBULA model. |
|
Abstract class for the NEBULA model. |
Module Contents#
- nebula.core.models.nebulamodel.logging_training#
- class nebula.core.models.nebulamodel.NebulaModel(input_channels=1, num_classes=10, learning_rate=0.001, metrics=None, confusion_matrix=None, seed=None)#
Bases:
lightning.LightningModule
,abc.ABC
Abstract class for the NEBULA model.
This class is an abstract class that defines the interface for the NEBULA model.
- process_metrics(phase, y_pred, y, loss=None)#
Calculate and log metrics for the given phase. The metrics are calculated in each batch. :param phase: One of ‘Train’, ‘Validation’, or ‘Test’ :type phase: str :param y_pred: Model predictions :type y_pred: torch.Tensor :param y: Ground truth labels :type y: torch.Tensor :param loss: Loss value :type loss: torch.Tensor, optional
- log_metrics_end(phase)#
Log metrics for the given phase. :param phase: One of ‘Train’, ‘Validation’, ‘Test (Local)’, or ‘Test (Global)’ :type phase: str :param print_cm: Print confusion matrix :type print_cm: bool :param plot_cm: Plot confusion matrix :type plot_cm: bool
- generate_confusion_matrix(phase, print_cm=False, plot_cm=False)#
Generate and plot the confusion matrix for the given phase. :param phase: One of ‘Train’, ‘Validation’, ‘Test (Local)’, or ‘Test (Global)’ :type phase: str :param : param phase: :param : param print: :param : param plot:
- input_channels#
- num_classes#
- learning_rate#
- train_metrics#
- val_metrics#
- test_metrics#
- test_metrics_global#
- global_number#
- abstract forward(x)#
Forward pass of the model.
- abstract configure_optimizers()#
Optimizer configuration.
- step(batch, batch_idx, phase)#
Training/validation/test step.
- training_step(batch, batch_idx)#
Training step for the model. :param batch: :param batch_id:
Returns:
- on_train_start()#
- on_train_end()#
- on_train_epoch_end()#
- validation_step(batch, batch_idx)#
Validation step for the model. :param batch: :param batch_idx:
Returns:
- on_validation_end()#
- on_validation_epoch_end()#
- test_step(batch, batch_idx, dataloader_idx=None)#
Test step for the model. :param batch: :param batch_idx:
Returns:
- on_test_start()#
- on_test_end()#
- on_test_epoch_end()#
- class nebula.core.models.nebulamodel.NebulaModelStandalone(*args, **kwargs)#
Bases:
NebulaModel
Abstract class for the NEBULA model.
This class is an abstract class that defines the interface for the NEBULA model.
- on_train_end()#
- on_train_epoch_end()#
- on_validation_end()#
- on_validation_epoch_end()#
- on_test_end()#
- on_test_epoch_end()#