nebula.core.models.cifar10.resnet#

Attributes#

Classes#

Functions#

conv_block(input_channels, num_classes[, pool])

Module Contents#

nebula.core.models.cifar10.resnet.IMAGE_SIZE = 32#
nebula.core.models.cifar10.resnet.BATCH_SIZE#
nebula.core.models.cifar10.resnet.classifiers#
nebula.core.models.cifar10.resnet.conv_block(input_channels, num_classes, pool=False)#
class nebula.core.models.cifar10.resnet.CIFAR10ModelResNet(input_channels=3, num_classes=10, learning_rate=0.001, metrics=None, confusion_matrix=None, seed=None, implementation='scratch', classifier='resnet9')#

Bases: lightning.LightningModule

process_metrics(phase, y_pred, y, loss=None)#
log_metrics_by_epoch(phase, print_cm=False, plot_cm=False)#
train_metrics#
val_metrics#
test_metrics#
implementation#
classifier#
example_input_array#
learning_rate#
criterion#
model#
epoch_global_number#
forward(x)#
configure_optimizers()#
step(batch, batch_idx, phase)#
training_step(batch, batch_id)#
on_train_epoch_end()#
validation_step(batch, batch_idx)#
on_validation_epoch_end()#
test_step(batch, batch_idx)#
on_test_epoch_end()#