nebula.core.models.cifar10.resnet#
Attributes#
Classes#
Functions#
|
Module Contents#
- nebula.core.models.cifar10.resnet.IMAGE_SIZE = 32#
- nebula.core.models.cifar10.resnet.BATCH_SIZE#
- nebula.core.models.cifar10.resnet.classifiers#
- nebula.core.models.cifar10.resnet.conv_block(input_channels, num_classes, pool=False)#
- class nebula.core.models.cifar10.resnet.CIFAR10ModelResNet(input_channels=3, num_classes=10, learning_rate=0.001, metrics=None, confusion_matrix=None, seed=None, implementation='scratch', classifier='resnet9')#
Bases:
lightning.LightningModule
- process_metrics(phase, y_pred, y, loss=None)#
- log_metrics_by_epoch(phase, print_cm=False, plot_cm=False)#
- train_metrics#
- val_metrics#
- test_metrics#
- implementation#
- classifier#
- example_input_array#
- learning_rate#
- criterion#
- model#
- epoch_global_number#
- forward(x)#
- configure_optimizers()#
- step(batch, batch_idx, phase)#
- training_step(batch, batch_id)#
- on_train_epoch_end()#
- validation_step(batch, batch_idx)#
- on_validation_epoch_end()#
- test_step(batch, batch_idx)#
- on_test_epoch_end()#