对无望的 Trial 进行剪枝 (Pruning)¶
该功能可以在训练的早期阶段自动终止无望的 Trial (a.k.a., 自动化 early-stopping). Optuna 提供了一些接口,可以用于在迭代训练算法中简洁地实现剪枝 (Pruning)。
开启 Pruner¶
为了打开 Pruning 功能,你需要在迭代式训练的每一步完成后调用函数 report()
和 should_prune()
report()
定期监测这个过程中的目标函数值。should_prune()
根据提前定义好的条件,判定该 trial 是否需要终止。
"""filename: prune.py"""
import sklearn.datasets
import sklearn.linear_model
import sklearn.model_selection
import optuna
def objective(trial):
iris = sklearn.datasets.load_iris()
classes = list(set(iris.target))
train_x, valid_x, train_y, valid_y = \
sklearn.model_selection.train_test_split(iris.data, iris.target, test_size=0.25, random_state=0)
alpha = trial.suggest_loguniform('alpha', 1e-5, 1e-1)
clf = sklearn.linear_model.SGDClassifier(alpha=alpha)
for step in range(100):
clf.partial_fit(train_x, train_y, classes=classes)
# Report intermediate objective value.
intermediate_value = 1.0 - clf.score(valid_x, valid_y)
trial.report(intermediate_value, step)
# Handle pruning based on the intermediate value.
if trial.should_prune():
raise optuna.TrialPruned()
return 1.0 - clf.score(valid_x, valid_y)
# Set up the median stopping rule as the pruning condition.
study = optuna.create_study(pruner=optuna.pruners.MedianPruner())
study.optimize(objective, n_trials=20)
运行上述脚本:
$ python prune.py
[I 2020-06-12 16:54:23,876] Trial 0 finished with value: 0.3157894736842105 and parameters: {'alpha': 0.00181467547181131}. Best is trial 0 with value: 0.3157894736842105.
[I 2020-06-12 16:54:23,981] Trial 1 finished with value: 0.07894736842105265 and parameters: {'alpha': 0.015378744419287613}. Best is trial 1 with value: 0.07894736842105265.
[I 2020-06-12 16:54:24,083] Trial 2 finished with value: 0.21052631578947367 and parameters: {'alpha': 0.04089428832878595}. Best is trial 1 with value: 0.07894736842105265.
[I 2020-06-12 16:54:24,185] Trial 3 finished with value: 0.052631578947368474 and parameters: {'alpha': 0.004018735937374473}. Best is trial 3 with value: 0.052631578947368474.
[I 2020-06-12 16:54:24,303] Trial 4 finished with value: 0.07894736842105265 and parameters: {'alpha': 2.805688697062864e-05}. Best is trial 3 with value: 0.052631578947368474.
[I 2020-06-12 16:54:24,315] Trial 5 pruned.
[I 2020-06-12 16:54:24,355] Trial 6 pruned.
[I 2020-06-12 16:54:24,511] Trial 7 finished with value: 0.052631578947368474 and parameters: {'alpha': 2.243775785299103e-05}. Best is trial 3 with value: 0.052631578947368474.
[I 2020-06-12 16:54:24,625] Trial 8 finished with value: 0.1842105263157895 and parameters: {'alpha': 0.007021209286214553}. Best is trial 3 with value: 0.052631578947368474.
[I 2020-06-12 16:54:24,629] Trial 9 pruned.
...
我们可以在输出信息中看到 Setting status of trial#{} as TrialState.PRUNED
.这意味着这些 trial 在他们完成迭代之前就被终止了。
用于 Pruning 的集成模块¶
为了能更加方便地实现 pruning, Optuna 为以下框架提供了集成模块。
TensorFlow
optuna.integration.TensorFlowPruningHook
PyTorch Ignite
optuna.integration.PyTorchIgnitePruningHandler
PyTorch Lightning
optuna.integration.PyTorchLightningPruningCallback
比如, XGBoostPruningCallback
在无需修改训练迭代逻辑的情况下引入了 pruning.(完整脚本见 example .)
pruning_callback = optuna.integration.XGBoostPruningCallback(trial, 'validation-error')
bst = xgb.train(param, dtrain, evals=[(dvalid, 'validation')], callbacks=[pruning_callback])