Service API Example on Hartmann6
The Ax Service API is designed to allow the user to control scheduling of trials and data computation while having an easy to use interface with Ax.
The user iteratively:
Queries Ax for candidates
Schedules / deploys them however they choose
Computes data and logs to Ax
Repeat
# only install if we are running in colab
# kaleido only for plotting
import sys
IN_COLAB = 'google.colab' in sys.modules
if IN_COLAB:
%pip install ax-platform kaleido
from ax.service.ax_client import AxClient, ObjectiveProperties
from ax.utils.measurement.synthetic_functions import hartmann6
from ax.utils.notebook.plotting import render, init_notebook_plotting
import plotly.io as pio
init_notebook_plotting()
pio.renderers.default = "colab"
[WARNING 07-19 14:05:38] ax.service.utils.with_db_settings_base: Ax currently requires a sqlalchemy version below 2.0. This will be addressed in a future release. Disabling SQL storage in Ax for now, if you would like to use SQL storage please install Ax with mysql extras via `pip install ax-platform[mysql]`.
[INFO 07-19 14:05:38] ax.utils.notebook.plotting: Injecting Plotly library into cell. Do not overwrite or delete cell.
[INFO 07-19 14:05:38] ax.utils.notebook.plotting: Please see
(https://ax.dev/tutorials/visualizations.html#Fix-for-plots-that-are-not-rendering)
if visualizations are not rendering.
1. Initialize client
Create a client object to interface with Ax APIs. By default this runs locally without storage.
ax_client = AxClient()
[INFO 07-19 14:05:38] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 6 decimal points.
2. Set up experiment
An experiment consists of a search space (parameters and parameter constraints) and optimization configuration (objectives and outcome constraints). Note that:
Only
parameters, andobjectivesarguments are required.Dictionaries in
parametershave the following required keys: “name” - parameter name, “type” - parameter type (“range”, “choice” or “fixed”), “bounds” for range parameters, “values” for choice parameters, and “value” for fixed parameters.Dictionaries in
parameterscan optionally include “value_type” (“int”, “float”, “bool” or “str”), “log_scale” flag for range parameters, and “is_ordered” flag for choice parameters.parameter_constraintsshould be a list of strings of form “p1 >= p2” or “p1 + p2 <= some_bound”.outcome_constraintsshould be a list of strings of form “constrained_metric <= some_bound”.
ax_client.create_experiment(
name="hartmann_test_experiment",
parameters=[
{
"name": "x1",
"type": "range",
"bounds": [0.0, 1.0],
"value_type": "float", # Optional, defaults to inference from type of "bounds".
"log_scale": False, # Optional, defaults to False.
},
{
"name": "x2",
"type": "range",
"bounds": [0.0, 1.0],
},
{
"name": "x3",
"type": "range",
"bounds": [0.0, 1.0],
},
{
"name": "x4",
"type": "range",
"bounds": [0.0, 1.0],
},
{
"name": "x5",
"type": "range",
"bounds": [0.0, 1.0],
},
{
"name": "x6",
"type": "range",
"bounds": [0.0, 1.0],
},
],
objectives={"hartmann6": ObjectiveProperties(minimize=True)},
parameter_constraints=["x1 + x2 <= 2.0"], # Optional.
outcome_constraints=["l2norm <= 1.25"], # Optional.
)
[INFO 07-19 14:05:38] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x2. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 07-19 14:05:38] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x3. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 07-19 14:05:38] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x4. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 07-19 14:05:38] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x5. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 07-19 14:05:38] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x6. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 07-19 14:05:38] ax.service.utils.instantiation: Created search space: SearchSpace(parameters=[RangeParameter(name='x1', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x2', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x3', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x4', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x5', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x6', parameter_type=FLOAT, range=[0.0, 1.0])], parameter_constraints=[ParameterConstraint(1.0*x1 + 1.0*x2 <= 2.0)]).
[INFO 07-19 14:05:38] ax.modelbridge.dispatch_utils: Using Models.BOTORCH_MODULAR since there is at least one ordered parameter and there are no unordered categorical parameters.
[INFO 07-19 14:05:38] ax.modelbridge.dispatch_utils: Calculating the number of remaining initialization trials based on num_initialization_trials=None max_initialization_trials=None num_tunable_parameters=6 num_trials=None use_batch_trials=False
[INFO 07-19 14:05:38] ax.modelbridge.dispatch_utils: calculated num_initialization_trials=12
[INFO 07-19 14:05:38] ax.modelbridge.dispatch_utils: num_completed_initialization_trials=0 num_remaining_initialization_trials=12
[INFO 07-19 14:05:38] ax.modelbridge.dispatch_utils: `verbose`, `disable_progbar`, and `jit_compile` are not yet supported when using `choose_generation_strategy` with ModularBoTorchModel, dropping these arguments.
[INFO 07-19 14:05:38] ax.modelbridge.dispatch_utils: Using Bayesian Optimization generation strategy: GenerationStrategy(name='Sobol+BoTorch', steps=[Sobol for 12 trials, BoTorch for subsequent trials]). Iterations after 12 will take longer to generate due to model-fitting.
3. Define how to evaluate trials
When using Ax a service, evaluation of parameterizations suggested by Ax is done either locally or, more commonly, using an external scheduler. Below is a dummy evaluation function that outputs data for two metrics “hartmann6” and “l2norm”. Note that all returned metrics correspond to either the objectives set on experiment creation or the metric names mentioned in outcome_constraints.
import numpy as np
def evaluate(parameters):
# Assuming you have a dictionary called 'parameters' with keys 'x1', 'x2', ..., 'x6'
# Iterate over the range [0, 1, 2, 3, 4, 5] and uses 'parameters.get()' to fetch values with keys like 'x1', 'x2', etc.
x = np.array([parameters.get(f"x{i+1}") for i in range(6)])
hartmann6_mean = hartmann6(x)
l2norm_mean = np.sqrt((x**2).sum())
# In our case, standard error is 0, since we are computing a synthetic function.
hartmann6_std = 0.0
l2norm_std = 0.0
return {"hartmann6": (hartmann6_mean, hartmann6_std), "l2norm": (l2norm_mean, l2norm_std)}
Result of the evaluation should generally be a mapping of the format: {metric_name -> (mean, SEM)}. If there is only one metric in the experiment – the objective – then evaluation function can return a single tuple of mean and SEM, in which case Ax will assume that evaluation corresponds to the objective. It can also return only the mean as a float, in which case Ax will treat SEM as unknown and use a model that can infer it.
For more details on evaluation function, refer to the “Trial Evaluation” section in the Ax docs at ax.dev
4. Run optimization loop
With the experiment set up, we can start the optimization loop.
At each step, the user queries the client for a new trial then submits the evaluation of that trial back to the client.
Note that Ax auto-selects an appropriate optimization algorithm based on the search space. For more advance use cases that require a specific optimization algorithm, pass a generation_strategy argument into the AxClient constructor. Note that when Bayesian Optimization is used, generating new trials may take a few minutes.
for i in range(25):
parameters, trial_index = ax_client.get_next_trial()
# Local evaluation here can be replaced with deployment to external system.
ax_client.complete_trial(trial_index=trial_index, raw_data=evaluate(parameters))
[INFO 07-19 14:05:38] ax.service.ax_client: Generated new trial 0 with parameters {'x1': 0.107889, 'x2': 0.922518, 'x3': 0.924539, 'x4': 0.936394, 'x5': 0.734736, 'x6': 0.165537} using model Sobol.
[INFO 07-19 14:05:38] ax.service.ax_client: Completed trial 0 with data: {'hartmann6': (-0.146128, 0.0), 'l2norm': (np.float64(1.778068), 0.0)}.
[INFO 07-19 14:05:38] ax.service.ax_client: Generated new trial 1 with parameters {'x1': 0.443552, 'x2': 0.488982, 'x3': 0.836648, 'x4': 0.320775, 'x5': 0.276616, 'x6': 0.706204} using model Sobol.
[INFO 07-19 14:05:38] ax.service.ax_client: Completed trial 1 with data: {'hartmann6': (-1.418863, 0.0), 'l2norm': (np.float64(1.346833), 0.0)}.
[INFO 07-19 14:05:38] ax.service.ax_client: Generated new trial 2 with parameters {'x1': 0.65414, 'x2': 0.516116, 'x3': 0.319941, 'x4': 0.942945, 'x5': 0.718464, 'x6': 0.659873} using model Sobol.
[INFO 07-19 14:05:38] ax.service.ax_client: Completed trial 2 with data: {'hartmann6': (-0.002137, 0.0), 'l2norm': (np.float64(1.624009), 0.0)}.
[INFO 07-19 14:05:38] ax.service.ax_client: Generated new trial 3 with parameters {'x1': 0.777313, 'x2': 0.713227, 'x3': 0.231293, 'x4': 0.364289, 'x5': 0.751878, 'x6': 0.620248} using model Sobol.
[INFO 07-19 14:05:39] ax.service.ax_client: Completed trial 3 with data: {'hartmann6': (-0.014062, 0.0), 'l2norm': (np.float64(1.499713), 0.0)}.
[INFO 07-19 14:05:39] ax.service.ax_client: Generated new trial 4 with parameters {'x1': 0.314862, 'x2': 0.406853, 'x3': 0.288613, 'x4': 0.288989, 'x5': 0.714391, 'x6': 0.180744} using model Sobol.
[INFO 07-19 14:05:39] ax.service.ax_client: Completed trial 4 with data: {'hartmann6': (-0.199245, 0.0), 'l2norm': (np.float64(0.987169), 0.0)}.
[INFO 07-19 14:05:39] ax.service.ax_client: Generated new trial 5 with parameters {'x1': 0.243049, 'x2': 0.871583, 'x3': 0.80612, 'x4': 0.597061, 'x5': 0.364388, 'x6': 0.141447} using model Sobol.
[INFO 07-19 14:05:39] ax.service.ax_client: Completed trial 5 with data: {'hartmann6': (-1.760953, 0.0), 'l2norm': (np.float64(1.406352), 0.0)}.
[INFO 07-19 14:05:39] ax.service.ax_client: Generated new trial 6 with parameters {'x1': 0.060675, 'x2': 0.27577, 'x3': 0.448922, 'x4': 0.104325, 'x5': 0.647423, 'x6': 0.607618} using model Sobol.
[INFO 07-19 14:05:39] ax.service.ax_client: Completed trial 6 with data: {'hartmann6': (-0.931146, 0.0), 'l2norm': (np.float64(1.039472), 0.0)}.
[INFO 07-19 14:05:39] ax.service.ax_client: Generated new trial 7 with parameters {'x1': 0.504149, 'x2': 0.67087, 'x3': 0.311699, 'x4': 0.49141, 'x5': 0.022197, 'x6': 0.17924} using model Sobol.
[INFO 07-19 14:05:39] ax.service.ax_client: Completed trial 7 with data: {'hartmann6': (-1.337291, 0.0), 'l2norm': (np.float64(1.037059), 0.0)}.
[INFO 07-19 14:05:39] ax.service.ax_client: Generated new trial 8 with parameters {'x1': 0.602793, 'x2': 0.765291, 'x3': 0.639362, 'x4': 0.046317, 'x5': 0.422133, 'x6': 0.370507} using model Sobol.
[INFO 07-19 14:05:39] ax.service.ax_client: Completed trial 8 with data: {'hartmann6': (-0.135917, 0.0), 'l2norm': (np.float64(1.294384), 0.0)}.
[INFO 07-19 14:05:39] ax.service.ax_client: Generated new trial 9 with parameters {'x1': 0.617003, 'x2': 0.130774, 'x3': 0.901677, 'x4': 0.839172, 'x5': 0.240024, 'x6': 0.922479} using model Sobol.
[INFO 07-19 14:05:39] ax.service.ax_client: Completed trial 9 with data: {'hartmann6': (-0.409906, 0.0), 'l2norm': (np.float64(1.680358), 0.0)}.
[INFO 07-19 14:05:39] ax.service.ax_client: Generated new trial 10 with parameters {'x1': 0.786644, 'x2': 0.828206, 'x3': 0.423388, 'x4': 0.919882, 'x5': 0.152168, 'x6': 0.770439} using model Sobol.
[INFO 07-19 14:05:39] ax.service.ax_client: Completed trial 10 with data: {'hartmann6': (-0.008517, 0.0), 'l2norm': (np.float64(1.716656), 0.0)}.
[INFO 07-19 14:05:39] ax.service.ax_client: Generated new trial 11 with parameters {'x1': 0.291381, 'x2': 0.749571, 'x3': 0.10217, 'x4': 0.911147, 'x5': 0.967016, 'x6': 0.310025} using model Sobol.
[INFO 07-19 14:05:39] ax.service.ax_client: Completed trial 11 with data: {'hartmann6': (-0.230108, 0.0), 'l2norm': (np.float64(1.587017), 0.0)}.
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[6], line 2
1 for i in range(25):
----> 2 parameters, trial_index = ax_client.get_next_trial()
3 # Local evaluation here can be replaced with deployment to external system.
4 ax_client.complete_trial(trial_index=trial_index, raw_data=evaluate(parameters))
File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/utils/common/executils.py:163, in retry_on_exception.<locals>.func_wrapper.<locals>.actual_wrapper(*args, **kwargs)
159 wait_interval = min(
160 MAX_WAIT_SECONDS, initial_wait_seconds * 2 ** (i - 1)
161 )
162 time.sleep(wait_interval)
--> 163 return func(*args, **kwargs)
165 # If we are here, it means the retries were finished but
166 # The error was suppressed. Hence return the default value provided.
167 return default_return_on_suppression
File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/service/ax_client.py:539, in AxClient.get_next_trial(self, ttl_seconds, force, fixed_features)
535 raise OptimizationShouldStop(message=global_stopping_message)
537 try:
538 trial = self.experiment.new_trial(
--> 539 generator_run=self._gen_new_generator_run(
540 fixed_features=fixed_features
541 ),
542 ttl_seconds=ttl_seconds,
543 )
544 except MaxParallelismReachedException as e:
545 if self._early_stopping_strategy is not None:
File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/service/ax_client.py:1790, in AxClient._gen_new_generator_run(self, n, fixed_features)
1782 fixed_feats = (
1783 InstantiationBase.make_fixed_observation_features(
1784 fixed_features=fixed_features
(...)
1787 else None
1788 )
1789 with with_rng_seed(seed=self._random_seed):
-> 1790 return not_none(self.generation_strategy).gen(
1791 experiment=self.experiment,
1792 n=n,
1793 pending_observations=self._get_pending_observation_features(
1794 experiment=self.experiment
1795 ),
1796 fixed_features=fixed_feats,
1797 )
File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/modelbridge/generation_strategy.py:370, in GenerationStrategy.gen(self, experiment, data, n, pending_observations, **kwargs)
335 def gen(
336 self,
337 experiment: Experiment,
(...)
341 **kwargs: Any,
342 ) -> GeneratorRun:
343 """Produce the next points in the experiment. Additional kwargs passed to
344 this method are propagated directly to the underlying model's `gen`, along
345 with the `model_gen_kwargs` set on the current generation node.
(...)
368 resuggesting points that are currently being evaluated.
369 """
--> 370 return self._gen_multiple(
371 experiment=experiment,
372 num_generator_runs=1,
373 data=data,
374 n=n,
375 pending_observations=pending_observations,
376 **kwargs,
377 )[0]
File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/modelbridge/generation_strategy.py:683, in GenerationStrategy._gen_multiple(self, experiment, num_generator_runs, data, n, pending_observations, **model_gen_kwargs)
681 for _ in range(num_generator_runs):
682 try:
--> 683 generator_run = self._curr.gen(
684 n=n,
685 pending_observations=pending_observations,
686 arms_by_signature_for_deduplication=experiment.arms_by_signature,
687 **model_gen_kwargs,
688 )
690 except DataRequiredError as err:
691 # Model needs more data, so we log the error and return
692 # as many generator runs as we were able to produce, unless
693 # no trials were produced at all (in which case its safe to raise).
694 if len(generator_runs) == 0:
File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/modelbridge/generation_node.py:712, in GenerationStep.gen(self, n, pending_observations, max_gen_draws_for_deduplication, arms_by_signature_for_deduplication, **model_gen_kwargs)
704 def gen(
705 self,
706 n: Optional[int] = None,
(...)
710 **model_gen_kwargs: Any,
711 ) -> GeneratorRun:
--> 712 gr = super().gen(
713 n=n,
714 pending_observations=pending_observations,
715 max_gen_draws_for_deduplication=max_gen_draws_for_deduplication,
716 arms_by_signature_for_deduplication=arms_by_signature_for_deduplication,
717 **model_gen_kwargs,
718 )
719 gr._generation_step_index = self.index
720 return gr
File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/modelbridge/generation_node.py:272, in GenerationNode.gen(self, n, pending_observations, max_gen_draws_for_deduplication, arms_by_signature_for_deduplication, **model_gen_kwargs)
269 # Keep generating until each of `generator_run.arms` is not a duplicate
270 # of a previous arm, if `should_deduplicate is True`
271 while should_generate_run:
--> 272 generator_run = self._gen(
273 n=n,
274 pending_observations=pending_observations,
275 **model_gen_kwargs,
276 )
277 should_generate_run = (
278 self.should_deduplicate
279 and arms_by_signature_for_deduplication
(...)
283 )
284 )
285 n_gen_draws += 1
File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/modelbridge/generation_node.py:334, in GenerationNode._gen(self, n, pending_observations, **model_gen_kwargs)
330 if n is None and model_spec.model_gen_kwargs:
331 # If `n` is not specified, ensure that the `None` value does not
332 # override the one set in `model_spec.model_gen_kwargs`.
333 n = model_spec.model_gen_kwargs.get("n", None)
--> 334 return model_spec.gen(
335 n=n,
336 # For `pending_observations`, prefer the input to this function, as
337 # `pending_observations` are dynamic throughout the experiment and thus
338 # unlikely to be specified in `model_spec.model_gen_kwargs`.
339 pending_observations=pending_observations,
340 **model_gen_kwargs,
341 )
File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/modelbridge/model_spec.py:221, in ModelSpec.gen(self, **model_gen_kwargs)
213 fitted_model = self.fitted_model
214 model_gen_kwargs = consolidate_kwargs(
215 kwargs_iterable=[
216 self.model_gen_kwargs,
(...)
219 keywords=get_function_argument_names(fitted_model.gen),
220 )
--> 221 return fitted_model.gen(**model_gen_kwargs)
File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/modelbridge/base.py:786, in ModelBridge.gen(self, n, search_space, optimization_config, pending_observations, fixed_features, model_gen_options)
779 base_gen_args = self._get_transformed_gen_args(
780 search_space=search_space,
781 optimization_config=optimization_config,
782 pending_observations=pending_observations,
783 fixed_features=fixed_features,
784 )
785 # Apply terminal transform and gen
--> 786 gen_results = self._gen(
787 n=n,
788 search_space=base_gen_args.search_space,
789 optimization_config=base_gen_args.optimization_config,
790 pending_observations=base_gen_args.pending_observations,
791 fixed_features=base_gen_args.fixed_features,
792 model_gen_options=model_gen_options,
793 )
795 observation_features = gen_results.observation_features
796 best_obsf = gen_results.best_observation_features
File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/modelbridge/torch.py:721, in TorchModelBridge._gen(self, n, search_space, pending_observations, fixed_features, model_gen_options, optimization_config)
716 observation_features = self._array_to_observation_features(
717 X=gen_results.points.detach().cpu().clone().numpy(),
718 candidate_metadata=gen_results.candidate_metadata,
719 )
720 try:
--> 721 xbest = not_none(self.model).best_point(
722 search_space_digest=search_space_digest,
723 torch_opt_config=torch_opt_config,
724 )
725 except NotImplementedError:
726 xbest = None
File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/models/torch/botorch_modular/model.py:79, in single_surrogate_only.<locals>.impl(self, *args, **kwargs)
74 if len(self._surrogates) != 1:
75 raise NotImplementedError(
76 f"{f.__name__} not implemented for multi-surrogate case. Found "
77 f"{self.surrogates=}."
78 )
---> 79 return f(self, *args, **kwargs)
File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/models/torch/botorch_modular/model.py:479, in BoTorchModel.best_point(self, search_space_digest, torch_opt_config)
471 @copy_doc(TorchModel.best_point)
472 @single_surrogate_only
473 def best_point(
(...)
476 torch_opt_config: TorchOptConfig,
477 ) -> Optional[Tensor]:
478 try:
--> 479 return self.surrogate.best_in_sample_point(
480 search_space_digest=search_space_digest,
481 torch_opt_config=torch_opt_config,
482 )[0]
483 except ValueError:
484 return None
File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/models/torch/botorch_modular/surrogate.py:622, in Surrogate.best_in_sample_point(self, search_space_digest, torch_opt_config, options)
618 if torch_opt_config.is_moo:
619 raise NotImplementedError(
620 "Best observed point is incompatible with MOO problems."
621 )
--> 622 best_point_and_observed_value = best_in_sample_point(
623 Xs=self.Xs,
624 model=self,
625 bounds=search_space_digest.bounds,
626 objective_weights=torch_opt_config.objective_weights,
627 outcome_constraints=torch_opt_config.outcome_constraints,
628 linear_constraints=torch_opt_config.linear_constraints,
629 fixed_features=torch_opt_config.fixed_features,
630 risk_measure=torch_opt_config.risk_measure,
631 options=options,
632 )
633 if best_point_and_observed_value is None:
634 raise ValueError("Could not obtain best in-sample point.")
File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/models/model_utils.py:464, in best_in_sample_point(Xs, model, bounds, objective_weights, outcome_constraints, linear_constraints, fixed_features, risk_measure, options)
462 raise UnsupportedError(f"Unknown best point method {method}.")
463 i = np.argmax(utility)
--> 464 if utility[i] == -np.Inf:
465 return None
466 else:
File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/numpy/__init__.py:397, in __getattr__(attr)
394 raise AttributeError(__former_attrs__[attr])
396 if attr in __expired_attributes__:
--> 397 raise AttributeError(
398 f"`np.{attr}` was removed in the NumPy 2.0 release. "
399 f"{__expired_attributes__[attr]}"
400 )
402 if attr == "chararray":
403 warnings.warn(
404 "`np.chararray` is deprecated and will be removed from "
405 "the main namespace in the future. Use an array with a string "
406 "or bytes dtype instead.", DeprecationWarning, stacklevel=2)
AttributeError: `np.Inf` was removed in the NumPy 2.0 release. Use `np.inf` instead.
How many trials can run in parallel?
By default, Ax restricts number of trials that can run in parallel for some optimization stages, in order to improve the optimization performance and reduce the number of trials that the optimization will require. To check the maximum parallelism for each optimization stage:
ax_client.get_max_parallelism()
[(12, 12), (-1, 3)]
The output of this function is a list of tuples of form (number of trials, max parallelism), so the example above means “the max parallelism is 12 for the first 12 trials and 3 for all subsequent trials.” This is because the first 12 trials are produced quasi-randomly and can all be evaluated at once, and subsequent trials are produced via Bayesian optimization, which converges on optimal point in fewer trials when parallelism is limited. MaxParallelismReachedException indicates that the parallelism limit has been reached –– refer to the ‘Service API Exceptions Meaning and Handling’ section at the end of the tutorial for handling.
How to view all existing trials during optimization?
ax_client.generation_strategy.trials_as_df
[INFO 01-05 19:07:42] ax.modelbridge.generation_strategy: Note that parameter values in dataframe are rounded to 2 decimal points; the values in the dataframe are thus not the exact ones suggested by Ax in trials.
| Generation Step | Generation Model | Trial Index | Trial Status | Arm Parameterizations | |
|---|---|---|---|---|---|
| 0 | 0 | Sobol | 0 | COMPLETED | {'0_0': {'x1': 0.16, 'x2': 0.82, 'x3': 0.55, '... |
| 1 | 0 | Sobol | 1 | COMPLETED | {'1_0': {'x1': 0.08, 'x2': 0.36, 'x3': 0.55, '... |
| 2 | 0 | Sobol | 2 | COMPLETED | {'2_0': {'x1': 0.83, 'x2': 0.78, 'x3': 0.28, '... |
| 3 | 0 | Sobol | 3 | COMPLETED | {'3_0': {'x1': 0.94, 'x2': 0.73, 'x3': 0.61, '... |
| 4 | 0 | Sobol | 4 | COMPLETED | {'4_0': {'x1': 0.28, 'x2': 0.55, 'x3': 0.82, '... |
| 5 | 0 | Sobol | 5 | COMPLETED | {'5_0': {'x1': 0.3, 'x2': 0.96, 'x3': 0.72, 'x... |
| 6 | 0 | Sobol | 6 | COMPLETED | {'6_0': {'x1': 0.31, 'x2': 0.88, 'x3': 0.87, '... |
| 7 | 0 | Sobol | 7 | COMPLETED | {'7_0': {'x1': 0.99, 'x2': 0.92, 'x3': 0.96, '... |
| 8 | 0 | Sobol | 8 | COMPLETED | {'8_0': {'x1': 0.97, 'x2': 0.52, 'x3': 0.4, 'x... |
| 9 | 0 | Sobol | 9 | COMPLETED | {'9_0': {'x1': 0.54, 'x2': 0.04, 'x3': 0.48, '... |
| 10 | 0 | Sobol | 10 | COMPLETED | {'10_0': {'x1': 0.4, 'x2': 1.0, 'x3': 0.31, 'x... |
| 11 | 0 | Sobol | 11 | COMPLETED | {'11_0': {'x1': 0.87, 'x2': 0.95, 'x3': 0.98, ... |
| 12 | 1 | BoTorch | 0 | COMPLETED | {'0_0': {'x1': 0.16, 'x2': 0.82, 'x3': 0.55, '... |
| 13 | 1 | BoTorch | 1 | COMPLETED | {'1_0': {'x1': 0.08, 'x2': 0.36, 'x3': 0.55, '... |
| 14 | 1 | BoTorch | 2 | COMPLETED | {'2_0': {'x1': 0.83, 'x2': 0.78, 'x3': 0.28, '... |
| 15 | 1 | BoTorch | 3 | COMPLETED | {'3_0': {'x1': 0.94, 'x2': 0.73, 'x3': 0.61, '... |
| 16 | 1 | BoTorch | 4 | COMPLETED | {'4_0': {'x1': 0.28, 'x2': 0.55, 'x3': 0.82, '... |
| 17 | 1 | BoTorch | 5 | COMPLETED | {'5_0': {'x1': 0.3, 'x2': 0.96, 'x3': 0.72, 'x... |
| 18 | 1 | BoTorch | 6 | COMPLETED | {'6_0': {'x1': 0.31, 'x2': 0.88, 'x3': 0.87, '... |
| 19 | 1 | BoTorch | 7 | COMPLETED | {'7_0': {'x1': 0.99, 'x2': 0.92, 'x3': 0.96, '... |
| 20 | 1 | BoTorch | 8 | COMPLETED | {'8_0': {'x1': 0.97, 'x2': 0.52, 'x3': 0.4, 'x... |
| 21 | 1 | BoTorch | 9 | COMPLETED | {'9_0': {'x1': 0.54, 'x2': 0.04, 'x3': 0.48, '... |
| 22 | 1 | BoTorch | 10 | COMPLETED | {'10_0': {'x1': 0.4, 'x2': 1.0, 'x3': 0.31, 'x... |
| 23 | 1 | BoTorch | 11 | COMPLETED | {'11_0': {'x1': 0.87, 'x2': 0.95, 'x3': 0.98, ... |
| 24 | 1 | BoTorch | 12 | COMPLETED | {'12_0': {'x1': 0.37, 'x2': 0.95, 'x3': 0.3, '... |
| 25 | 1 | BoTorch | 13 | COMPLETED | {'13_0': {'x1': 0.4, 'x2': 0.99, 'x3': 0.31, '... |
| 26 | 1 | BoTorch | 14 | COMPLETED | {'14_0': {'x1': 0.44, 'x2': 1.0, 'x3': 0.27, '... |
| 27 | 1 | BoTorch | 15 | COMPLETED | {'15_0': {'x1': 0.39, 'x2': 1.0, 'x3': 0.23, '... |
| 28 | 1 | BoTorch | 16 | COMPLETED | {'16_0': {'x1': 0.38, 'x2': 1.0, 'x3': 0.29, '... |
| 29 | 1 | BoTorch | 17 | COMPLETED | {'17_0': {'x1': 0.33, 'x2': 1.0, 'x3': 0.27, '... |
| 30 | 1 | BoTorch | 18 | COMPLETED | {'18_0': {'x1': 0.37, 'x2': 1.0, 'x3': 0.23, '... |
| 31 | 1 | BoTorch | 19 | COMPLETED | {'19_0': {'x1': 0.39, 'x2': 0.95, 'x3': 0.23, ... |
| 32 | 1 | BoTorch | 20 | COMPLETED | {'20_0': {'x1': 0.41, 'x2': 0.9, 'x3': 0.22, '... |
| 33 | 1 | BoTorch | 21 | COMPLETED | {'21_0': {'x1': 0.42, 'x2': 0.85, 'x3': 0.26, ... |
| 34 | 1 | BoTorch | 22 | COMPLETED | {'22_0': {'x1': 0.39, 'x2': 0.83, 'x3': 0.18, ... |
| 35 | 1 | BoTorch | 23 | COMPLETED | {'23_0': {'x1': 0.42, 'x2': 0.9, 'x3': 0.24, '... |
| 36 | 1 | BoTorch | 24 | COMPLETED | {'24_0': {'x1': 0.41, 'x2': 0.87, 'x3': 0.27, ... |
5. Retrieve best parameters
Once it’s complete, we can access the best parameters found, as well as the corresponding metric values.
best_parameters, values = ax_client.get_best_parameters()
best_parameters
{'x1': 0.40817620096757834,
'x2': 0.8731657716481355,
'x3': 0.27026868276996546,
'x4': 0.5723680198802129,
'x5': 0.27955297611077395,
'x6': 0.055291866177425805}
means, covariances = values
means
{'hartmann6': -3.1273235118485294, 'l2norm': 1.1878074124227307}
For comparison, Hartmann6 minimum:
hartmann6.fmin
-3.32237
6. Plot the response surface and optimization trace
Here we arbitrarily select “x1” and “x2” as the two parameters to plot for both metrics, “hartmann6” and “l2norm”.
render(ax_client.get_contour_plot())
[INFO 01-05 19:07:42] ax.service.ax_client: Retrieving contour plot with parameter 'x1' on X-axis and 'x2' on Y-axis, for metric 'hartmann6'. Remaining parameters are affixed to the middle of their range.
We can also retrieve a contour plot for the other metric, “l2norm” –– say, we are interested in seeing the response surface for parameters “x3” and “x4” for this one.
render(ax_client.get_contour_plot(param_x="x3", param_y="x4", metric_name="l2norm"))
[INFO 01-05 19:07:45] ax.service.ax_client: Retrieving contour plot with parameter 'x3' on X-axis and 'x4' on Y-axis, for metric 'l2norm'. Remaining parameters are affixed to the middle of their range.
Here we plot the optimization trace, showing the progression of finding the point with the optimal objective:
# Objective_optimum is optional.
trace = ax_client.get_optimization_trace(objective_optimum=hartmann6.fmin)
render(trace)
c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\numpy\core\_methods.py:173: RuntimeWarning:
invalid value encountered in subtract
c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\numpy\lib\function_base.py:4655: RuntimeWarning:
invalid value encountered in subtract
7. Save / reload optimization to JSON / SQL
We can serialize the state of optimization to JSON and save it to a .json file or save it to the SQL backend. For the former:
ax_client.save_to_json_file() # For custom filepath, pass `filepath` argument.
[INFO 01-05 19:07:46] ax.service.ax_client: Saved JSON-serialized state of optimization to `ax_client_snapshot.json`.
restored_ax_client = (
AxClient.load_from_json_file()
) # For custom filepath, pass `filepath` argument.
c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\ax\core\data.py:203: FutureWarning:
Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.
c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\ax\core\data.py:203: FutureWarning:
Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.
c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\ax\core\data.py:203: FutureWarning:
Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.
c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\ax\core\data.py:203: FutureWarning:
Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.
c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\ax\core\data.py:203: FutureWarning:
Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.
c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\ax\core\data.py:203: FutureWarning:
Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.
c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\ax\core\data.py:203: FutureWarning:
Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.
c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\ax\core\data.py:203: FutureWarning:
Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.
c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\ax\core\data.py:203: FutureWarning:
Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.
c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\ax\core\data.py:203: FutureWarning:
Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.
c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\ax\core\data.py:203: FutureWarning:
Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.
c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\ax\core\data.py:203: FutureWarning:
Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.
c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\ax\core\data.py:203: FutureWarning:
Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.
c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\ax\core\data.py:203: FutureWarning:
Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.
c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\ax\core\data.py:203: FutureWarning:
Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.
c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\ax\core\data.py:203: FutureWarning:
Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.
c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\ax\core\data.py:203: FutureWarning:
Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.
c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\ax\core\data.py:203: FutureWarning:
Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.
c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\ax\core\data.py:203: FutureWarning:
Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.
c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\ax\core\data.py:203: FutureWarning:
Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.
c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\ax\core\data.py:203: FutureWarning:
Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.
c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\ax\core\data.py:203: FutureWarning:
Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.
c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\ax\core\data.py:203: FutureWarning:
Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.
c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\ax\core\data.py:203: FutureWarning:
Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.
c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\ax\core\data.py:203: FutureWarning:
Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.
[INFO 01-05 19:07:46] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 6 decimal points.
To store state of optimization to an SQL backend, first follow setup instructions on Ax website.
Having set up the SQL backend, pass DBSettings to AxClient on instantiation (note that SQLAlchemy dependency will have to be installed – for installation, refer to optional dependencies on Ax website):
# NOTE: Requires running `pip install ax-platform[mysql]` and setting up a database per instructions above (out of scope for this tutorial)
# from ax.storage.sqa_store.structs import DBSettings
# # URL is of the form "dialect+driver://username:password@host:port/database".
# db_settings = DBSettings(url="sqlite:///foo.db")
# # Instead of URL, can provide a `creator function`; can specify custom encoders/decoders if necessary.
# new_ax = AxClient(db_settings=db_settings)
When valid DBSettings are passed into AxClient, a unique experiment name is a required argument (name) to ax_client.create_experiment. The state of the optimization is auto-saved any time it changes (i.e. a new trial is added or completed, etc).
To reload an optimization state later, instantiate AxClient with the same DBSettings and use ax_client.load_experiment_from_database(experiment_name="my_experiment").
Special Cases
Evaluation failure: should any optimization iterations fail during evaluation, log_trial_failure will ensure that the same trial is not proposed again.
_, trial_index = ax_client.get_next_trial()
ax_client.log_trial_failure(trial_index=trial_index)
[INFO 01-05 19:08:04] ax.service.ax_client: Generated new trial 25 with parameters {'x1': 0.407242, 'x2': 0.873801, 'x3': 0.315026, 'x4': 0.568915, 'x5': 0.315262, 'x6': 0.015128}.
[INFO 01-05 19:08:04] ax.service.ax_client: Registered failure of trial 25.
Adding custom trials: should there be need to evaluate a specific parameterization, attach_trial will add it to the experiment.
ax_client.attach_trial(
parameters={"x1": 0.9, "x2": 0.9, "x3": 0.9, "x4": 0.9, "x5": 0.9, "x6": 0.9}
)
[INFO 01-05 19:08:04] ax.core.experiment: Attached custom parameterizations [{'x1': 0.9, 'x2': 0.9, 'x3': 0.9, 'x4': 0.9, 'x5': 0.9, 'x6': 0.9}] as trial 26.
({'x1': 0.9, 'x2': 0.9, 'x3': 0.9, 'x4': 0.9, 'x5': 0.9, 'x6': 0.9}, 26)
Need to run many trials in parallel: for optimal results and optimization efficiency, we strongly recommend sequential optimization (generating a few trials, then waiting for them to be completed with evaluation data). However, if your use case needs to dispatch many trials in parallel before they are updated with data and you are running into the “All trials for current model have been generated, but not enough data has been observed to fit next model” error, instantiate AxClient as AxClient(enforce_sequential_optimization=False).