The elbo.ElboEpochIterator
The iterator class is a decorator on top of a traditional Python iterator. Add this to your training epoch loop.


The elbo.EpochIterator takes the following arguments:
  • The range of epochs usually - range(0, num_epochs)
  • The PyTorch model it is training
  • save_state_interval How often should the model state be saved to artifacts directory. A value of 5 means the model state will be saved every 5 epochs.
if __name__ == '__main__':
print(f"Training MNIST classifier")
train_data = datasets.MNIST("data", train=True, transform=transforms.ToTensor(), download=True)
test_data = datasets.MNIST("data", train=False, transform=transforms.ToTensor(), download=True)
model = MNISTClassifier()
num_epochs = 10
for epoch in elbo.elbo.ElboEpochIterator(range(0, num_epochs), model, save_state_interval=1):
loss = train(model, train_data)
print(f"Epoch = {epoch} Loss = {loss}")
test(model, test_data)
Export as PDF
Copy link