import optuna
with optuna._imports.try_import() as _imports:
import xgboost as xgb # NOQA
def _get_callback_context(env):
# type: (xgb.core.CallbackEnv) -> str
"""Return whether the current callback context is cv or train.
.. note::
`Reference
<https://github.com/dmlc/xgboost/blob/master/python-package/xgboost/callback.py>`_.
"""
if env.model is None and env.cvfolds is not None:
context = "cv"
else:
context = "train"
return context
[文档]class XGBoostPruningCallback(object):
"""Callback for XGBoost to prune unpromising trials.
See `the example <https://github.com/optuna/optuna/blob/master/
examples/pruning/xgboost_integration.py>`__
if you want to add a pruning callback which observes validation AUC of
a XGBoost model.
Args:
trial:
A :class:`~optuna.trial.Trial` corresponding to the current evaluation of the
objective function.
observation_key:
An evaluation metric for pruning, e.g., ``validation-error`` and
``validation-merror``. When using the Scikit-Learn API, the index number of
``eval_set`` must be included in the ``observation_key``, e.g., ``validation_0-error``
and ``validation_0-merror``. Please refer to ``eval_metric`` in
`XGBoost reference <https://xgboost.readthedocs.io/en/latest/parameter.html>`_
for further details.
"""
def __init__(self, trial, observation_key):
# type: (optuna.trial.Trial, str) -> None
_imports.check()
self._trial = trial
self._observation_key = observation_key
def __call__(self, env):
# type: (xgb.core.CallbackEnv) -> None
context = _get_callback_context(env)
evaluation_result_list = env.evaluation_result_list
if context == "cv":
# Remove a third element: the stddev of the metric across the cross-valdation folds.
evaluation_result_list = [(key, metric) for key, metric, _ in evaluation_result_list]
current_score = dict(evaluation_result_list)[self._observation_key]
self._trial.report(current_score, step=env.iteration)
if self._trial.should_prune():
message = "Trial was pruned at iteration {}.".format(env.iteration)
raise optuna.TrialPruned(message)