The abstract elbo.ElboModel
An ElboModel is an abstract class that allows ELBO service to automatically checkpoint your training.
Extending your model class
Extend the abstractElboModel
along with nn.Model
in your PyTorch model class. With this you will be required to implement two methods:
save_state
- This method should save the state of the model and other state information needed.load_state
- This method should load the state of the model from the input directory.
class MNISTClassifier(ElboModel, nn.Module):
def get_artifacts_directory(self):
return 'artifacts'
def save_state(self):
model_path = os.path.join(self.get_artifacts_directory(), "mnist_model")
torch.save(self.state_dict(), model_path)
print(f"Saving model to {model_path}")
def load_state(self, state_dir):
model_path = os.path.join(self.get_artifacts_directory(), "mnist_model")
print(f"Loading model from {model_path}")
self.load_state_dict(torch.load(model_path))
Last updated
Was this helpful?