optuna.samplers._grid 源代码

import collections
import itertools
import random

from optuna._experimental import experimental
from optuna.samplers import BaseSampler
from optuna import type_checking

if type_checking.TYPE_CHECKING:
    from typing import Any  # NOQA
    from typing import Dict  # NOQA
    from typing import List  # NOQA
    from typing import Mapping  # NOQA
    from typing import Sequence  # NOQA
    from typing import Union

    from optuna.distributions import BaseDistribution  # NOQA
    from optuna.study import Study  # NOQA
    from optuna.trial import FrozenTrial  # NOQA

    GridValueType = Union[str, float, int, bool, None]


[文档]@experimental("1.2.0") class GridSampler(BaseSampler): """Sampler using grid search. With :class:`~optuna.samplers.GridSampler`, the trials suggest all combinations of parameters in the given search space during the study. Example: .. testcode:: import optuna def objective(trial): x = trial.suggest_uniform('x', -100, 100) y = trial.suggest_int('y', -100, 100) return x ** 2 + y ** 2 search_space = { 'x': [-50, 0, 50], 'y': [-99, 0, 99] } study = optuna.create_study(sampler=optuna.samplers.GridSampler(search_space)) study.optimize(objective, n_trials=3*3) Note: :class:`~optuna.samplers.GridSampler` raises an error if all combinations in the passed ``search_space`` has already been evaluated. Please make sure that unnecessary trials do not run during optimization by properly setting ``n_trials`` in the :func:`~optuna.study.Study.optimize` method. Note: :class:`~optuna.samplers.GridSampler` does not take care of a parameter's quantization specified by discrete suggest methods but just samples one of values specified in the search space. E.g., in the following code snippet, either of ``-0.5`` or ``0.5`` is sampled as ``x`` instead of an integer point. .. testcode:: import optuna def objective(trial): # The following suggest method specifies integer points between -5 and 5. x = trial.suggest_discrete_uniform('x', -5, 5, 1) return x ** 2 # Non-int points are specified in the grid. search_space = {'x': [-0.5, 0.5]} study = optuna.create_study(sampler=optuna.samplers.GridSampler(search_space)) study.optimize(objective, n_trials=2) Args: search_space: A dictionary whose key and value are a parameter name and the corresponding candidates of values, respectively. """ def __init__(self, search_space): # type: (Mapping[str, Sequence[GridValueType]]) -> None for param_name, param_values in search_space.items(): for value in param_values: self._check_value(param_name, value) self._search_space = collections.OrderedDict() for param_name, param_values in sorted(search_space.items(), key=lambda x: x[0]): self._search_space[param_name] = sorted(param_values) self._all_grids = list(itertools.product(*self._search_space.values())) self._param_names = sorted(search_space.keys()) self._n_min_trials = len(self._all_grids) def infer_relative_search_space(self, study, trial): # type: (Study, FrozenTrial) -> Dict[str, BaseDistribution] return {} def sample_relative(self, study, trial, search_space): # type: (Study, FrozenTrial, Dict[str, BaseDistribution]) -> Dict[str, Any] # Instead of returning param values, GridSampler puts the target grid id as a system attr, # and the values are returned from `sample_independent`. This is because the distribution # object is hard to get at the beginning of trial, while we need the access to the object # to validate the sampled value. unvisited_grids = self._get_unvisited_grid_ids(study) if len(unvisited_grids) == 0: raise ValueError( "All grids have been evaluated. If you want to avoid this error, " "please make sure that unnecessary trials do not run during " "optimization by properly setting `n_trials` in `study.optimize`." ) # In distributed optimization, multiple workers may simultaneously pick up the same grid. # To make the conflict less frequent, the grid is chosen randomly. grid_id = random.choice(unvisited_grids) study._storage.set_trial_system_attr(trial._trial_id, "search_space", self._search_space) study._storage.set_trial_system_attr(trial._trial_id, "grid_id", grid_id) return {} def sample_independent(self, study, trial, param_name, param_distribution): # type: (Study, FrozenTrial, str, BaseDistribution) -> Any if param_name not in self._search_space: message = "The parameter name, {}, is not found in the given grid.".format(param_name) raise ValueError(message) # TODO(c-bata): Reduce the number of duplicated evaluations on multiple workers. # Current selection logic may evaluate the same parameters multiple times. # See https://gist.github.com/c-bata/f759f64becb24eea2040f4b2e3afce8f for details. grid_id = trial.system_attrs["grid_id"] param_value = self._all_grids[grid_id][self._param_names.index(param_name)] contains = param_distribution._contains(param_distribution.to_internal_repr(param_value)) if not contains: raise ValueError( "The value `{}` is out of range of the parameter `{}`. Please make " "sure the search space of the `GridSampler` only contains values " "consistent with the distribution specified in the objective " "function. The distribution is: `{}`.".format( param_value, param_name, param_distribution ) ) return param_value @staticmethod def _check_value(param_name, param_value): # type: (str, Any) -> None if param_value is None or isinstance(param_value, (str, int, float, bool)): return raise ValueError( "{} contains a value with the type of {}, which is not supported by " "`GridSampler`. Please make sure a value is `str`, `int`, `float`, `bool`" " or `None`.".format(param_name, type(param_value)) ) def _get_unvisited_grid_ids(self, study): # type: (Study) -> List[int] # List up unvisited grids based on already finished ones. visited_grids = [] for t in study.trials: if ( t.state.is_finished() and "grid_id" in t.system_attrs and self._same_search_space(t.system_attrs["search_space"]) ): visited_grids.append(t.system_attrs["grid_id"]) unvisited_grids = set(range(self._n_min_trials)) - set(visited_grids) return list(unvisited_grids) def _same_search_space(self, search_space): # type: (Mapping[str, Sequence[GridValueType]]) -> bool if set(search_space.keys()) != set(self._search_space.keys()): return False for param_name in search_space.keys(): if len(search_space[param_name]) != len(self._search_space[param_name]): return False for i, param_value in enumerate(sorted(search_space[param_name])): if param_value != self._search_space[param_name][i]: return False return True