import collections
import itertools
import random
from optuna._experimental import experimental
from optuna.samplers import BaseSampler
from optuna import type_checking
if type_checking.TYPE_CHECKING:
    from typing import Any  # NOQA
    from typing import Dict  # NOQA
    from typing import List  # NOQA
    from typing import Mapping  # NOQA
    from typing import Sequence  # NOQA
    from typing import Union
    from optuna.distributions import BaseDistribution  # NOQA
    from optuna.study import Study  # NOQA
    from optuna.trial import FrozenTrial  # NOQA
    GridValueType = Union[str, float, int, bool, None]
[文档]@experimental("1.2.0")
class GridSampler(BaseSampler):
    """Sampler using grid search.
    With :class:`~optuna.samplers.GridSampler`, the trials suggest all combinations of parameters
    in the given search space during the study.
    Example:
        .. testcode::
            import optuna
            def objective(trial):
                x = trial.suggest_uniform('x', -100, 100)
                y = trial.suggest_int('y', -100, 100)
                return x ** 2 + y ** 2
            search_space = {
                'x': [-50, 0, 50],
                'y': [-99, 0, 99]
            }
            study = optuna.create_study(sampler=optuna.samplers.GridSampler(search_space))
            study.optimize(objective, n_trials=3*3)
    Note:
        :class:`~optuna.samplers.GridSampler` raises an error if all combinations in the passed
        ``search_space`` has already been evaluated. Please make sure that unnecessary trials do
        not run during optimization by properly setting ``n_trials`` in the
        :func:`~optuna.study.Study.optimize` method.
    Note:
        :class:`~optuna.samplers.GridSampler` does not take care of a parameter's quantization
        specified by discrete suggest methods but just samples one of values specified in the
        search space. E.g., in the following code snippet, either of ``-0.5`` or ``0.5`` is
        sampled as ``x`` instead of an integer point.
        .. testcode::
            import optuna
            def objective(trial):
                # The following suggest method specifies integer points between -5 and 5.
                x = trial.suggest_discrete_uniform('x', -5, 5, 1)
                return x ** 2
            # Non-int points are specified in the grid.
            search_space = {'x': [-0.5, 0.5]}
            study = optuna.create_study(sampler=optuna.samplers.GridSampler(search_space))
            study.optimize(objective, n_trials=2)
    Args:
        search_space:
            A dictionary whose key and value are a parameter name and the corresponding candidates
            of values, respectively.
    """
    def __init__(self, search_space):
        # type: (Mapping[str, Sequence[GridValueType]]) -> None
        for param_name, param_values in search_space.items():
            for value in param_values:
                self._check_value(param_name, value)
        self._search_space = collections.OrderedDict()
        for param_name, param_values in sorted(search_space.items(), key=lambda x: x[0]):
            self._search_space[param_name] = sorted(param_values)
        self._all_grids = list(itertools.product(*self._search_space.values()))
        self._param_names = sorted(search_space.keys())
        self._n_min_trials = len(self._all_grids)
    def infer_relative_search_space(self, study, trial):
        # type: (Study, FrozenTrial) -> Dict[str, BaseDistribution]
        return {}
    def sample_relative(self, study, trial, search_space):
        # type: (Study, FrozenTrial, Dict[str, BaseDistribution]) -> Dict[str, Any]
        # Instead of returning param values, GridSampler puts the target grid id as a system attr,
        # and the values are returned from `sample_independent`. This is because the distribution
        # object is hard to get at the beginning of trial, while we need the access to the object
        # to validate the sampled value.
        unvisited_grids = self._get_unvisited_grid_ids(study)
        if len(unvisited_grids) == 0:
            raise ValueError(
                "All grids have been evaluated. If you want to avoid this error, "
                "please make sure that unnecessary trials do not run during "
                "optimization by properly setting `n_trials` in `study.optimize`."
            )
        # In distributed optimization, multiple workers may simultaneously pick up the same grid.
        # To make the conflict less frequent, the grid is chosen randomly.
        grid_id = random.choice(unvisited_grids)
        study._storage.set_trial_system_attr(trial._trial_id, "search_space", self._search_space)
        study._storage.set_trial_system_attr(trial._trial_id, "grid_id", grid_id)
        return {}
    def sample_independent(self, study, trial, param_name, param_distribution):
        # type: (Study, FrozenTrial, str, BaseDistribution) -> Any
        if param_name not in self._search_space:
            message = "The parameter name, {}, is not found in the given grid.".format(param_name)
            raise ValueError(message)
        # TODO(c-bata): Reduce the number of duplicated evaluations on multiple workers.
        # Current selection logic may evaluate the same parameters multiple times.
        # See https://gist.github.com/c-bata/f759f64becb24eea2040f4b2e3afce8f for details.
        grid_id = trial.system_attrs["grid_id"]
        param_value = self._all_grids[grid_id][self._param_names.index(param_name)]
        contains = param_distribution._contains(param_distribution.to_internal_repr(param_value))
        if not contains:
            raise ValueError(
                "The value `{}` is out of range of the parameter `{}`. Please make "
                "sure the search space of the `GridSampler` only contains values "
                "consistent with the distribution specified in the objective "
                "function. The distribution is: `{}`.".format(
                    param_value, param_name, param_distribution
                )
            )
        return param_value
    @staticmethod
    def _check_value(param_name, param_value):
        # type: (str, Any) -> None
        if param_value is None or isinstance(param_value, (str, int, float, bool)):
            return
        raise ValueError(
            "{} contains a value with the type of {}, which is not supported by "
            "`GridSampler`. Please make sure a value is `str`, `int`, `float`, `bool`"
            " or `None`.".format(param_name, type(param_value))
        )
    def _get_unvisited_grid_ids(self, study):
        # type: (Study) -> List[int]
        # List up unvisited grids based on already finished ones.
        visited_grids = []
        for t in study.trials:
            if (
                t.state.is_finished()
                and "grid_id" in t.system_attrs
                and self._same_search_space(t.system_attrs["search_space"])
            ):
                visited_grids.append(t.system_attrs["grid_id"])
        unvisited_grids = set(range(self._n_min_trials)) - set(visited_grids)
        return list(unvisited_grids)
    def _same_search_space(self, search_space):
        # type: (Mapping[str, Sequence[GridValueType]]) -> bool
        if set(search_space.keys()) != set(self._search_space.keys()):
            return False
        for param_name in search_space.keys():
            if len(search_space[param_name]) != len(self._search_space[param_name]):
                return False
            for i, param_value in enumerate(sorted(search_space[param_name])):
                if param_value != self._search_space[param_name][i]:
                    return False
        return True