The elbo.ElboEpochIterator
The iterator class is a decorator on top of a traditional Python iterator. Add this to your training epoch loop.
Usage
The elbo.EpochIterator
takes the following arguments:
The range of epochs usually -
range(0, num_epochs)
The PyTorch model it is training
save_state_interval
How often should the model state be saved to artifacts directory. A value of5
means the model state will be saved every 5 epochs.
if __name__ == '__main__':
print(f"Training MNIST classifier")
train_data = datasets.MNIST("data", train=True, transform=transforms.ToTensor(), download=True)
test_data = datasets.MNIST("data", train=False, transform=transforms.ToTensor(), download=True)
model = MNISTClassifier()
num_epochs = 10
for epoch in elbo.elbo.ElboEpochIterator(range(0, num_epochs), model, save_state_interval=1):
loss = train(model, train_data)
print(f"Epoch = {epoch} Loss = {loss}")
test(model, test_data)
Last updated
Was this helpful?