srv-img ML Space

Сохранить промежуточные результаты обучения (чекпоинты)

Сохранение чекпоинтов позволяет:

  • не потерять результаты обучения, если задача по какой-то причине остановилась, например из-за выхода за денежный лимит или ошибок обучения;

  • возобновить обучение модели из последнего сохраненного состояния;

  • делиться обученными моделями с другими пользователями, чтобы те могли восстановить объект с моделью без повторного обучения.

Рекомендуется сохранять промежуточные результаты с некоторой периодичностью, например, в конце каждой эпохи или после того, как закончится итерация по небольшим блокам обучения.

В процессе обучения модели промежуточные результаты сохраняются в рабочем каталоге пользователя /home/jovyan/. Их можно скачать через интерфейс Jupyter Notebook/JupyterLab или скопировать в хранилище S3 из локально доступной файловой системы. Подробнее о выгрузке промежуточных результатов обучения на S3 — в примерах на GitHub.

Рассмотрим, как сохранять промежуточные результаты обучения для наиболее распространенных фреймворков.

Сохранить чекпоинты в Keras

Чтобы сохранить промежуточные результаты и зафиксировать состояние модели в определенный момент, используется механизм обратных вызовов как экземпляр класса ModelCheckpoint.

# Horovod: save checkpoints only on worker 0 to prevent other workers from corrupting them
if hvd . rank () == 0 :
callbacks . append ( tf . keras . callbacks . ModelCheckpoint ( os . path . join ( 'checkpoints' , 'checkpoint- {epoch} .h5' ))
# Train the model
# Horovod: adjust number of steps based on number of GPUs
mnist_model . fit ( dataset , steps_per_epoch = 500 // hvd . size (), callbacks = callbacks , epochs = 24 , verbose = verbose )

Сохранить чекпоинты в TensorFlow

Восстановить сеанс Tensorflow можно, использовав конструктор MonitoredTrainingSession() с аргументом checkpoint_dir.

# Horovod: Save checkpoints only on worker 0 to prevent other workers from corrupting them
checkpoint_dir = '/tmp/train_logs' if hvd . rank () == 0 else None
# The MonitoredTrainingSession takes care of session initialization,
# restoring from a checkpoint, saving to a checkpoint, and closing when done
# or an error occurs
with tf . train . MonitoredTrainingSession ( checkpoint_dir = checkpoint_dir ,
< some_other_variables > ) as sess :

Сохранить чекпоинты в PyTorch

Сохранение параметров модели с помощью torch.save().

# Save checkpoints only on worker 0 to prevent other workers from corrupting them
if hvd . rank () == 0
torch . save ( the_model . state_dict (), PATH )

Ниже показано, как возобновить обучение модели из последнего сохраненного состояния. За основу взят следующий пример.

num_epochs = 3
print ( f 'Start train { num_epochs } epochs total' )
# Loading from checkpoint
# https://pytorch.org/tutorials/beginner/saving_loading_models.html
last_epoch = 0
import os
for root , dirs , files in os . walk ( BASE_DIR . joinpath ( 'logs' )):
saved_models = [ model_filename for model_filename in files if ".bin" in model_filename ]
if saved_models :
checkpoint = torch . load ( os . path . join ( root , saved_models [ - 1 ]))
clf . load_state_dict ( checkpoint [ 'model_state_dict' ]) # Loading model weights and other training parameters
optimizer . load_state_dict ( checkpoint [ 'optimizer_state_dict' ])
last_epoch = checkpoint [ 'epoch' ]
print ( f "Continue training from { last_epoch } epoch" )
# Start training
mlflow . set_tracking_uri ( 'file:/home/jovyan/mlruns' )
mlflow . set_experiment ( "pytorch_tensorboard_mlflow.ipynb" )
with mlflow . start_run ( nested = True ) as run :
for epoch in range ( num_epochs ):
if last_epoch :
epoch += last_epoch + 1
print ( "Epoch %d " % epoch )
train ( epoch , clf , optimizer , writer )
test ( epoch , clf , writer )
# Save checkpoint every epoch
torch . save ({
'epoch' : epoch ,
'model_state_dict' : clf . state_dict (),
'optimizer_state_dict' : optimizer . state_dict (),
}, BASE_DIR . joinpath ( 'logs/log_' + current_time + f "/model_epoch_ { epoch } .bin" ))
writer . close ()
ML Space