import optuna
from optuna._imports import try_import
with try_import() as _imports:
import mxnet as mx # NOQA
[文档]class MXNetPruningCallback(object):
"""MXNet callback to prune unpromising trials.
See `the example <https://github.com/optuna/optuna/blob/master/
examples/pruning/mxnet_integration.py>`__
if you want to add a pruning callback which observes accuracy.
Args:
trial:
A :class:`~optuna.trial.Trial` corresponding to the current evaluation of the
objective function.
eval_metric:
An evaluation metric name for pruning, e.g., ``cross-entropy`` and
``accuracy``. If using default metrics like mxnet.metrics.Accuracy, use it's
default metric name. For custom metrics, use the metric_name provided to
constructor. Please refer to `mxnet.metrics reference
<https://mxnet.apache.org/api/python/metric/metric.html>`_ for further details.
"""
def __init__(self, trial, eval_metric):
# type: (optuna.trial.Trial, str) -> None
_imports.check()
self._trial = trial
self._eval_metric = eval_metric
def __call__(self, param):
# type: (mx.model.BatchEndParams,) -> None
if param.eval_metric is not None:
metric_names, metric_values = param.eval_metric.get()
if type(metric_names) == list and self._eval_metric in metric_names:
current_score = metric_values[metric_names.index(self._eval_metric)]
elif metric_names == self._eval_metric:
current_score = metric_values
else:
raise ValueError(
'The entry associated with the metric name "{}" '
"is not found in the evaluation result list {}.".format(
self._eval_metric, str(metric_names)
)
)
self._trial.report(current_score, step=param.epoch)
if self._trial.should_prune():
message = "Trial was pruned at epoch {}.".format(param.epoch)
raise optuna.TrialPruned(message)