TensorFlow2入门到进阶—— 回调函数

1、回调定义

在模型训练期间的某些点调用的实用程序。

2、函数种类

class BaseLogger:累积指标的时期平均值的回调。

class CSVLogger:将纪元结果流式传输到csv文件的回调。

class Callback:用于建立新回调的抽象基类。

class EarlyStopping:当监视的数量停止改善时,停止训练。

class History:将事件记录到History对象中的回调。

class LambdaCallback:用于即时创建简单,自定义回调的回调。

class LearningRateScheduler:学习率调度程序。

class ModelCheckpoint:每个时期后保存模型。

class ProgbarLogger:将指标输出到标准输出的回调。

class ReduceLROnPlateau:当指标停止改善时,降低学习率。

class RemoteMonitor:用于将事件流传输到服务器的回调。

class TensorBoard:为TensorBoard启用可视化。

class TerminateOnNaN:当遇到NaN丢失时回调将终止训练。

3、常用函数详解

3.1 earlystopping

提前终止函数,当训练过程中达到预定条件后提前停止训练。
函数:

tf.keras.callbacks.EarlyStopping(
    monitor='val_loss', min_delta=0, patience=0, verbose=0, mode='auto',
    baseline=None, restore_best_weights=False
)

参数:

monitor:要监视的数量。
min_delta:监视数量的最小变化(有资格视为改进),即绝对变化小于min_delta,将不视为改进。
patience:没有改善的时期数,之后训练将停止。
verbose:详细模式。
mode:之一{“auto”, “min”, “max”}。在min模式下,当监视的数量停止减少时,训练将停止;在max 模式下,当监视的数量停止增加时,它将停止;在auto 模式下,将根据监视数量的名称自动推断出方向。
baseline:监视数量的基准值。如果模型没有显示出超过基线的改善,培训将停止。
restore_best_weights:是否从时期以受监视数量的最佳值恢复模型权重。如果为False,则使用在训练的最后一步获得的模型权重。
举例:在5轮迭代中,改变量小于e-3 时提前停止。
 

keras.callbacks.EarlyStopping(patience=5,min_delta=1e-3)

3.2 modelcheckpoint

保存每次迭代过程中产生的模型,需要一个文件来保存其内容。

函数:

tf.keras.callbacks.ModelCheckpoint(
    filepath, monitor='val_loss', verbose=0, save_best_only=False,
    save_weights_only=False, mode='auto', save_freq='epoch', **kwargs
)

参数:
filepath:字符串,保存模型文件的路径。
monitor:要监视的数量。
verbose:详细模式,0或1。
save_best_only:如果为save_best_only=True,则根据覆盖的数量的最新最佳模型不会被覆盖。如果filepath不包含类似格式的选项,{epoch}则 filepath每个新的更好的模型都将覆盖它们。
mode:{auto,min,max}之一。如果为save_best_only=True,则基于监视数量的最大化或最小化来决定覆盖当前保存文件。对于val_acc,这应该是max,对于val_loss这应该是min,等等。在auto 模式下,将根据监视数量的名称自动推断出方向。
save_weights_only:如果为True,则仅保存模型的权重(model.save_weights(filepath)),否则保存完整的模型(model.save(filepath))。
save_freq:‘epoch’或整数。使用时’epoch’,回调函数会在每个时期后保存模型。使用整数时,回调将在最后一次保存后看到这么多样本的批处理结束时保存模型。请注意,如果保存未与时间段保持一致,则受监视的指标可能会不太可靠(它可能只反映1个批次,因为每个时间段都会重置该指标)。默认为 ‘epoch’
**kwargs:向后兼容的其他参数。可能的键是period。

举例:

output_model_file = os.path.join(logdir,
                                'fashion_mnist_model.h5')

    #这里save_best_only改为True表示保存最好的模型,否则保存最近的模型
keras.callbacks.ModelCheckpoint(output_model_file,
                                   save_best_only = True),

3.3 tensorboard

可以提供运行过程中可视化,需要一个文件夹保存。
函数:

tf.keras.callbacks.TensorBoard(
    log_dir='logs', histogram_freq=0, write_graph=True, write_images=False,
    update_freq='epoch', profile_batch=2, embeddings_freq=0,
    embeddings_metadata=None, **kwargs
)

参数:

log_dir:将要由TensorBoard解析的日志文件保存到的目录路径。
histogram_freq:计算模型各层的激活度和权重直方图的频率(以历元计)。如果设置为0,将不计算直方图。必须为直方图可视化指定验证数据(或拆分)。
write_graph:是否在TensorBoard中可视化图形。当write_graph设置为True时,日志文件可能会变得很大。
write_images:是否编写模型权重以在TensorBoard中可视化为图像。
update_freq:‘batch’或’epoch’或整数。使用时’batch’,每批之后将损失和指标写入TensorBoard。同样适用于’epoch’。如果使用整数,假设1000,回调将每1000批将指标和损失写入TensorBoard。请注意,过于频繁地向TensorBoard写入可能会减慢您的训练速度。
profile_batch:分析批次以采样计算特征。默认情况下,它将配置第二批。将profile_batch = 0设置为禁用分析。必须在TensorFlow急切模式下运行。
embeddings_freq:嵌入层可视化的频率(以历元计)。如果设置为0,则嵌入将不可见。
embeddings_metadata:将层名称映射到文件名的字典,该嵌入层的元数据保存在该文件名中。查看 有关元数据文件格式的 详细信息。如果相同的元数据文件用于所有嵌入层,则可以传递字符串。

举例:
 

logdir = os.path.join("callbacks")
if not os.path.exists(logdir):
    os.mkdir(logdir)

 keras.callbacks.TensorBoard(logdir)

使用:
在程序中使用tensorboard后,打开电脑的powershell,将位置切换到文件夹保存的位置,之后输入如下指令:

tensorboard --logdir=callbacks

返回的一大堆数据最后,会有一个网址,用浏览器打开即可。

3.4 程序举例

当然,上面只是函数介绍,下面用一段程序举例:

#Tensorboart,earlystopping,ModelCheckpoint
#Tensorboart可以提供运行过程中可视化,需要一个文件夹保存
#ModelCheckpoint为保存每次迭代过程中产生的模型,需要一个文件
#EarlyStopping为提前停止

logdir = os.path.join("callbacks")
if not os.path.exists(logdir):
    os.mkdir(logdir)
output_model_file = os.path.join(logdir,
                                'fashion_mnist_model.h5')
callbacks=[
    keras.callbacks.TensorBoard(logdir),
    #这里save_best_only改为True表示保存最好的模型,否则保存最近的模型
    keras.callbacks.ModelCheckpoint(output_model_file,
                                   save_best_only = True),
    keras.callbacks.EarlyStopping(patience=5,min_delta=1e-3)
]

history=model.fit(x_train_scaled,y_train,epochs=10,
         validation_data=(x_valid_scaled,y_valid),
         callbacks = callbacks )

这段程序只是训练部分,如果想具体了解,请阅读本系列之前的博客。

未经允许不得转载!TensorFlow2入门到进阶—— 回调函数