- Начало работы с ML Space
- Инструкции
- Обучение моделей
- Примеры препроцессинга данных
- Установить библиотеки из Git-репозитория
- Запустить процесс обучения
- Обучить модель с использованием PyTorch Elastic Learning
- Обучить модель с использованием библиотеки Horovod
- Сохранить промежуточные результаты обучения (чекпоинты)
- Собрать и использовать кастомный Docker-образ для задачи обучения на основе базового образа платформы
- Собрать и использовать кастомный Docker-образ для задачи обучения на основе внешнего образа
- Использовать GitLab CI при работе с Environments
- Тарификация
- Термины и сокращения
- Обратиться в поддержку
Сохранить промежуточные результаты обучения (чекпоинты)
Сохранение чекпоинтов позволяет:
не потерять результаты обучения, если задача по какой-то причине остановилась, например из-за выхода за денежный лимит или ошибок обучения;
возобновить обучение модели из последнего сохраненного состояния;
делиться обученными моделями с другими пользователями, чтобы те могли восстановить объект с моделью без повторного обучения.
Рекомендуется сохранять промежуточные результаты с некоторой периодичностью, например, в конце каждой эпохи или после того, как закончится итерация по небольшим блокам обучения.
В процессе обучения модели промежуточные результаты сохраняются в рабочем каталоге пользователя /home/jovyan/. Их можно скачать через интерфейс Jupyter Notebook/JupyterLab или скопировать в хранилище S3 из локально доступной файловой системы. Подробнее о выгрузке промежуточных результатов обучения на S3 — в примерах на GitHub.
Рассмотрим, как сохранять промежуточные результаты обучения для наиболее распространенных фреймворков.
Сохранить чекпоинты в Keras
Чтобы сохранить промежуточные результаты и зафиксировать состояние модели в определенный момент, используется механизм обратных вызовов как экземпляр класса ModelCheckpoint.
# Horovod: save checkpoints only on worker 0 to prevent other workers from corrupting themif hvd.rank() == 0:callbacks.append(tf.keras.callbacks.ModelCheckpoint(os.path.join('checkpoints','checkpoint-{epoch}.h5'))# Train the model# Horovod: adjust number of steps based on number of GPUsmnist_model.fit(dataset, steps_per_epoch=500 // hvd.size(), callbacks=callbacks, epochs=24, verbose=verbose)
Сохранить чекпоинты в TensorFlow
Восстановить сеанс Tensorflow можно, использовав конструктор MonitoredTrainingSession() с аргументом checkpoint_dir.
# Horovod: Save checkpoints only on worker 0 to prevent other workers from corrupting themcheckpoint_dir = '/tmp/train_logs' if hvd.rank() == 0 else None# The MonitoredTrainingSession takes care of session initialization,# restoring from a checkpoint, saving to a checkpoint, and closing when done# or an error occurswith tf.train.MonitoredTrainingSession(checkpoint_dir=checkpoint_dir,<some_other_variables>) as sess:
Сохранить чекпоинты в PyTorch
Сохранение параметров модели с помощью torch.save().
# Save checkpoints only on worker 0 to prevent other workers from corrupting themif hvd.rank() == 0torch.save(the_model.state_dict(), PATH)
Ниже показано, как возобновить обучение модели из последнего сохраненного состояния. За основу взят следующий пример.
num_epochs = 3print(f'Start train {num_epochs} epochs total')# Loading from checkpoint# https://pytorch.org/tutorials/beginner/saving_loading_models.htmllast_epoch = 0import osfor root, dirs, files in os.walk(BASE_DIR.joinpath('logs')):saved_models = [model_filename for model_filename in files if ".bin" in model_filename]if saved_models:checkpoint = torch.load(os.path.join(root, saved_models[-1]))clf.load_state_dict(checkpoint['model_state_dict']) # Loading model weights and other training parametersoptimizer.load_state_dict(checkpoint['optimizer_state_dict'])last_epoch = checkpoint['epoch']print(f"Continue training from {last_epoch} epoch")# Start trainingmlflow.set_tracking_uri('file:/home/jovyan/mlruns')mlflow.set_experiment("pytorch_tensorboard_mlflow.ipynb")with mlflow.start_run(nested=True) as run:for epoch in range(num_epochs):if last_epoch:epoch += last_epoch + 1print("Epoch %d" % epoch)train(epoch, clf, optimizer, writer)test(epoch, clf, writer)# Save checkpoint every epochtorch.save({'epoch': epoch,'model_state_dict': clf.state_dict(),'optimizer_state_dict': optimizer.state_dict(),}, BASE_DIR.joinpath('logs/log_' + current_time + f"/model_epoch_{epoch}.bin"))writer.close()