LogoLogo
  • Welcome
  • The Configuration file
  • Quick Start
  • Reference
    • API Reference
      • The ELBO Tracker
      • The abstract elbo.ElboModel
      • The elbo.ElboEpochIterator
    • CLI Reference
Powered by GitBook
On this page

Was this helpful?

Export as PDF
  1. Reference
  2. API Reference

The abstract elbo.ElboModel

An ElboModel is an abstract class that allows ELBO service to automatically checkpoint your training.

Extending your model class

Extend the abstractElboModel along with nn.Model in your PyTorch model class. With this you will be required to implement two methods:

  • save_state - This method should save the state of the model and other state information needed.

  • load_state - This method should load the state of the model from the input directory.

class MNISTClassifier(ElboModel, nn.Module):
    def get_artifacts_directory(self):
        return 'artifacts'

    def save_state(self):
        model_path = os.path.join(self.get_artifacts_directory(), "mnist_model")
        torch.save(self.state_dict(), model_path)
        print(f"Saving model to {model_path}")

    def load_state(self, state_dir):
        model_path = os.path.join(self.get_artifacts_directory(), "mnist_model")
        print(f"Loading model from {model_path}")
        self.load_state_dict(torch.load(model_path))

Good to know: These methods will be called by the training loop at periodic intervals to keep saving the state of training. Please make sure anything thats needed for a training to resume from a previous checkpoint is saved and loaded through this method.

PreviousThe ELBO TrackerNextThe elbo.ElboEpochIterator

Last updated 3 years ago

Was this helpful?