skorch integration API reference

API reference for the NeptuneLogger class of the Neptune-skorch integration.

You can use a Neptune callback to capture model training metadata when using skorch.

NeptuneLogger

CapturesNeuralNetClassifierhistory and logs the metadata to Neptune.

Parameters

Name Type Default Description
run Run - An existing run reference, as returned byneptune.init_run(), or anamespace handler.
log_on_batch_end bool, optional False Whether to log loss and other metrics on batch level.
close_after_train bool, optional True Whether to close the run object once training finishes. Set toFalseif you want to continue logging to the same run or if you use it as a context manager.
keys_ignored stror list ofstr None Key or list of keys that shouldnotbe logged to Neptune. Note that in addition to the keys provided by the user, keys such as those starting with"event_"or ending with"_best"are ignored by default.
base_namespace str, optional "training" Namespace under which all metadata logged by the Neptune callback will be stored.

Examples

Create a NeptuneLogger callback:

import neptune

neptune_logger = NeptuneLogger(neptune.init_run(), close_after_train=False)

(Optional)Set the path to the checkpoints directory:

checkpoint_dirname = "./checkpoints"
checkpoint = Checkpoint(dirname=checkpoint_dirname)

Pass the callback to the netcallbacksargument:

net = NeuralNetClassifier(
 ClassifierModule,
 max_epochs=20,
 lr=0.01,
 callbacks=[neptune_logger, checkpoint],
)

# Run training
net.fit(X, y)

Log additional metrics after training has finished:

from sklearn.metrics import roc_auc_score

y_pred = net.predict_proba(X)
auc = roc_auc_score(y, y_pred[:, 1])
neptune_logger.run["roc_auc_score"].append(auc)

Log charts, such as an ROC curve:

from scikitplot.metrics import plot_roc
import matplotlib.pyplot as plt

from neptune.types import File

fig, ax = plt.subplots(figsize=(16, 12))
plot_roc(y, y_pred, ax=ax)
neptune_logger.run["roc_curve"].upload(File.as_html(fig))

Log the net object after training:

net.save_params(f_params="basic_model.pkl")
neptune_logger.run["basic_model"].upload("basic_model.pkl")

Close the run if needed

If you setclose_after_train=False, close the run when done:

neptune_logger.run.stop()

See also

NeptuneLoggerin theskorch API reference


This page is originally sourced from the legacy docs.