PyTorch integration API reference
API reference for the NeptuneLogger class of the Neptune-PyTorch integration.
You can use theNeptuneLoggerto capture model training metadata when working with PyTorch.
NeptuneLogger
Captures model training metadata and logs them to Neptune.
Parameters
| Name | Type | Default | Description |
|---|---|---|---|
| run | RunorHandler | - | (required) An existing run reference, as returned byneptune.init_run(), or anamespace handler. |
| base_namespace | str, optional | "training" | Namespace under which all metadata logged by the Neptune logger will be stored. |
| model | torch.nn.Module | - | (required) PyTorch model object to be tracked. |
| log_model_diagram | bool, optional | False | Whether to save the model visualization. Requirestorchvizto be installed. |
| log_gradients | bool, optional | False | Whether to track the frobenius-order norm of the gradients. |
| log_parameters | bool, optional | False | Whether to track the frobenius-order norm of the parameters. |
| log_freq | int, optional | 100 | How often to log the parameters/gradients norm. Applicable only iflog_parametersorlog_gradientsis set toTrue. |
Examples
Creating a Neptune run and callback
Create a run:
import neptune
run = neptune.init_run()
As a best practice, you should save your Neptune API token and project name as environment variables:
export NEPTUNE_API_TOKEN="h0dHBzOi8aHR0cHM6Lkc78ghs74kl0jv...Yh3Kb8"
export NEPTUNE_PROJECT="ml-team/classification"
Alternatively, you can pass the information when using a function that takesapi_tokenandprojectas arguments:
run = neptune.init_run(
api_token="h0dHBzOi8aHR0cHM6Lkc78ghs74kl0jv...Yh3Kb8", # (1)!
project="ml-team/classification", # (2)!
)
- In the bottom-left corner, expand the user menu and select Get my API token .
- You can copy the path from the project details ( → Details & privacy ).
If you haven't registered, you can log anonymously to a public project:
api_token=neptune.ANONYMOUS_API_TOKEN
project="common/quickstarts"
Make sure not to publish sensitive data through your code!
Instantiate the Neptune callback:
from neptune_pytorch import NeptuneLogger
neptune_logger = NeptuneLogger(run=run, model=model)
Train your model:
for epoch in range(1, 4):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
Additional options
import neptune
from neptune_pytorch import NeptuneLogger
run = neptune.init_run(
name="My PyTorch run",
tags=["test", "pytorch"],
dependencies="infer",
)
neptune_logger = NeptuneLogger(
run=run,
model=model,
base_namespace="test",
log_model_diagram=True,
log_gradients=True,
log_parameters=True,
log_freq=50,
)
log_checkpoint()
Uploads a model checkpoint to Neptune, into a namespace calledmodel/checkpointsnested under the base namespace of the run.
The filename is set tocheckpoint_<checkpoint number>.ptby default, but can be customized.
Parameters
| Name | Type | Default | Description |
|---|---|---|---|
| checkpoint_name | str, optional | checkpoint_ |
Name for the logged checkpoint file. If left empty and the default name is used, the checkpoint number starts from 1 and is incremented automatically on each call. The extension.ptis added automatically. |
Example
from neptune_pytorch import NeptuneLogger
neptune_logger = NeptuneLogger(...)
...
for epoch in range(parameters["epochs"]):
...
neptune_logger.log_checkpoint()
log_model()
Uploads the model to Neptune, into a namespace calledmodelnested under the base namespace of the run.
The filename is set tomodel.ptby default, but can be customized.
Parameters
| Name | Type | Default | Description |
|---|---|---|---|
| model_name | str, optional | model.pt | Name for the logged model file. The extension.ptis added automatically. |
Example
from neptune_pytorch import NeptuneLogger
neptune_logger = NeptuneLogger(...)
...
neptune_logger.log_model()
save_checkpoint
Seelog_checkpoint.
save_model
Seelog_model.
See also
neptune-pytorch repo onGitHub
Related Documentation
This page is originally sourced from the legacy docs.