# The abstract elbo.ElboModel

## Extending your model class

Extend the abstract`ElboModel` along with `nn.Model` in your PyTorch model class. With this you will be required to implement two methods:

* `save_state` - This method should save the state of the model and other state information needed.
* `load_state` - This method should load the state of the model from the input directory.&#x20;

```python
class MNISTClassifier(ElboModel, nn.Module):
    def get_artifacts_directory(self):
        return 'artifacts'

    def save_state(self):
        model_path = os.path.join(self.get_artifacts_directory(), "mnist_model")
        torch.save(self.state_dict(), model_path)
        print(f"Saving model to {model_path}")

    def load_state(self, state_dir):
        model_path = os.path.join(self.get_artifacts_directory(), "mnist_model")
        print(f"Loading model from {model_path}")
        self.load_state_dict(torch.load(model_path))
```

{% hint style="info" %}
**Good to know:** These methods will be called by the training loop at periodic intervals to keep saving the state of training. Please make sure anything thats needed for a training to resume from a previous checkpoint is saved and loaded through this method.
{% endhint %}
