optuna.multi_objective.visualization._pareto_front 源代码

import json
from typing import List
from typing import Optional

import optuna
from optuna._experimental import experimental
from optuna.multi_objective.study import MultiObjectiveStudy
from optuna.multi_objective.trial import FrozenMultiObjectiveTrial
from optuna.visualization._plotly_imports import _imports

if _imports.is_successful():
    from optuna.visualization._plotly_imports import go

_logger = optuna.logging.get_logger(__name__)


[文档]@experimental("2.0.0") def plot_pareto_front( study: MultiObjectiveStudy, names: Optional[List[str]] = None ) -> "go.Figure": """Plot the pareto front of a study. Example: The following code snippet shows how to plot the pareto front of a study. .. testcode:: import optuna def objective(trial): x = trial.suggest_float("x", 0, 5) y = trial.suggest_float("y", 0, 3) v0 = (4 * x) ** 2 + (4 * y) ** 2 v1 = (x - 5) ** 2 + (y - 5) ** 2 return v0, v1 study = optuna.multi_objective.create_study(["minimize", "minimize"]) study.optimize(objective, n_trials=50) optuna.multi_objective.visualization.plot_pareto_front(study) .. raw:: html <iframe src="../../_static/plot_pareto_front.html" width="100%" height="500px" frameborder="0"></iframe> Args: study: A :class:`~optuna.multi_objective.study.MultiObjectiveStudy` object whose trials are plotted for their objective values. names: Objective name list used as the axis titles. If :obj:`None` is specified, "Objective {objective_index}" is used instead. Returns: A :class:`plotly.graph_objs.Figure` object. Raises: :exc:`ValueError`: If the number of objectives of ``study`` isn't 2 or 3. """ _imports.check() if study.n_objectives == 2: return _get_pareto_front_2d(study, names) elif study.n_objectives == 3: return _get_pareto_front_3d(study, names) else: raise ValueError("`plot_pareto_front` function only supports 2 or 3 objective studies.")
def _get_pareto_front_2d(study: MultiObjectiveStudy, names: Optional[List[str]]) -> "go.Figure": if names is None: names = ["Objective 0", "Objective 1"] elif len(names) != 2: raise ValueError("The length of `names` is supposed to be 2.") trials = study.get_pareto_front_trials() if len(trials) == 0: _logger.warning("Your study does not have any completed trials.") data = go.Scatter( x=[t.values[0] for t in trials], y=[t.values[1] for t in trials], text=[_make_hovertext(t) for t in trials], mode="markers", hovertemplate="%{text}<extra></extra>", ) layout = go.Layout(title="Pareto-front Plot", xaxis_title=names[0], yaxis_title=names[1]) return go.Figure(data=data, layout=layout) def _get_pareto_front_3d(study: MultiObjectiveStudy, names: Optional[List[str]]) -> "go.Figure": if names is None: names = ["Objective 0", "Objective 1", "Objective 2"] elif len(names) != 3: raise ValueError("The length of `names` is supposed to be 3.") trials = study.get_pareto_front_trials() if len(trials) == 0: _logger.warning("Your study does not have any completed trials.") data = go.Scatter3d( x=[t.values[0] for t in trials], y=[t.values[1] for t in trials], z=[t.values[2] for t in trials], text=[_make_hovertext(t) for t in trials], mode="markers", hovertemplate="%{text}<extra></extra>", ) layout = go.Layout( title="Pareto-front Plot", scene={"xaxis_title": names[0], "yaxis_title": names[1], "zaxis_title": names[2]}, ) return go.Figure(data=data, layout=layout) def _make_hovertext(trial: FrozenMultiObjectiveTrial) -> str: text = json.dumps( {"number": trial.number, "values": trial.values, "params": trial.params}, indent=2 ) return text.replace("\n", "<br>")