optuna.trial._fixed 源代码

import datetime
from typing import Optional

from optuna import distributions
from optuna.distributions import CategoricalDistribution
from optuna.distributions import DiscreteUniformDistribution
from optuna.distributions import IntLogUniformDistribution
from optuna.distributions import IntUniformDistribution
from optuna.distributions import LogUniformDistribution
from optuna.distributions import UniformDistribution
from optuna.trial._base import BaseTrial
from optuna import type_checking

if type_checking.TYPE_CHECKING:
    from typing import Any  # NOQA
    from typing import Dict  # NOQA
    from typing import Sequence  # NOQA
    from typing import Union  # NOQA

    from optuna.distributions import BaseDistribution  # NOQA
    from optuna.distributions import CategoricalChoiceType  # NOQA

    FloatingPointDistributionType = Union[UniformDistribution, LogUniformDistribution]


[文档]class FixedTrial(BaseTrial): """A trial class which suggests a fixed value for each parameter. This object has the same methods as :class:`~optuna.trial.Trial`, and it suggests pre-defined parameter values. The parameter values can be determined at the construction of the :class:`~optuna.trial.FixedTrial` object. In contrast to :class:`~optuna.trial.Trial`, :class:`~optuna.trial.FixedTrial` does not depend on :class:`~optuna.study.Study`, and it is useful for deploying optimization results. Example: Evaluate an objective function with parameter values given by a user. .. testcode:: import optuna def objective(trial): x = trial.suggest_uniform('x', -100, 100) y = trial.suggest_categorical('y', [-1, 0, 1]) return x ** 2 + y assert objective(optuna.trial.FixedTrial({'x': 1, 'y': 0})) == 1 .. note:: Please refer to :class:`~optuna.trial.Trial` for details of methods and properties. Args: params: A dictionary containing all parameters. number: A trial number. Defaults to ``0``. """ def __init__(self, params, number=0): # type: (Dict[str, Any], int) -> None self._params = params self._suggested_params = {} # type: Dict[str, Any] self._distributions = {} # type: Dict[str, BaseDistribution] self._user_attrs = {} # type: Dict[str, Any] self._system_attrs = {} # type: Dict[str, Any] self._datetime_start = datetime.datetime.now() self._number = number def suggest_float( self, name: str, low: float, high: float, *, step: Optional[float] = None, log: bool = False ) -> float: if step is not None: if log: raise ValueError("The parameter `step` is not supported when `log` is True.") else: return self._suggest(name, DiscreteUniformDistribution(low=low, high=high, q=step)) else: if log: return self._suggest(name, LogUniformDistribution(low=low, high=high)) else: return self._suggest(name, UniformDistribution(low=low, high=high)) def suggest_uniform(self, name, low, high): # type: (str, float, float) -> float return self._suggest(name, UniformDistribution(low=low, high=high)) def suggest_loguniform(self, name, low, high): # type: (str, float, float) -> float return self._suggest(name, LogUniformDistribution(low=low, high=high)) def suggest_discrete_uniform(self, name: str, low: float, high: float, q: float) -> float: discrete = DiscreteUniformDistribution(low=low, high=high, q=q) return self._suggest(name, discrete) def suggest_int(self, name: str, low: int, high: int, step: int = 1, log: bool = False) -> int: if step != 1: if log: raise ValueError( "The parameter `step != 1` is not supported when `log` is True." "The specified `step` is {}.".format(step) ) else: distribution = IntUniformDistribution( low=low, high=high, step=step ) # type: Union[IntUniformDistribution, IntLogUniformDistribution] else: if log: distribution = IntLogUniformDistribution(low=low, high=high, step=step) else: distribution = IntUniformDistribution(low=low, high=high, step=step) return int(self._suggest(name, distribution)) def suggest_categorical(self, name, choices): # type: (str, Sequence[CategoricalChoiceType]) -> CategoricalChoiceType choices = tuple(choices) return self._suggest(name, CategoricalDistribution(choices=choices)) def _suggest(self, name, distribution): # type: (str, BaseDistribution) -> Any if name not in self._params: raise ValueError( "The value of the parameter '{}' is not found. Please set it at " "the construction of the FixedTrial object.".format(name) ) value = self._params[name] param_value_in_internal_repr = distribution.to_internal_repr(value) if not distribution._contains(param_value_in_internal_repr): raise ValueError( "The value {} of the parameter '{}' is out of " "the range of the distribution {}.".format(value, name, distribution) ) if name in self._distributions: distributions.check_distribution_compatibility(self._distributions[name], distribution) self._suggested_params[name] = value self._distributions[name] = distribution return value def report(self, value, step): # type: (float, int) -> None pass def should_prune(self) -> bool: return False def set_user_attr(self, key, value): # type: (str, Any) -> None self._user_attrs[key] = value def set_system_attr(self, key, value): # type: (str, Any) -> None self._system_attrs[key] = value @property def params(self): # type: () -> Dict[str, Any] return self._suggested_params @property def distributions(self): # type: () -> Dict[str, BaseDistribution] return self._distributions @property def user_attrs(self): # type: () -> Dict[str, Any] return self._user_attrs @property def system_attrs(self): # type: () -> Dict[str, Any] return self._system_attrs @property def datetime_start(self): # type: () -> Optional[datetime.datetime] return self._datetime_start @property def number(self) -> int: return self._number