nebula.core.training.lightning#
Attributes#
Classes#
Nebula progress bar for training. |
|
Module Contents#
- nebula.core.training.lightning.logging_training#
- class nebula.core.training.lightning.NebulaProgressBar#
Bases:
lightning.pytorch.callbacks.ProgressBar
Nebula progress bar for training. Logs the percentage of completion of the training process using logging.
- enable = True#
- disable()#
Disable the progress bar logging.
- on_train_epoch_start(trainer, pl_module)#
Called when the training epoch starts.
- on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)#
Called at the end of each training batch.
- on_train_epoch_end(trainer, pl_module)#
Called at the end of the training epoch.
- on_validation_epoch_start(trainer, pl_module)#
- on_validation_epoch_end(trainer, pl_module)#
- on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)#
- on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)#
Called at the end of each test batch.
- on_test_epoch_start(trainer, pl_module)#
- on_test_epoch_end(trainer, pl_module)#
- class nebula.core.training.lightning.Lightning(model, data, config=None, logger=None)#
- DEFAULT_MODEL_WEIGHT = 1#
- BYPASS_MODEL_WEIGHT = 0#
- model#
- data#
- config#
- epochs = 1#
- round = 0#
- property logger#
- get_round()#
- set_model(model)#
- set_data(data)#
- create_trainer()#
- validate_neighbour_model(neighbour_model_param)#
- get_hash_model()#
- Returns:
SHA256 hash of model parameters
- Return type:
str
- set_epochs(epochs)#
- serialize_model(model)#
- deserialize_model(data)#
- set_model_parameters(params, initialize=False)#
- get_model_parameters(bytes=False)#
- async train()#
- async test()#
- get_model_weight()#
- on_round_start()#
- on_round_end()#
- on_learning_cycle_end()#