3.1 Ax Platform Basic Usage

Open In Colab

This notebook uses the Ax Platform’s Service API to perform Bayesian optimization. Here, we use an analytic function called the Branin function, which is often used for optimization benchmarking. It takes two parameters as inputs, and returns one value as the output. The task here is to optimize the parameter inputs to minimize the function value.

# only install if we are running in colab
import sys
IN_COLAB = 'google.colab' in sys.modules
if IN_COLAB:
    %pip install ax-platform
import math
from ax.service.ax_client import AxClient, ObjectiveProperties

obj1_name = "branin"

def branin(x1, x2):
    y = float(
        (x2 - 5.1 / (4 * math.pi**2) * x1**2 + 5.0 / math.pi * x1 - 6.0) ** 2
        + 10 * (1 - 1.0 / (8 * math.pi)) * math.cos(x1)
        + 10
    )

    return y

ax_client = AxClient()
ax_client.create_experiment(
    parameters=[
        {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]},
        {"name": "x2", "type": "range", "bounds": [0.0, 15.0]},
    ],
    objectives={
        obj1_name: ObjectiveProperties(minimize=True),
    },
)

for _ in range(15):
    parameters, trial_index = ax_client.get_next_trial()
    results = branin(
        parameters["x1"],
        parameters["x2"],
    )
    ax_client.complete_trial(trial_index=trial_index, raw_data=results)

best_parameters, metrics = ax_client.get_best_parameters()
[WARNING 07-19 14:05:57] ax.service.utils.with_db_settings_base: Ax currently requires a sqlalchemy version below 2.0. This will be addressed in a future release. Disabling SQL storage in Ax for now, if you would like to use SQL storage please install Ax with mysql extras via `pip install ax-platform[mysql]`.
[INFO 07-19 14:05:57] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 6 decimal points.
[INFO 07-19 14:05:57] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x1. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 07-19 14:05:57] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x2. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 07-19 14:05:57] ax.service.utils.instantiation: Created search space: SearchSpace(parameters=[RangeParameter(name='x1', parameter_type=FLOAT, range=[-5.0, 10.0]), RangeParameter(name='x2', parameter_type=FLOAT, range=[0.0, 15.0])], parameter_constraints=[]).
[INFO 07-19 14:05:57] ax.modelbridge.dispatch_utils: Using Models.BOTORCH_MODULAR since there is at least one ordered parameter and there are no unordered categorical parameters.
[INFO 07-19 14:05:57] ax.modelbridge.dispatch_utils: Calculating the number of remaining initialization trials based on num_initialization_trials=None max_initialization_trials=None num_tunable_parameters=2 num_trials=None use_batch_trials=False
[INFO 07-19 14:05:57] ax.modelbridge.dispatch_utils: calculated num_initialization_trials=5
[INFO 07-19 14:05:57] ax.modelbridge.dispatch_utils: num_completed_initialization_trials=0 num_remaining_initialization_trials=5
[INFO 07-19 14:05:57] ax.modelbridge.dispatch_utils: `verbose`, `disable_progbar`, and `jit_compile` are not yet supported when using `choose_generation_strategy` with ModularBoTorchModel, dropping these arguments.
[INFO 07-19 14:05:57] ax.modelbridge.dispatch_utils: Using Bayesian Optimization generation strategy: GenerationStrategy(name='Sobol+BoTorch', steps=[Sobol for 5 trials, BoTorch for subsequent trials]). Iterations after 5 will take longer to generate due to model-fitting.
[INFO 07-19 14:05:57] ax.service.ax_client: Generated new trial 0 with parameters {'x1': 7.46005, 'x2': 11.050261} using model Sobol.
[INFO 07-19 14:05:57] ax.service.ax_client: Completed trial 0 with data: {'branin': (108.433893, None)}.
[INFO 07-19 14:05:57] ax.service.ax_client: Generated new trial 1 with parameters {'x1': 9.75531, 'x2': 0.75612} using model Sobol.
[INFO 07-19 14:05:57] ax.service.ax_client: Completed trial 1 with data: {'branin': (4.965009, None)}.
[INFO 07-19 14:05:57] ax.service.ax_client: Generated new trial 2 with parameters {'x1': 6.443358, 'x2': 2.343874} using model Sobol.
[INFO 07-19 14:05:57] ax.service.ax_client: Completed trial 2 with data: {'branin': (21.005569, None)}.
[INFO 07-19 14:05:57] ax.service.ax_client: Generated new trial 3 with parameters {'x1': 5.215967, 'x2': 8.146018} using model Sobol.
[INFO 07-19 14:05:57] ax.service.ax_client: Completed trial 3 with data: {'branin': (62.69807, None)}.
[INFO 07-19 14:05:57] ax.service.ax_client: Generated new trial 4 with parameters {'x1': 8.841575, 'x2': 8.930721} using model Sobol.
[INFO 07-19 14:05:57] ax.service.ax_client: Completed trial 4 with data: {'branin': (49.646523, None)}.
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[2], line 27
     16 ax_client.create_experiment(
     17     parameters=[
     18         {"name": "x1", "type": "range", "bounds": [-5.0, 10.0]},
   (...)
     23     },
     24 )
     26 for _ in range(15):
---> 27     parameters, trial_index = ax_client.get_next_trial()
     28     results = branin(
     29         parameters["x1"],
     30         parameters["x2"],
     31     )
     32     ax_client.complete_trial(trial_index=trial_index, raw_data=results)

File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/utils/common/executils.py:163, in retry_on_exception.<locals>.func_wrapper.<locals>.actual_wrapper(*args, **kwargs)
    159             wait_interval = min(
    160                 MAX_WAIT_SECONDS, initial_wait_seconds * 2 ** (i - 1)
    161             )
    162             time.sleep(wait_interval)
--> 163         return func(*args, **kwargs)
    165 # If we are here, it means the retries were finished but
    166 # The error was suppressed. Hence return the default value provided.
    167 return default_return_on_suppression

File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/service/ax_client.py:539, in AxClient.get_next_trial(self, ttl_seconds, force, fixed_features)
    535         raise OptimizationShouldStop(message=global_stopping_message)
    537 try:
    538     trial = self.experiment.new_trial(
--> 539         generator_run=self._gen_new_generator_run(
    540             fixed_features=fixed_features
    541         ),
    542         ttl_seconds=ttl_seconds,
    543     )
    544 except MaxParallelismReachedException as e:
    545     if self._early_stopping_strategy is not None:

File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/service/ax_client.py:1790, in AxClient._gen_new_generator_run(self, n, fixed_features)
   1782 fixed_feats = (
   1783     InstantiationBase.make_fixed_observation_features(
   1784         fixed_features=fixed_features
   (...)
   1787     else None
   1788 )
   1789 with with_rng_seed(seed=self._random_seed):
-> 1790     return not_none(self.generation_strategy).gen(
   1791         experiment=self.experiment,
   1792         n=n,
   1793         pending_observations=self._get_pending_observation_features(
   1794             experiment=self.experiment
   1795         ),
   1796         fixed_features=fixed_feats,
   1797     )

File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/modelbridge/generation_strategy.py:370, in GenerationStrategy.gen(self, experiment, data, n, pending_observations, **kwargs)
    335 def gen(
    336     self,
    337     experiment: Experiment,
   (...)
    341     **kwargs: Any,
    342 ) -> GeneratorRun:
    343     """Produce the next points in the experiment. Additional kwargs passed to
    344     this method are propagated directly to the underlying model's `gen`, along
    345     with the `model_gen_kwargs` set on the current generation node.
   (...)
    368             resuggesting points that are currently being evaluated.
    369     """
--> 370     return self._gen_multiple(
    371         experiment=experiment,
    372         num_generator_runs=1,
    373         data=data,
    374         n=n,
    375         pending_observations=pending_observations,
    376         **kwargs,
    377     )[0]

File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/modelbridge/generation_strategy.py:683, in GenerationStrategy._gen_multiple(self, experiment, num_generator_runs, data, n, pending_observations, **model_gen_kwargs)
    681 for _ in range(num_generator_runs):
    682     try:
--> 683         generator_run = self._curr.gen(
    684             n=n,
    685             pending_observations=pending_observations,
    686             arms_by_signature_for_deduplication=experiment.arms_by_signature,
    687             **model_gen_kwargs,
    688         )
    690     except DataRequiredError as err:
    691         # Model needs more data, so we log the error and return
    692         # as many generator runs as we were able to produce, unless
    693         # no trials were produced at all (in which case its safe to raise).
    694         if len(generator_runs) == 0:

File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/modelbridge/generation_node.py:712, in GenerationStep.gen(self, n, pending_observations, max_gen_draws_for_deduplication, arms_by_signature_for_deduplication, **model_gen_kwargs)
    704 def gen(
    705     self,
    706     n: Optional[int] = None,
   (...)
    710     **model_gen_kwargs: Any,
    711 ) -> GeneratorRun:
--> 712     gr = super().gen(
    713         n=n,
    714         pending_observations=pending_observations,
    715         max_gen_draws_for_deduplication=max_gen_draws_for_deduplication,
    716         arms_by_signature_for_deduplication=arms_by_signature_for_deduplication,
    717         **model_gen_kwargs,
    718     )
    719     gr._generation_step_index = self.index
    720     return gr

File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/modelbridge/generation_node.py:272, in GenerationNode.gen(self, n, pending_observations, max_gen_draws_for_deduplication, arms_by_signature_for_deduplication, **model_gen_kwargs)
    269 # Keep generating until each of `generator_run.arms` is not a duplicate
    270 # of a previous arm, if `should_deduplicate is True`
    271 while should_generate_run:
--> 272     generator_run = self._gen(
    273         n=n,
    274         pending_observations=pending_observations,
    275         **model_gen_kwargs,
    276     )
    277     should_generate_run = (
    278         self.should_deduplicate
    279         and arms_by_signature_for_deduplication
   (...)
    283         )
    284     )
    285     n_gen_draws += 1

File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/modelbridge/generation_node.py:334, in GenerationNode._gen(self, n, pending_observations, **model_gen_kwargs)
    330 if n is None and model_spec.model_gen_kwargs:
    331     # If `n` is not specified, ensure that the `None` value does not
    332     # override the one set in `model_spec.model_gen_kwargs`.
    333     n = model_spec.model_gen_kwargs.get("n", None)
--> 334 return model_spec.gen(
    335     n=n,
    336     # For `pending_observations`, prefer the input to this function, as
    337     # `pending_observations` are dynamic throughout the experiment and thus
    338     # unlikely to be specified in `model_spec.model_gen_kwargs`.
    339     pending_observations=pending_observations,
    340     **model_gen_kwargs,
    341 )

File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/modelbridge/model_spec.py:221, in ModelSpec.gen(self, **model_gen_kwargs)
    213 fitted_model = self.fitted_model
    214 model_gen_kwargs = consolidate_kwargs(
    215     kwargs_iterable=[
    216         self.model_gen_kwargs,
   (...)
    219     keywords=get_function_argument_names(fitted_model.gen),
    220 )
--> 221 return fitted_model.gen(**model_gen_kwargs)

File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/modelbridge/base.py:786, in ModelBridge.gen(self, n, search_space, optimization_config, pending_observations, fixed_features, model_gen_options)
    779 base_gen_args = self._get_transformed_gen_args(
    780     search_space=search_space,
    781     optimization_config=optimization_config,
    782     pending_observations=pending_observations,
    783     fixed_features=fixed_features,
    784 )
    785 # Apply terminal transform and gen
--> 786 gen_results = self._gen(
    787     n=n,
    788     search_space=base_gen_args.search_space,
    789     optimization_config=base_gen_args.optimization_config,
    790     pending_observations=base_gen_args.pending_observations,
    791     fixed_features=base_gen_args.fixed_features,
    792     model_gen_options=model_gen_options,
    793 )
    795 observation_features = gen_results.observation_features
    796 best_obsf = gen_results.best_observation_features

File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/modelbridge/torch.py:721, in TorchModelBridge._gen(self, n, search_space, pending_observations, fixed_features, model_gen_options, optimization_config)
    716 observation_features = self._array_to_observation_features(
    717     X=gen_results.points.detach().cpu().clone().numpy(),
    718     candidate_metadata=gen_results.candidate_metadata,
    719 )
    720 try:
--> 721     xbest = not_none(self.model).best_point(
    722         search_space_digest=search_space_digest,
    723         torch_opt_config=torch_opt_config,
    724     )
    725 except NotImplementedError:
    726     xbest = None

File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/models/torch/botorch_modular/model.py:79, in single_surrogate_only.<locals>.impl(self, *args, **kwargs)
     74 if len(self._surrogates) != 1:
     75     raise NotImplementedError(
     76         f"{f.__name__} not implemented for multi-surrogate case. Found "
     77         f"{self.surrogates=}."
     78     )
---> 79 return f(self, *args, **kwargs)

File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/models/torch/botorch_modular/model.py:479, in BoTorchModel.best_point(self, search_space_digest, torch_opt_config)
    471 @copy_doc(TorchModel.best_point)
    472 @single_surrogate_only
    473 def best_point(
   (...)
    476     torch_opt_config: TorchOptConfig,
    477 ) -> Optional[Tensor]:
    478     try:
--> 479         return self.surrogate.best_in_sample_point(
    480             search_space_digest=search_space_digest,
    481             torch_opt_config=torch_opt_config,
    482         )[0]
    483     except ValueError:
    484         return None

File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/models/torch/botorch_modular/surrogate.py:622, in Surrogate.best_in_sample_point(self, search_space_digest, torch_opt_config, options)
    618 if torch_opt_config.is_moo:
    619     raise NotImplementedError(
    620         "Best observed point is incompatible with MOO problems."
    621     )
--> 622 best_point_and_observed_value = best_in_sample_point(
    623     Xs=self.Xs,
    624     model=self,
    625     bounds=search_space_digest.bounds,
    626     objective_weights=torch_opt_config.objective_weights,
    627     outcome_constraints=torch_opt_config.outcome_constraints,
    628     linear_constraints=torch_opt_config.linear_constraints,
    629     fixed_features=torch_opt_config.fixed_features,
    630     risk_measure=torch_opt_config.risk_measure,
    631     options=options,
    632 )
    633 if best_point_and_observed_value is None:
    634     raise ValueError("Could not obtain best in-sample point.")

File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/models/model_utils.py:464, in best_in_sample_point(Xs, model, bounds, objective_weights, outcome_constraints, linear_constraints, fixed_features, risk_measure, options)
    462     raise UnsupportedError(f"Unknown best point method {method}.")
    463 i = np.argmax(utility)
--> 464 if utility[i] == -np.Inf:
    465     return None
    466 else:

File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/numpy/__init__.py:397, in __getattr__(attr)
    394     raise AttributeError(__former_attrs__[attr])
    396 if attr in __expired_attributes__:
--> 397     raise AttributeError(
    398         f"`np.{attr}` was removed in the NumPy 2.0 release. "
    399         f"{__expired_attributes__[attr]}"
    400     )
    402 if attr == "chararray":
    403     warnings.warn(
    404         "`np.chararray` is deprecated and will be removed from "
    405         "the main namespace in the future. Use an array with a string "
    406         "or bytes dtype instead.", DeprecationWarning, stacklevel=2)

AttributeError: `np.Inf` was removed in the NumPy 2.0 release. Use `np.inf` instead.
print(best_parameters, metrics)
{'x1': 9.42045994053506, 'x2': 2.522998247072662} ({'branin': 0.6237735107536508}, {'branin': {'branin': 0.27541560401879245}})

After running this, return to the Bayesian optimization tutorial notebook.