3.1 Ax Platform Basic Usage
This notebook uses the Ax Platform’s Service API to perform Bayesian optimization. Here, we use an analytic function called the Branin function, which is often used for optimization benchmarking. It takes two parameters as inputs, and returns one value as the output. The task here is to optimize the parameter inputs to minimize the function value.
# only install if we are running in colab
import sys
IN_COLAB = 'google.colab' in sys.modules
if IN_COLAB:
%pip install ax-platform
import math
from ax.service.ax_client import AxClient, ObjectiveProperties
obj1_name = "branin"
def branin(x1, x2):
y = float(
(x2 - 5.1 / (4 * math.pi**2) * x1**2 + 5.0 / math.pi * x1 - 6.0) ** 2
+ 10 * (1 - 1.0 / (8 * math.pi)) * math.cos(x1)
+ 10
)
return y
ax_client = AxClient()
ax_client.create_experiment(
parameters=[
{"name": "x1", "type": "range", "bounds": [-5.0, 10.0]},
{"name": "x2", "type": "range", "bounds": [0.0, 15.0]},
],
objectives={
obj1_name: ObjectiveProperties(minimize=True),
},
)
for _ in range(15):
parameters, trial_index = ax_client.get_next_trial()
results = branin(
parameters["x1"],
parameters["x2"],
)
ax_client.complete_trial(trial_index=trial_index, raw_data=results)
best_parameters, metrics = ax_client.get_best_parameters()
[WARNING 07-19 14:05:57] ax.service.utils.with_db_settings_base: Ax currently requires a sqlalchemy version below 2.0. This will be addressed in a future release. Disabling SQL storage in Ax for now, if you would like to use SQL storage please install Ax with mysql extras via `pip install ax-platform[mysql]`.
[INFO 07-19 14:05:57] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 6 decimal points.
[INFO 07-19 14:05:57] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x1. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 07-19 14:05:57] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x2. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 07-19 14:05:57] ax.service.utils.instantiation: Created search space: SearchSpace(parameters=[RangeParameter(name='x1', parameter_type=FLOAT, range=[-5.0, 10.0]), RangeParameter(name='x2', parameter_type=FLOAT, range=[0.0, 15.0])], parameter_constraints=[]).
[INFO 07-19 14:05:57] ax.modelbridge.dispatch_utils: Using Models.BOTORCH_MODULAR since there is at least one ordered parameter and there are no unordered categorical parameters.
[INFO 07-19 14:05:57] ax.modelbridge.dispatch_utils: Calculating the number of remaining initialization trials based on num_initialization_trials=None max_initialization_trials=None num_tunable_parameters=2 num_trials=None use_batch_trials=False
[INFO 07-19 14:05:57] ax.modelbridge.dispatch_utils: calculated num_initialization_trials=5
[INFO 07-19 14:05:57] ax.modelbridge.dispatch_utils: num_completed_initialization_trials=0 num_remaining_initialization_trials=5
[INFO 07-19 14:05:57] ax.modelbridge.dispatch_utils: `verbose`, `disable_progbar`, and `jit_compile` are not yet supported when using `choose_generation_strategy` with ModularBoTorchModel, dropping these arguments.
[INFO 07-19 14:05:57] ax.modelbridge.dispatch_utils: Using Bayesian Optimization generation strategy: GenerationStrategy(name='Sobol+BoTorch', steps=[Sobol for 5 trials, BoTorch for subsequent trials]). Iterations after 5 will take longer to generate due to model-fitting.
[INFO 07-19 14:05:57] ax.service.ax_client: Generated new trial 0 with parameters {'x1': 7.46005, 'x2': 11.050261} using model Sobol.
[INFO 07-19 14:05:57] ax.service.ax_client: Completed trial 0 with data: {'branin': (108.433893, None)}.
[INFO 07-19 14:05:57] ax.service.ax_client: Generated new trial 1 with parameters {'x1': 9.75531, 'x2': 0.75612} using model Sobol.
[INFO 07-19 14:05:57] ax.service.ax_client: Completed trial 1 with data: {'branin': (4.965009, None)}.
[INFO 07-19 14:05:57] ax.service.ax_client: Generated new trial 2 with parameters {'x1': 6.443358, 'x2': 2.343874} using model Sobol.
[INFO 07-19 14:05:57] ax.service.ax_client: Completed trial 2 with data: {'branin': (21.005569, None)}.
[INFO 07-19 14:05:57] ax.service.ax_client: Generated new trial 3 with parameters {'x1': 5.215967, 'x2': 8.146018} using model Sobol.
[INFO 07-19 14:05:57] ax.service.ax_client: Completed trial 3 with data: {'branin': (62.69807, None)}.
[INFO 07-19 14:05:57] ax.service.ax_client: Generated new trial 4 with parameters {'x1': 8.841575, 'x2': 8.930721} using model Sobol.
[INFO 07-19 14:05:57] ax.service.ax_client: Completed trial 4 with data: {'branin': (49.646523, None)}.
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[2], line 27
16 ax_client.create_experiment(
17 parameters=[
18 {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]},
(...)
23 },
24 )
26 for _ in range(15):
---> 27 parameters, trial_index = ax_client.get_next_trial()
28 results = branin(
29 parameters["x1"],
30 parameters["x2"],
31 )
32 ax_client.complete_trial(trial_index=trial_index, raw_data=results)
File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/utils/common/executils.py:163, in retry_on_exception.<locals>.func_wrapper.<locals>.actual_wrapper(*args, **kwargs)
159 wait_interval = min(
160 MAX_WAIT_SECONDS, initial_wait_seconds * 2 ** (i - 1)
161 )
162 time.sleep(wait_interval)
--> 163 return func(*args, **kwargs)
165 # If we are here, it means the retries were finished but
166 # The error was suppressed. Hence return the default value provided.
167 return default_return_on_suppression
File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/service/ax_client.py:539, in AxClient.get_next_trial(self, ttl_seconds, force, fixed_features)
535 raise OptimizationShouldStop(message=global_stopping_message)
537 try:
538 trial = self.experiment.new_trial(
--> 539 generator_run=self._gen_new_generator_run(
540 fixed_features=fixed_features
541 ),
542 ttl_seconds=ttl_seconds,
543 )
544 except MaxParallelismReachedException as e:
545 if self._early_stopping_strategy is not None:
File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/service/ax_client.py:1790, in AxClient._gen_new_generator_run(self, n, fixed_features)
1782 fixed_feats = (
1783 InstantiationBase.make_fixed_observation_features(
1784 fixed_features=fixed_features
(...)
1787 else None
1788 )
1789 with with_rng_seed(seed=self._random_seed):
-> 1790 return not_none(self.generation_strategy).gen(
1791 experiment=self.experiment,
1792 n=n,
1793 pending_observations=self._get_pending_observation_features(
1794 experiment=self.experiment
1795 ),
1796 fixed_features=fixed_feats,
1797 )
File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/modelbridge/generation_strategy.py:370, in GenerationStrategy.gen(self, experiment, data, n, pending_observations, **kwargs)
335 def gen(
336 self,
337 experiment: Experiment,
(...)
341 **kwargs: Any,
342 ) -> GeneratorRun:
343 """Produce the next points in the experiment. Additional kwargs passed to
344 this method are propagated directly to the underlying model's `gen`, along
345 with the `model_gen_kwargs` set on the current generation node.
(...)
368 resuggesting points that are currently being evaluated.
369 """
--> 370 return self._gen_multiple(
371 experiment=experiment,
372 num_generator_runs=1,
373 data=data,
374 n=n,
375 pending_observations=pending_observations,
376 **kwargs,
377 )[0]
File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/modelbridge/generation_strategy.py:683, in GenerationStrategy._gen_multiple(self, experiment, num_generator_runs, data, n, pending_observations, **model_gen_kwargs)
681 for _ in range(num_generator_runs):
682 try:
--> 683 generator_run = self._curr.gen(
684 n=n,
685 pending_observations=pending_observations,
686 arms_by_signature_for_deduplication=experiment.arms_by_signature,
687 **model_gen_kwargs,
688 )
690 except DataRequiredError as err:
691 # Model needs more data, so we log the error and return
692 # as many generator runs as we were able to produce, unless
693 # no trials were produced at all (in which case its safe to raise).
694 if len(generator_runs) == 0:
File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/modelbridge/generation_node.py:712, in GenerationStep.gen(self, n, pending_observations, max_gen_draws_for_deduplication, arms_by_signature_for_deduplication, **model_gen_kwargs)
704 def gen(
705 self,
706 n: Optional[int] = None,
(...)
710 **model_gen_kwargs: Any,
711 ) -> GeneratorRun:
--> 712 gr = super().gen(
713 n=n,
714 pending_observations=pending_observations,
715 max_gen_draws_for_deduplication=max_gen_draws_for_deduplication,
716 arms_by_signature_for_deduplication=arms_by_signature_for_deduplication,
717 **model_gen_kwargs,
718 )
719 gr._generation_step_index = self.index
720 return gr
File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/modelbridge/generation_node.py:272, in GenerationNode.gen(self, n, pending_observations, max_gen_draws_for_deduplication, arms_by_signature_for_deduplication, **model_gen_kwargs)
269 # Keep generating until each of `generator_run.arms` is not a duplicate
270 # of a previous arm, if `should_deduplicate is True`
271 while should_generate_run:
--> 272 generator_run = self._gen(
273 n=n,
274 pending_observations=pending_observations,
275 **model_gen_kwargs,
276 )
277 should_generate_run = (
278 self.should_deduplicate
279 and arms_by_signature_for_deduplication
(...)
283 )
284 )
285 n_gen_draws += 1
File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/modelbridge/generation_node.py:334, in GenerationNode._gen(self, n, pending_observations, **model_gen_kwargs)
330 if n is None and model_spec.model_gen_kwargs:
331 # If `n` is not specified, ensure that the `None` value does not
332 # override the one set in `model_spec.model_gen_kwargs`.
333 n = model_spec.model_gen_kwargs.get("n", None)
--> 334 return model_spec.gen(
335 n=n,
336 # For `pending_observations`, prefer the input to this function, as
337 # `pending_observations` are dynamic throughout the experiment and thus
338 # unlikely to be specified in `model_spec.model_gen_kwargs`.
339 pending_observations=pending_observations,
340 **model_gen_kwargs,
341 )
File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/modelbridge/model_spec.py:221, in ModelSpec.gen(self, **model_gen_kwargs)
213 fitted_model = self.fitted_model
214 model_gen_kwargs = consolidate_kwargs(
215 kwargs_iterable=[
216 self.model_gen_kwargs,
(...)
219 keywords=get_function_argument_names(fitted_model.gen),
220 )
--> 221 return fitted_model.gen(**model_gen_kwargs)
File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/modelbridge/base.py:786, in ModelBridge.gen(self, n, search_space, optimization_config, pending_observations, fixed_features, model_gen_options)
779 base_gen_args = self._get_transformed_gen_args(
780 search_space=search_space,
781 optimization_config=optimization_config,
782 pending_observations=pending_observations,
783 fixed_features=fixed_features,
784 )
785 # Apply terminal transform and gen
--> 786 gen_results = self._gen(
787 n=n,
788 search_space=base_gen_args.search_space,
789 optimization_config=base_gen_args.optimization_config,
790 pending_observations=base_gen_args.pending_observations,
791 fixed_features=base_gen_args.fixed_features,
792 model_gen_options=model_gen_options,
793 )
795 observation_features = gen_results.observation_features
796 best_obsf = gen_results.best_observation_features
File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/modelbridge/torch.py:721, in TorchModelBridge._gen(self, n, search_space, pending_observations, fixed_features, model_gen_options, optimization_config)
716 observation_features = self._array_to_observation_features(
717 X=gen_results.points.detach().cpu().clone().numpy(),
718 candidate_metadata=gen_results.candidate_metadata,
719 )
720 try:
--> 721 xbest = not_none(self.model).best_point(
722 search_space_digest=search_space_digest,
723 torch_opt_config=torch_opt_config,
724 )
725 except NotImplementedError:
726 xbest = None
File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/models/torch/botorch_modular/model.py:79, in single_surrogate_only.<locals>.impl(self, *args, **kwargs)
74 if len(self._surrogates) != 1:
75 raise NotImplementedError(
76 f"{f.__name__} not implemented for multi-surrogate case. Found "
77 f"{self.surrogates=}."
78 )
---> 79 return f(self, *args, **kwargs)
File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/models/torch/botorch_modular/model.py:479, in BoTorchModel.best_point(self, search_space_digest, torch_opt_config)
471 @copy_doc(TorchModel.best_point)
472 @single_surrogate_only
473 def best_point(
(...)
476 torch_opt_config: TorchOptConfig,
477 ) -> Optional[Tensor]:
478 try:
--> 479 return self.surrogate.best_in_sample_point(
480 search_space_digest=search_space_digest,
481 torch_opt_config=torch_opt_config,
482 )[0]
483 except ValueError:
484 return None
File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/models/torch/botorch_modular/surrogate.py:622, in Surrogate.best_in_sample_point(self, search_space_digest, torch_opt_config, options)
618 if torch_opt_config.is_moo:
619 raise NotImplementedError(
620 "Best observed point is incompatible with MOO problems."
621 )
--> 622 best_point_and_observed_value = best_in_sample_point(
623 Xs=self.Xs,
624 model=self,
625 bounds=search_space_digest.bounds,
626 objective_weights=torch_opt_config.objective_weights,
627 outcome_constraints=torch_opt_config.outcome_constraints,
628 linear_constraints=torch_opt_config.linear_constraints,
629 fixed_features=torch_opt_config.fixed_features,
630 risk_measure=torch_opt_config.risk_measure,
631 options=options,
632 )
633 if best_point_and_observed_value is None:
634 raise ValueError("Could not obtain best in-sample point.")
File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/models/model_utils.py:464, in best_in_sample_point(Xs, model, bounds, objective_weights, outcome_constraints, linear_constraints, fixed_features, risk_measure, options)
462 raise UnsupportedError(f"Unknown best point method {method}.")
463 i = np.argmax(utility)
--> 464 if utility[i] == -np.Inf:
465 return None
466 else:
File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/numpy/__init__.py:397, in __getattr__(attr)
394 raise AttributeError(__former_attrs__[attr])
396 if attr in __expired_attributes__:
--> 397 raise AttributeError(
398 f"`np.{attr}` was removed in the NumPy 2.0 release. "
399 f"{__expired_attributes__[attr]}"
400 )
402 if attr == "chararray":
403 warnings.warn(
404 "`np.chararray` is deprecated and will be removed from "
405 "the main namespace in the future. Use an array with a string "
406 "or bytes dtype instead.", DeprecationWarning, stacklevel=2)
AttributeError: `np.Inf` was removed in the NumPy 2.0 release. Use `np.inf` instead.
print(best_parameters, metrics)
{'x1': 9.42045994053506, 'x2': 2.522998247072662} ({'branin': 0.6237735107536508}, {'branin': {'branin': 0.27541560401879245}})
After running this, return to the Bayesian optimization tutorial notebook.