optuna.storages._redis 源代码

import copy
from datetime import datetime
import pickle

from optuna._experimental import experimental
from optuna._imports import try_import
from optuna import distributions
from optuna import exceptions
from optuna.storages._base import DEFAULT_STUDY_NAME_PREFIX
from optuna.storages import BaseStorage
from optuna.study import StudyDirection
from optuna.study import StudySummary
from optuna.trial import FrozenTrial
from optuna.trial import TrialState
from optuna import type_checking


with try_import() as _imports:
    import redis


if type_checking.TYPE_CHECKING:
    from typing import Any  # NOQA
    from typing import Dict  # NOQA
    from typing import List  # NOQA
    from typing import Optional  # NOQA


[文档]@experimental("1.4.0") class RedisStorage(BaseStorage): """Storage class for Redis backend. Note that library users can instantiate this class, but the attributes provided by this class are not supposed to be directly accessed by them. Example: We create an :class:`~optuna.storages.RedisStorage` instance using the given redis database URL. .. code:: >>> import optuna >>> >>> def objective(trial): >>> ... >>> >>> storage = optuna.storages.RedisStorage( >>> url='redis://passwd@localhost:port/db', >>> ) >>> >>> study = optuna.create_study(storage=storage) >>> study.optimize(objective) Args: url: URL of the redis storage, password and db are optional. (ie: redis://localhost:6379) .. note:: If you use plan to use Redis as a storage mechanism for optuna, make sure Redis in installed and running. Please execute ``$ pip install -U redis`` to install redis python library. """ def __init__(self, url): # type: (str) -> None _imports.check() self._redis = redis.Redis.from_url(url) def create_new_study(self, study_name=None): # type: (Optional[str]) -> int if study_name is not None and self._redis.exists(self._key_study_name(study_name)): raise exceptions.DuplicatedStudyError if not self._redis.exists("study_counter"): # We need the counter to start with 0. self._redis.set("study_counter", -1) study_id = self._redis.incr("study_counter", 1) # We need the trial_number counter to start with 0. self._redis.set("study_id:{:010d}:trial_number".format(study_id), -1) if study_name is None: study_name = "{}{:010d}".format(DEFAULT_STUDY_NAME_PREFIX, study_id) with self._redis.pipeline() as pipe: pipe.multi() pipe.set(self._key_study_name(study_name), pickle.dumps(study_id)) pipe.set("study_id:{:010d}:study_name".format(study_id), pickle.dumps(study_name)) pipe.set( "study_id:{:010d}:direction".format(study_id), pickle.dumps(StudyDirection.NOT_SET), ) study_summary = StudySummary( study_name=study_name, direction=StudyDirection.NOT_SET, best_trial=None, user_attrs={}, system_attrs={}, n_trials=0, datetime_start=None, study_id=study_id, ) pipe.rpush("study_list", pickle.dumps(study_id)) pipe.set(self._key_study_summary(study_id), pickle.dumps(study_summary)) pipe.execute() return study_id def delete_study(self, study_id): # type: (int) -> None self._check_study_id(study_id) with self._redis.pipeline() as pipe: pipe.multi() # Sumaries pipe.delete(self._key_study_summary(study_id)) pipe.lrem("study_list", 0, pickle.dumps(study_id)) # Trials trial_ids = self._get_study_trials(study_id) for trial_id in trial_ids: pipe.delete("trial_id:{:010d}:frozentrial".format(trial_id)) pipe.delete("trial_id:{:010d}:study_id".format(trial_id)) pipe.delete("study_id:{:010d}:trial_list".format(study_id)) pipe.delete("study_id:{:010d}:trial_number".format(study_id)) # Study study_name = self.get_study_name_from_id(study_id) pipe.delete("study_name:{}:study_id".format(study_name)) pipe.delete("study_id:{:010d}:study_name".format(study_id)) pipe.delete("study_id:{:010d}:direction".format(study_id)) pipe.delete("study_id:{:010d}:best_trial_id".format(study_id)) pipe.delete("study_id:{:010d}:params_distribution".format(study_id)) pipe.execute() @staticmethod def _key_study_name(study_name): # type: (str) -> str return "study_name:{}:study_id".format(study_name) @staticmethod def _key_study_summary(study_id): # type: (int) -> str return "study_id:{:010d}:study_summary".format(study_id) def _set_study_summary(self, study_id, study_summary): # type: (int, StudySummary) -> None self._redis.set(self._key_study_summary(study_id), pickle.dumps(study_summary)) def _get_study_summary(self, study_id): # type: (int) -> StudySummary summary_pkl = self._redis.get(self._key_study_summary(study_id)) assert summary_pkl is not None return pickle.loads(summary_pkl) def _del_study_summary(self, study_id): # type: (int) -> None self._redis.delete(self._key_study_summary(study_id)) @staticmethod def _key_study_direction(study_id): # type: (int) -> str return "study_id:{:010d}:direction".format(study_id) def set_study_direction(self, study_id, direction): # type: (int, StudyDirection) -> None self._check_study_id(study_id) if self._redis.exists(self._key_study_direction(study_id)): direction_pkl = self._redis.get(self._key_study_direction(study_id)) assert direction_pkl is not None current_direction = pickle.loads(direction_pkl) if current_direction != StudyDirection.NOT_SET and current_direction != direction: raise ValueError( "Cannot overwrite study direction from {} to {}.".format( current_direction, direction ) ) with self._redis.pipeline() as pipe: pipe.multi() pipe.set(self._key_study_direction(study_id), pickle.dumps(direction)) study_summary = self._get_study_summary(study_id) study_summary.direction = direction pipe.set(self._key_study_summary(study_id), pickle.dumps(study_summary)) pipe.execute() def set_study_user_attr(self, study_id, key, value): # type: (int, str, Any) -> None self._check_study_id(study_id) study_summary = self._get_study_summary(study_id) study_summary.user_attrs[key] = value self._set_study_summary(study_id, study_summary) def set_study_system_attr(self, study_id, key, value): # type: (int, str, Any) -> None self._check_study_id(study_id) study_summary = self._get_study_summary(study_id) study_summary.system_attrs[key] = value self._set_study_summary(study_id, study_summary) def get_study_id_from_name(self, study_name): # type: (str) -> int if not self._redis.exists(self._key_study_name(study_name)): raise KeyError("No such study {}.".format(study_name)) study_id_pkl = self._redis.get(self._key_study_name(study_name)) assert study_id_pkl is not None return pickle.loads(study_id_pkl) def get_study_id_from_trial_id(self, trial_id): # type: (int) -> int study_id_pkl = self._redis.get("trial_id:{:010d}:study_id".format(trial_id)) if study_id_pkl is None: raise KeyError("No such trial: {}.".format(trial_id)) return pickle.loads(study_id_pkl) def get_study_name_from_id(self, study_id): # type: (int) -> str self._check_study_id(study_id) study_name_pkl = self._redis.get("study_id:{:010d}:study_name".format(study_id)) if study_name_pkl is None: raise KeyError("No such study: {}.".format(study_id)) return pickle.loads(study_name_pkl) def get_study_direction(self, study_id): # type: (int) -> StudyDirection direction_pkl = self._redis.get("study_id:{:010d}:direction".format(study_id)) if direction_pkl is None: raise KeyError("No such study: {}.".format(study_id)) return pickle.loads(direction_pkl) def get_study_user_attrs(self, study_id): # type: (int) -> Dict[str, Any] self._check_study_id(study_id) study_summary = self._get_study_summary(study_id) return copy.deepcopy(study_summary.user_attrs) def get_study_system_attrs(self, study_id): # type: (int) -> Dict[str, Any] self._check_study_id(study_id) study_summary = self._get_study_summary(study_id) return copy.deepcopy(study_summary.system_attrs) @staticmethod def _key_study_param_distribution(study_id): # type: (int) -> str return "study_id:{:010d}:params_distribution".format(study_id) def _get_study_param_distribution(self, study_id): # type: (int) -> Dict if self._redis.exists(self._key_study_param_distribution(study_id)): param_distribution_pkl = self._redis.get(self._key_study_param_distribution(study_id)) assert param_distribution_pkl is not None return pickle.loads(param_distribution_pkl) else: return {} def _set_study_param_distribution(self, study_id, param_distribution): # type: (int, Dict) -> None self._redis.set( self._key_study_param_distribution(study_id), pickle.dumps(param_distribution) ) def get_all_study_summaries(self): # type: () -> List[StudySummary] study_summaries = [] study_ids = [pickle.loads(sid) for sid in self._redis.lrange("study_list", 0, -1)] for study_id in study_ids: study_summary = self._get_study_summary(study_id) study_summaries.append(study_summary) return study_summaries def create_new_trial(self, study_id, template_trial=None): # type: (int, Optional[FrozenTrial]) -> int self._check_study_id(study_id) if template_trial is None: trial = self._create_running_trial() else: trial = copy.deepcopy(template_trial) if not self._redis.exists("trial_counter"): self._redis.set("trial_counter", -1) trial_id = self._redis.incr("trial_counter", 1) trial_number = self._redis.incr("study_id:{:010d}:trial_number".format(study_id)) trial.number = trial_number trial._trial_id = trial_id with self._redis.pipeline() as pipe: pipe.multi() pipe.set(self._key_trial(trial_id), pickle.dumps(trial)) pipe.set("trial_id:{:010d}:study_id".format(trial_id), pickle.dumps(study_id)) pipe.rpush("study_id:{:010d}:trial_list".format(study_id), trial_id) pipe.execute() pipe.multi() study_summary = self._get_study_summary(study_id) study_summary.n_trials = len(self._get_study_trials(study_id)) min_datetime_start = min([t.datetime_start for t in self.get_all_trials(study_id)]) study_summary.datetime_start = min_datetime_start pipe.set(self._key_study_summary(study_id), pickle.dumps(study_summary)) pipe.execute() if trial.state.is_finished(): self._update_cache(trial_id) return trial_id @staticmethod def _create_running_trial(): # type: () -> FrozenTrial return FrozenTrial( trial_id=-1, # dummy value. number=-1, # dummy value. state=TrialState.RUNNING, params={}, distributions={}, user_attrs={}, system_attrs={}, value=None, intermediate_values={}, datetime_start=datetime.now(), datetime_complete=None, ) def set_trial_state(self, trial_id, state): # type: (int, TrialState) -> bool self._check_trial_id(trial_id) trial = self.get_trial(trial_id) self.check_trial_is_updatable(trial_id, trial.state) if state == TrialState.RUNNING and trial.state != TrialState.WAITING: return False trial.state = state if state.is_finished(): trial.datetime_complete = datetime.now() self._redis.set(self._key_trial(trial_id), pickle.dumps(trial)) self._update_cache(trial_id) else: self._redis.set(self._key_trial(trial_id), pickle.dumps(trial)) return True def set_trial_param(self, trial_id, param_name, param_value_internal, distribution): # type: (int, str, float, distributions.BaseDistribution) -> None self._check_trial_id(trial_id) self.check_trial_is_updatable(trial_id, self.get_trial(trial_id).state) # Check param distribution compatibility with previous trial(s). study_id = self.get_study_id_from_trial_id(trial_id) param_distribution = self._get_study_param_distribution(study_id) if param_name in param_distribution: distributions.check_distribution_compatibility( param_distribution[param_name], distribution ) trial = self.get_trial(trial_id) with self._redis.pipeline() as pipe: pipe.multi() # Set study param distribution. param_distribution[param_name] = distribution pipe.set( self._key_study_param_distribution(study_id), pickle.dumps(param_distribution) ) # Set params. trial.params[param_name] = distribution.to_external_repr(param_value_internal) trial.distributions[param_name] = distribution pipe.set(self._key_trial(trial_id), pickle.dumps(trial)) pipe.execute() def get_trial_number_from_id(self, trial_id): # type: (int) -> int return self.get_trial(trial_id).number @staticmethod def _key_best_trial(study_id): # type: (int) -> str return "study_id:{:010d}:best_trial_id".format(study_id) def get_best_trial(self, study_id): # type: (int) -> FrozenTrial if not self._redis.exists(self._key_best_trial(study_id)): all_trials = self.get_all_trials(study_id, deepcopy=False) all_trials = [t for t in all_trials if t.state is TrialState.COMPLETE] if len(all_trials) == 0: raise ValueError("No trials are completed yet.") if self.get_study_direction(study_id) == StudyDirection.MAXIMIZE: best_trial = max(all_trials, key=lambda t: t.value) else: best_trial = min(all_trials, key=lambda t: t.value) self._set_best_trial(study_id, best_trial.number) else: best_trial_id_pkl = self._redis.get(self._key_best_trial(study_id)) assert best_trial_id_pkl is not None best_trial_id = pickle.loads(best_trial_id_pkl) best_trial = self.get_trial(best_trial_id) return best_trial def _set_best_trial(self, study_id, trial_id): # type: (int, int) -> None with self._redis.pipeline() as pipe: pipe.multi() pipe.set(self._key_best_trial(study_id), pickle.dumps(trial_id)) study_summary = self._get_study_summary(study_id) study_summary.best_trial = self.get_trial(trial_id) pipe.set(self._key_study_summary(study_id), pickle.dumps(study_summary)) pipe.execute() def get_trial_param(self, trial_id, param_name): # type: (int, str) -> float distribution = self.get_trial(trial_id).distributions[param_name] return distribution.to_internal_repr(self.get_trial(trial_id).params[param_name]) def set_trial_value(self, trial_id, value): # type: (int, float) -> None self._check_trial_id(trial_id) trial = self.get_trial(trial_id) self.check_trial_is_updatable(trial_id, trial.state) trial.value = value self._redis.set(self._key_trial(trial_id), pickle.dumps(trial)) def _update_cache(self, trial_id): # type: (int) -> None trial = self.get_trial(trial_id) if trial.state != TrialState.COMPLETE: return study_id = self.get_study_id_from_trial_id(trial_id) if not self._redis.exists("study_id:{:010d}:best_trial_id".format(study_id)): self._set_best_trial(study_id, trial_id) return best_value_or_none = self.get_best_trial(study_id).value assert best_value_or_none is not None assert trial.value is not None best_value = float(best_value_or_none) new_value = float(trial.value) # Complete trials do not have `None` values. assert new_value is not None if self.get_study_direction(study_id) == StudyDirection.MAXIMIZE: if new_value > best_value: self._set_best_trial(study_id, trial_id) else: if new_value < best_value: self._set_best_trial(study_id, trial_id) return def set_trial_intermediate_value(self, trial_id, step, intermediate_value): # type: (int, int, float) -> None self._check_trial_id(trial_id) frozen_trial = self.get_trial(trial_id) self.check_trial_is_updatable(trial_id, frozen_trial.state) frozen_trial.intermediate_values[step] = intermediate_value self._set_trial(trial_id, frozen_trial) def set_trial_user_attr(self, trial_id, key, value): # type: (int, str, Any) -> None self._check_trial_id(trial_id) trial = self.get_trial(trial_id) self.check_trial_is_updatable(trial_id, trial.state) trial.user_attrs[key] = value self._set_trial(trial_id, trial) def set_trial_system_attr(self, trial_id, key, value): # type: (int, str, Any) -> None self._check_trial_id(trial_id) trial = self.get_trial(trial_id) self.check_trial_is_updatable(trial_id, trial.state) trial.system_attrs[key] = value self._set_trial(trial_id, trial) @staticmethod def _key_trial(trial_id): # type: (int) -> str return "trial_id:{:010d}:frozentrial".format(trial_id) def get_trial(self, trial_id): # type: (int) -> FrozenTrial self._check_trial_id(trial_id) frozen_trial_pkl = self._redis.get(self._key_trial(trial_id)) assert frozen_trial_pkl is not None return pickle.loads(frozen_trial_pkl) def _set_trial(self, trial_id, trial): # type: (int, FrozenTrial) -> None self._redis.set(self._key_trial(trial_id), pickle.dumps(trial)) def _del_trial(self, trial_id): # type: (int) -> None with self._redis.pipeline() as pipe: pipe.multi() pipe.delete(self._key_trial(trial_id)) pipe.delete("trial_id:{:010d}:study_id".format(trial_id)) pipe.execute() def _get_study_trials(self, study_id): # type: (int) -> List[int] self._check_study_id(study_id) study_trial_list_key = "study_id:{:010d}:trial_list".format(study_id) return [int(tid) for tid in self._redis.lrange(study_trial_list_key, 0, -1)] def get_all_trials(self, study_id, deepcopy=True): # type: (int, bool) -> List[FrozenTrial] self._check_study_id(study_id) trials = [] trial_ids = self._get_study_trials(study_id) for trial_id in trial_ids: frozen_trial = self.get_trial(trial_id) trials.append(frozen_trial) if deepcopy: return copy.deepcopy(trials) else: return trials def get_n_trials(self, study_id, state=None): # type: (int, Optional[TrialState]) -> int self._check_study_id(study_id) if state is None: return len(self.get_all_trials(study_id)) return len([t for t in self.get_all_trials(study_id) if t.state == state]) def read_trials_from_remote_storage(self, study_id: int) -> None: self._check_study_id(study_id) def _check_study_id(self, study_id): # type: (int) -> None if not self._redis.exists("study_id:{:010d}:study_name".format(study_id)): raise KeyError("study_id {} does not exist.".format(study_id)) def _check_trial_id(self, trial_id: int) -> None: if not self._redis.exists(self._key_trial(trial_id)): raise KeyError("study_id {} does not exist.".format(trial_id))