from collections import defaultdict
from typing import Any
from typing import DefaultDict
from typing import Dict
from typing import List
from typing import Optional
from optuna.logging import get_logger
from optuna.study import Study
from optuna.study import StudyDirection
from optuna.trial import TrialState
from optuna.visualization._plotly_imports import _imports
if _imports.is_successful():
from optuna.visualization._plotly_imports import go
_logger = get_logger(__name__)
[文档]def plot_parallel_coordinate(study: Study, params: Optional[List[str]] = None) -> "go.Figure":
"""Plot the high-dimentional parameter relationships in a study.
Note that, If a parameter contains missing values, a trial with missing values is not plotted.
Example:
The following code snippet shows how to plot the high-dimentional parameter relationships.
.. testcode::
import optuna
def objective(trial):
x = trial.suggest_uniform('x', -100, 100)
y = trial.suggest_categorical('y', [-1, 0, 1])
return x ** 2 + y
study = optuna.create_study()
study.optimize(objective, n_trials=10)
optuna.visualization.plot_parallel_coordinate(study, params=['x', 'y'])
.. raw:: html
<iframe src="../_static/plot_parallel_coordinate.html"
width="100%" height="500px" frameborder="0">
</iframe>
Args:
study:
A :class:`~optuna.study.Study` object whose trials are plotted for their objective
values.
params:
Parameter list to visualize. The default is all parameters.
Returns:
A :class:`plotly.graph_objs.Figure` object.
"""
_imports.check()
return _get_parallel_coordinate_plot(study, params)
def _get_parallel_coordinate_plot(study: Study, params: Optional[List[str]] = None) -> "go.Figure":
layout = go.Layout(title="Parallel Coordinate Plot",)
trials = [trial for trial in study.trials if trial.state == TrialState.COMPLETE]
if len(trials) == 0:
_logger.warning("Your study does not have any completed trials.")
return go.Figure(data=[], layout=layout)
all_params = {p_name for t in trials for p_name in t.params.keys()}
if params is not None:
for input_p_name in params:
if input_p_name not in all_params:
raise ValueError("Parameter {} does not exist in your study.".format(input_p_name))
all_params = set(params)
sorted_params = sorted(list(all_params))
dims = [
{
"label": "Objective Value",
"values": tuple([t.value for t in trials]),
"range": (min([t.value for t in trials]), max([t.value for t in trials])),
}
] # type: List[Dict[str, Any]]
for p_name in sorted_params:
values = []
for t in trials:
if p_name in t.params:
values.append(t.params[p_name])
is_categorical = False
try:
tuple(map(float, values))
except (TypeError, ValueError):
vocab = defaultdict(lambda: len(vocab)) # type: DefaultDict[str, int]
values = [vocab[v] for v in values]
is_categorical = True
dim = {
"label": p_name if len(p_name) < 20 else "{}...".format(p_name[:17]),
"values": tuple(values),
"range": (min(values), max(values)),
}
if is_categorical:
dim["tickvals"] = list(range(len(vocab)))
dim["ticktext"] = list(sorted(vocab.items(), key=lambda x: x[1]))
dims.append(dim)
traces = [
go.Parcoords(
dimensions=dims,
labelangle=30,
labelside="bottom",
line={
"color": dims[0]["values"],
"colorscale": "blues",
"colorbar": {"title": "Objective Value"},
"showscale": True,
"reversescale": study.direction == StudyDirection.MINIMIZE,
},
)
]
figure = go.Figure(data=traces, layout=layout)
return figure