Сохранить промежуточные результаты обучения (чекпоинты)

Сохранение чекпоинтов позволяет:

  • не потерять результаты обучения, если задача по какой-то причине остановилась, например из-за выхода за денежный лимит или ошибок обучения;

  • возобновить обучение модели из последнего сохраненного состояния;

  • делиться обученными моделями с другими пользователями, чтобы те могли восстановить объект с моделью без повторного обучения.

Рекомендуется сохранять промежуточные результаты с некоторой периодичностью, например, в конце каждой эпохи или после того, как закончится итерация по небольшим блокам обучения.

В процессе обучения модели промежуточные результаты сохраняются в рабочем каталоге пользователя /home/jovyan/. Их можно скачать через интерфейс Jupyter Notebook/JupyterLab или скопировать в хранилище S3 из локально доступной файловой системы. Подробнее о выгрузке промежуточных результатов обучения на S3 — в примере на GitHub.

Рассмотрим, как сохранять промежуточные результаты обучения для наиболее распространенных фреймворков.

Сохранить чекпоинты в Keras

Чтобы сохранить промежуточные результаты и зафиксировать состояние модели в определенный момент, используется механизм обратных вызовов как экземпляр класса ModelCheckpoint.

# Horovod: save checkpoints only on worker 0 to prevent other workers from corrupting them
    if hvd.rank() == 0:
       callbacks.append(tf.keras.callbacks.ModelCheckpoint(os.path.join('checkpoints','checkpoint-{epoch}.h5'))

# Train the model
# Horovod: adjust number of steps based on number of GPUs
mnist_model.fit(dataset, steps_per_epoch=500 // hvd.size(), callbacks=callbacks, epochs=24, verbose=verbose)

Сохранить чекпоинты в TensorFlow

Восстановить сеанс Tensorflow можно, использовав конструктор MonitoredTrainingSession() с аргументом checkpoint_dir.

# Horovod: Save checkpoints only on worker 0 to prevent other workers from corrupting them
checkpoint_dir = '/tmp/train_logs' if hvd.rank() == 0 else None

# The MonitoredTrainingSession takes care of session initialization,
# restoring from a checkpoint, saving to a checkpoint, and closing when done
# or an error occurs
with tf.train.MonitoredTrainingSession(checkpoint_dir=checkpoint_dir,
                                       <some_other_variables>) as sess:

Сохранить чекпоинты в PyTorch

Сохранение параметров модели с помощью torch.save().

# Save checkpoints only on worker 0 to prevent other workers from corrupting them
if hvd.rank() == 0
    torch.save(the_model.state_dict(), PATH)

Ниже показано, как возобновить обучение модели из последнего сохраненного состояния. За основу взят следующий пример.

num_epochs = 3
print(f'Start train {num_epochs} epochs total')

# Loading from checkpoint
# https://pytorch.org/tutorials/beginner/saving_loading_models.html
last_epoch = 0
import os
for root, dirs, files in os.walk(BASE_DIR.joinpath('logs')):
    saved_models = [model_filename for model_filename in files if ".bin" in model_filename]

if saved_models:
    checkpoint = torch.load(os.path.join(root, saved_models[-1]))
    clf.load_state_dict(checkpoint['model_state_dict'])      # Loading model weights and other training parameters
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    last_epoch = checkpoint['epoch']
    print(f"Continue training from {last_epoch} epoch")

# Start training
mlflow.set_tracking_uri('file:/home/jovyan/mlruns')
mlflow.set_experiment("pytorch_tensorboard_mlflow.ipynb")
with mlflow.start_run(nested=True) as run:
    for epoch in range(num_epochs):
        if last_epoch:
            epoch += last_epoch + 1

        print("Epoch %d" % epoch)
        train(epoch, clf, optimizer, writer)
        test(epoch, clf, writer)
        # Save checkpoint every epoch
        torch.save({
            'epoch': epoch,
            'model_state_dict': clf.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, BASE_DIR.joinpath('logs/log_' + current_time + f"/model_epoch_{epoch}.bin"))
        writer.close()
Запустили Evolution free tier
для Dev & Test
Получить