The elbo.ElboEpochIterator

The iterator class is a decorator on top of a traditional Python iterator. Add this to your training epoch loop.

Usage

The elbo.EpochIterator takes the following arguments:

  • The range of epochs usually - range(0, num_epochs)

  • The PyTorch model it is training

  • save_state_interval How often should the model state be saved to artifacts directory. A value of 5 means the model state will be saved every 5 epochs.

if __name__ == '__main__':
    print(f"Training MNIST classifier")
    train_data = datasets.MNIST("data", train=True, transform=transforms.ToTensor(), download=True)
    test_data = datasets.MNIST("data", train=False, transform=transforms.ToTensor(), download=True)
    model = MNISTClassifier()
    num_epochs = 10

    for epoch in elbo.elbo.ElboEpochIterator(range(0, num_epochs), model, save_state_interval=1):
        loss = train(model, train_data)
        print(f"Epoch = {epoch} Loss = {loss}")

    test(model, test_data)

Last updated