Service API Example on Hartmann6

Open In Colab

The Ax Service API is designed to allow the user to control scheduling of trials and data computation while having an easy to use interface with Ax.

The user iteratively:

  • Queries Ax for candidates

  • Schedules / deploys them however they choose

  • Computes data and logs to Ax

  • Repeat

# only install if we are running in colab
# kaleido only for plotting
import sys
IN_COLAB = 'google.colab' in sys.modules
if IN_COLAB:
    %pip install ax-platform kaleido
from ax.service.ax_client import AxClient, ObjectiveProperties
from ax.utils.measurement.synthetic_functions import hartmann6
from ax.utils.notebook.plotting import render, init_notebook_plotting
import plotly.io as pio

init_notebook_plotting()
pio.renderers.default = "colab"
[WARNING 07-19 14:05:38] ax.service.utils.with_db_settings_base: Ax currently requires a sqlalchemy version below 2.0. This will be addressed in a future release. Disabling SQL storage in Ax for now, if you would like to use SQL storage please install Ax with mysql extras via `pip install ax-platform[mysql]`.
[INFO 07-19 14:05:38] ax.utils.notebook.plotting: Injecting Plotly library into cell. Do not overwrite or delete cell.
[INFO 07-19 14:05:38] ax.utils.notebook.plotting: Please see
    (https://ax.dev/tutorials/visualizations.html#Fix-for-plots-that-are-not-rendering)
    if visualizations are not rendering.

1. Initialize client

Create a client object to interface with Ax APIs. By default this runs locally without storage.

ax_client = AxClient()
[INFO 07-19 14:05:38] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 6 decimal points.

2. Set up experiment

An experiment consists of a search space (parameters and parameter constraints) and optimization configuration (objectives and outcome constraints). Note that:

  • Only parameters, and objectives arguments are required.

  • Dictionaries in parameters have the following required keys: “name” - parameter name, “type” - parameter type (“range”, “choice” or “fixed”), “bounds” for range parameters, “values” for choice parameters, and “value” for fixed parameters.

  • Dictionaries in parameters can optionally include “value_type” (“int”, “float”, “bool” or “str”), “log_scale” flag for range parameters, and “is_ordered” flag for choice parameters.

  • parameter_constraints should be a list of strings of form “p1 >= p2” or “p1 + p2 <= some_bound”.

  • outcome_constraints should be a list of strings of form “constrained_metric <= some_bound”.

ax_client.create_experiment(
    name="hartmann_test_experiment",
    parameters=[
        {
            "name": "x1",
            "type": "range",
            "bounds": [0.0, 1.0],
            "value_type": "float",  # Optional, defaults to inference from type of "bounds".
            "log_scale": False,  # Optional, defaults to False.
        },
        {
            "name": "x2",
            "type": "range",
            "bounds": [0.0, 1.0],
        },
        {
            "name": "x3",
            "type": "range",
            "bounds": [0.0, 1.0],
        },
        {
            "name": "x4",
            "type": "range",
            "bounds": [0.0, 1.0],
        },
        {
            "name": "x5",
            "type": "range",
            "bounds": [0.0, 1.0],
        },
        {
            "name": "x6",
            "type": "range",
            "bounds": [0.0, 1.0],
        },
    ],
    objectives={"hartmann6": ObjectiveProperties(minimize=True)},
    parameter_constraints=["x1 + x2 <= 2.0"],  # Optional.
    outcome_constraints=["l2norm <= 1.25"],  # Optional.
)
[INFO 07-19 14:05:38] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x2. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 07-19 14:05:38] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x3. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 07-19 14:05:38] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x4. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 07-19 14:05:38] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x5. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 07-19 14:05:38] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x6. If that is not the expected value type, you can explicitly specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 07-19 14:05:38] ax.service.utils.instantiation: Created search space: SearchSpace(parameters=[RangeParameter(name='x1', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x2', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x3', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x4', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x5', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x6', parameter_type=FLOAT, range=[0.0, 1.0])], parameter_constraints=[ParameterConstraint(1.0*x1 + 1.0*x2 <= 2.0)]).
[INFO 07-19 14:05:38] ax.modelbridge.dispatch_utils: Using Models.BOTORCH_MODULAR since there is at least one ordered parameter and there are no unordered categorical parameters.
[INFO 07-19 14:05:38] ax.modelbridge.dispatch_utils: Calculating the number of remaining initialization trials based on num_initialization_trials=None max_initialization_trials=None num_tunable_parameters=6 num_trials=None use_batch_trials=False
[INFO 07-19 14:05:38] ax.modelbridge.dispatch_utils: calculated num_initialization_trials=12
[INFO 07-19 14:05:38] ax.modelbridge.dispatch_utils: num_completed_initialization_trials=0 num_remaining_initialization_trials=12
[INFO 07-19 14:05:38] ax.modelbridge.dispatch_utils: `verbose`, `disable_progbar`, and `jit_compile` are not yet supported when using `choose_generation_strategy` with ModularBoTorchModel, dropping these arguments.
[INFO 07-19 14:05:38] ax.modelbridge.dispatch_utils: Using Bayesian Optimization generation strategy: GenerationStrategy(name='Sobol+BoTorch', steps=[Sobol for 12 trials, BoTorch for subsequent trials]). Iterations after 12 will take longer to generate due to model-fitting.

3. Define how to evaluate trials

When using Ax a service, evaluation of parameterizations suggested by Ax is done either locally or, more commonly, using an external scheduler. Below is a dummy evaluation function that outputs data for two metrics “hartmann6” and “l2norm”. Note that all returned metrics correspond to either the objectives set on experiment creation or the metric names mentioned in outcome_constraints.

import numpy as np


def evaluate(parameters):
    # Assuming you have a dictionary called 'parameters' with keys 'x1', 'x2', ..., 'x6'
    # Iterate over the range [0, 1, 2, 3, 4, 5] and uses 'parameters.get()' to fetch values with keys like 'x1', 'x2', etc.
    x = np.array([parameters.get(f"x{i+1}") for i in range(6)])

    hartmann6_mean = hartmann6(x)
    l2norm_mean = np.sqrt((x**2).sum())

    # In our case, standard error is 0, since we are computing a synthetic function.
    hartmann6_std = 0.0
    l2norm_std = 0.0
    return {"hartmann6": (hartmann6_mean, hartmann6_std), "l2norm": (l2norm_mean, l2norm_std)}

Result of the evaluation should generally be a mapping of the format: {metric_name -> (mean, SEM)}. If there is only one metric in the experiment – the objective – then evaluation function can return a single tuple of mean and SEM, in which case Ax will assume that evaluation corresponds to the objective. It can also return only the mean as a float, in which case Ax will treat SEM as unknown and use a model that can infer it.

For more details on evaluation function, refer to the “Trial Evaluation” section in the Ax docs at ax.dev

4. Run optimization loop

With the experiment set up, we can start the optimization loop.

At each step, the user queries the client for a new trial then submits the evaluation of that trial back to the client.

Note that Ax auto-selects an appropriate optimization algorithm based on the search space. For more advance use cases that require a specific optimization algorithm, pass a generation_strategy argument into the AxClient constructor. Note that when Bayesian Optimization is used, generating new trials may take a few minutes.

for i in range(25):
    parameters, trial_index = ax_client.get_next_trial()
    # Local evaluation here can be replaced with deployment to external system.
    ax_client.complete_trial(trial_index=trial_index, raw_data=evaluate(parameters))
[INFO 07-19 14:05:38] ax.service.ax_client: Generated new trial 0 with parameters {'x1': 0.107889, 'x2': 0.922518, 'x3': 0.924539, 'x4': 0.936394, 'x5': 0.734736, 'x6': 0.165537} using model Sobol.
[INFO 07-19 14:05:38] ax.service.ax_client: Completed trial 0 with data: {'hartmann6': (-0.146128, 0.0), 'l2norm': (np.float64(1.778068), 0.0)}.
[INFO 07-19 14:05:38] ax.service.ax_client: Generated new trial 1 with parameters {'x1': 0.443552, 'x2': 0.488982, 'x3': 0.836648, 'x4': 0.320775, 'x5': 0.276616, 'x6': 0.706204} using model Sobol.
[INFO 07-19 14:05:38] ax.service.ax_client: Completed trial 1 with data: {'hartmann6': (-1.418863, 0.0), 'l2norm': (np.float64(1.346833), 0.0)}.
[INFO 07-19 14:05:38] ax.service.ax_client: Generated new trial 2 with parameters {'x1': 0.65414, 'x2': 0.516116, 'x3': 0.319941, 'x4': 0.942945, 'x5': 0.718464, 'x6': 0.659873} using model Sobol.
[INFO 07-19 14:05:38] ax.service.ax_client: Completed trial 2 with data: {'hartmann6': (-0.002137, 0.0), 'l2norm': (np.float64(1.624009), 0.0)}.
[INFO 07-19 14:05:38] ax.service.ax_client: Generated new trial 3 with parameters {'x1': 0.777313, 'x2': 0.713227, 'x3': 0.231293, 'x4': 0.364289, 'x5': 0.751878, 'x6': 0.620248} using model Sobol.
[INFO 07-19 14:05:39] ax.service.ax_client: Completed trial 3 with data: {'hartmann6': (-0.014062, 0.0), 'l2norm': (np.float64(1.499713), 0.0)}.
[INFO 07-19 14:05:39] ax.service.ax_client: Generated new trial 4 with parameters {'x1': 0.314862, 'x2': 0.406853, 'x3': 0.288613, 'x4': 0.288989, 'x5': 0.714391, 'x6': 0.180744} using model Sobol.
[INFO 07-19 14:05:39] ax.service.ax_client: Completed trial 4 with data: {'hartmann6': (-0.199245, 0.0), 'l2norm': (np.float64(0.987169), 0.0)}.
[INFO 07-19 14:05:39] ax.service.ax_client: Generated new trial 5 with parameters {'x1': 0.243049, 'x2': 0.871583, 'x3': 0.80612, 'x4': 0.597061, 'x5': 0.364388, 'x6': 0.141447} using model Sobol.
[INFO 07-19 14:05:39] ax.service.ax_client: Completed trial 5 with data: {'hartmann6': (-1.760953, 0.0), 'l2norm': (np.float64(1.406352), 0.0)}.
[INFO 07-19 14:05:39] ax.service.ax_client: Generated new trial 6 with parameters {'x1': 0.060675, 'x2': 0.27577, 'x3': 0.448922, 'x4': 0.104325, 'x5': 0.647423, 'x6': 0.607618} using model Sobol.
[INFO 07-19 14:05:39] ax.service.ax_client: Completed trial 6 with data: {'hartmann6': (-0.931146, 0.0), 'l2norm': (np.float64(1.039472), 0.0)}.
[INFO 07-19 14:05:39] ax.service.ax_client: Generated new trial 7 with parameters {'x1': 0.504149, 'x2': 0.67087, 'x3': 0.311699, 'x4': 0.49141, 'x5': 0.022197, 'x6': 0.17924} using model Sobol.
[INFO 07-19 14:05:39] ax.service.ax_client: Completed trial 7 with data: {'hartmann6': (-1.337291, 0.0), 'l2norm': (np.float64(1.037059), 0.0)}.
[INFO 07-19 14:05:39] ax.service.ax_client: Generated new trial 8 with parameters {'x1': 0.602793, 'x2': 0.765291, 'x3': 0.639362, 'x4': 0.046317, 'x5': 0.422133, 'x6': 0.370507} using model Sobol.
[INFO 07-19 14:05:39] ax.service.ax_client: Completed trial 8 with data: {'hartmann6': (-0.135917, 0.0), 'l2norm': (np.float64(1.294384), 0.0)}.
[INFO 07-19 14:05:39] ax.service.ax_client: Generated new trial 9 with parameters {'x1': 0.617003, 'x2': 0.130774, 'x3': 0.901677, 'x4': 0.839172, 'x5': 0.240024, 'x6': 0.922479} using model Sobol.
[INFO 07-19 14:05:39] ax.service.ax_client: Completed trial 9 with data: {'hartmann6': (-0.409906, 0.0), 'l2norm': (np.float64(1.680358), 0.0)}.
[INFO 07-19 14:05:39] ax.service.ax_client: Generated new trial 10 with parameters {'x1': 0.786644, 'x2': 0.828206, 'x3': 0.423388, 'x4': 0.919882, 'x5': 0.152168, 'x6': 0.770439} using model Sobol.
[INFO 07-19 14:05:39] ax.service.ax_client: Completed trial 10 with data: {'hartmann6': (-0.008517, 0.0), 'l2norm': (np.float64(1.716656), 0.0)}.
[INFO 07-19 14:05:39] ax.service.ax_client: Generated new trial 11 with parameters {'x1': 0.291381, 'x2': 0.749571, 'x3': 0.10217, 'x4': 0.911147, 'x5': 0.967016, 'x6': 0.310025} using model Sobol.
[INFO 07-19 14:05:39] ax.service.ax_client: Completed trial 11 with data: {'hartmann6': (-0.230108, 0.0), 'l2norm': (np.float64(1.587017), 0.0)}.
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[6], line 2
      1 for i in range(25):
----> 2     parameters, trial_index = ax_client.get_next_trial()
      3     # Local evaluation here can be replaced with deployment to external system.
      4     ax_client.complete_trial(trial_index=trial_index, raw_data=evaluate(parameters))

File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/utils/common/executils.py:163, in retry_on_exception.<locals>.func_wrapper.<locals>.actual_wrapper(*args, **kwargs)
    159             wait_interval = min(
    160                 MAX_WAIT_SECONDS, initial_wait_seconds * 2 ** (i - 1)
    161             )
    162             time.sleep(wait_interval)
--> 163         return func(*args, **kwargs)
    165 # If we are here, it means the retries were finished but
    166 # The error was suppressed. Hence return the default value provided.
    167 return default_return_on_suppression

File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/service/ax_client.py:539, in AxClient.get_next_trial(self, ttl_seconds, force, fixed_features)
    535         raise OptimizationShouldStop(message=global_stopping_message)
    537 try:
    538     trial = self.experiment.new_trial(
--> 539         generator_run=self._gen_new_generator_run(
    540             fixed_features=fixed_features
    541         ),
    542         ttl_seconds=ttl_seconds,
    543     )
    544 except MaxParallelismReachedException as e:
    545     if self._early_stopping_strategy is not None:

File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/service/ax_client.py:1790, in AxClient._gen_new_generator_run(self, n, fixed_features)
   1782 fixed_feats = (
   1783     InstantiationBase.make_fixed_observation_features(
   1784         fixed_features=fixed_features
   (...)
   1787     else None
   1788 )
   1789 with with_rng_seed(seed=self._random_seed):
-> 1790     return not_none(self.generation_strategy).gen(
   1791         experiment=self.experiment,
   1792         n=n,
   1793         pending_observations=self._get_pending_observation_features(
   1794             experiment=self.experiment
   1795         ),
   1796         fixed_features=fixed_feats,
   1797     )

File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/modelbridge/generation_strategy.py:370, in GenerationStrategy.gen(self, experiment, data, n, pending_observations, **kwargs)
    335 def gen(
    336     self,
    337     experiment: Experiment,
   (...)
    341     **kwargs: Any,
    342 ) -> GeneratorRun:
    343     """Produce the next points in the experiment. Additional kwargs passed to
    344     this method are propagated directly to the underlying model's `gen`, along
    345     with the `model_gen_kwargs` set on the current generation node.
   (...)
    368             resuggesting points that are currently being evaluated.
    369     """
--> 370     return self._gen_multiple(
    371         experiment=experiment,
    372         num_generator_runs=1,
    373         data=data,
    374         n=n,
    375         pending_observations=pending_observations,
    376         **kwargs,
    377     )[0]

File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/modelbridge/generation_strategy.py:683, in GenerationStrategy._gen_multiple(self, experiment, num_generator_runs, data, n, pending_observations, **model_gen_kwargs)
    681 for _ in range(num_generator_runs):
    682     try:
--> 683         generator_run = self._curr.gen(
    684             n=n,
    685             pending_observations=pending_observations,
    686             arms_by_signature_for_deduplication=experiment.arms_by_signature,
    687             **model_gen_kwargs,
    688         )
    690     except DataRequiredError as err:
    691         # Model needs more data, so we log the error and return
    692         # as many generator runs as we were able to produce, unless
    693         # no trials were produced at all (in which case its safe to raise).
    694         if len(generator_runs) == 0:

File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/modelbridge/generation_node.py:712, in GenerationStep.gen(self, n, pending_observations, max_gen_draws_for_deduplication, arms_by_signature_for_deduplication, **model_gen_kwargs)
    704 def gen(
    705     self,
    706     n: Optional[int] = None,
   (...)
    710     **model_gen_kwargs: Any,
    711 ) -> GeneratorRun:
--> 712     gr = super().gen(
    713         n=n,
    714         pending_observations=pending_observations,
    715         max_gen_draws_for_deduplication=max_gen_draws_for_deduplication,
    716         arms_by_signature_for_deduplication=arms_by_signature_for_deduplication,
    717         **model_gen_kwargs,
    718     )
    719     gr._generation_step_index = self.index
    720     return gr

File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/modelbridge/generation_node.py:272, in GenerationNode.gen(self, n, pending_observations, max_gen_draws_for_deduplication, arms_by_signature_for_deduplication, **model_gen_kwargs)
    269 # Keep generating until each of `generator_run.arms` is not a duplicate
    270 # of a previous arm, if `should_deduplicate is True`
    271 while should_generate_run:
--> 272     generator_run = self._gen(
    273         n=n,
    274         pending_observations=pending_observations,
    275         **model_gen_kwargs,
    276     )
    277     should_generate_run = (
    278         self.should_deduplicate
    279         and arms_by_signature_for_deduplication
   (...)
    283         )
    284     )
    285     n_gen_draws += 1

File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/modelbridge/generation_node.py:334, in GenerationNode._gen(self, n, pending_observations, **model_gen_kwargs)
    330 if n is None and model_spec.model_gen_kwargs:
    331     # If `n` is not specified, ensure that the `None` value does not
    332     # override the one set in `model_spec.model_gen_kwargs`.
    333     n = model_spec.model_gen_kwargs.get("n", None)
--> 334 return model_spec.gen(
    335     n=n,
    336     # For `pending_observations`, prefer the input to this function, as
    337     # `pending_observations` are dynamic throughout the experiment and thus
    338     # unlikely to be specified in `model_spec.model_gen_kwargs`.
    339     pending_observations=pending_observations,
    340     **model_gen_kwargs,
    341 )

File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/modelbridge/model_spec.py:221, in ModelSpec.gen(self, **model_gen_kwargs)
    213 fitted_model = self.fitted_model
    214 model_gen_kwargs = consolidate_kwargs(
    215     kwargs_iterable=[
    216         self.model_gen_kwargs,
   (...)
    219     keywords=get_function_argument_names(fitted_model.gen),
    220 )
--> 221 return fitted_model.gen(**model_gen_kwargs)

File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/modelbridge/base.py:786, in ModelBridge.gen(self, n, search_space, optimization_config, pending_observations, fixed_features, model_gen_options)
    779 base_gen_args = self._get_transformed_gen_args(
    780     search_space=search_space,
    781     optimization_config=optimization_config,
    782     pending_observations=pending_observations,
    783     fixed_features=fixed_features,
    784 )
    785 # Apply terminal transform and gen
--> 786 gen_results = self._gen(
    787     n=n,
    788     search_space=base_gen_args.search_space,
    789     optimization_config=base_gen_args.optimization_config,
    790     pending_observations=base_gen_args.pending_observations,
    791     fixed_features=base_gen_args.fixed_features,
    792     model_gen_options=model_gen_options,
    793 )
    795 observation_features = gen_results.observation_features
    796 best_obsf = gen_results.best_observation_features

File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/modelbridge/torch.py:721, in TorchModelBridge._gen(self, n, search_space, pending_observations, fixed_features, model_gen_options, optimization_config)
    716 observation_features = self._array_to_observation_features(
    717     X=gen_results.points.detach().cpu().clone().numpy(),
    718     candidate_metadata=gen_results.candidate_metadata,
    719 )
    720 try:
--> 721     xbest = not_none(self.model).best_point(
    722         search_space_digest=search_space_digest,
    723         torch_opt_config=torch_opt_config,
    724     )
    725 except NotImplementedError:
    726     xbest = None

File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/models/torch/botorch_modular/model.py:79, in single_surrogate_only.<locals>.impl(self, *args, **kwargs)
     74 if len(self._surrogates) != 1:
     75     raise NotImplementedError(
     76         f"{f.__name__} not implemented for multi-surrogate case. Found "
     77         f"{self.surrogates=}."
     78     )
---> 79 return f(self, *args, **kwargs)

File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/models/torch/botorch_modular/model.py:479, in BoTorchModel.best_point(self, search_space_digest, torch_opt_config)
    471 @copy_doc(TorchModel.best_point)
    472 @single_surrogate_only
    473 def best_point(
   (...)
    476     torch_opt_config: TorchOptConfig,
    477 ) -> Optional[Tensor]:
    478     try:
--> 479         return self.surrogate.best_in_sample_point(
    480             search_space_digest=search_space_digest,
    481             torch_opt_config=torch_opt_config,
    482         )[0]
    483     except ValueError:
    484         return None

File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/models/torch/botorch_modular/surrogate.py:622, in Surrogate.best_in_sample_point(self, search_space_digest, torch_opt_config, options)
    618 if torch_opt_config.is_moo:
    619     raise NotImplementedError(
    620         "Best observed point is incompatible with MOO problems."
    621     )
--> 622 best_point_and_observed_value = best_in_sample_point(
    623     Xs=self.Xs,
    624     model=self,
    625     bounds=search_space_digest.bounds,
    626     objective_weights=torch_opt_config.objective_weights,
    627     outcome_constraints=torch_opt_config.outcome_constraints,
    628     linear_constraints=torch_opt_config.linear_constraints,
    629     fixed_features=torch_opt_config.fixed_features,
    630     risk_measure=torch_opt_config.risk_measure,
    631     options=options,
    632 )
    633 if best_point_and_observed_value is None:
    634     raise ValueError("Could not obtain best in-sample point.")

File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/ax/models/model_utils.py:464, in best_in_sample_point(Xs, model, bounds, objective_weights, outcome_constraints, linear_constraints, fixed_features, risk_measure, options)
    462     raise UnsupportedError(f"Unknown best point method {method}.")
    463 i = np.argmax(utility)
--> 464 if utility[i] == -np.Inf:
    465     return None
    466 else:

File ~/checkouts/readthedocs.org/user_builds/ac-microcourses/envs/stable/lib/python3.11/site-packages/numpy/__init__.py:397, in __getattr__(attr)
    394     raise AttributeError(__former_attrs__[attr])
    396 if attr in __expired_attributes__:
--> 397     raise AttributeError(
    398         f"`np.{attr}` was removed in the NumPy 2.0 release. "
    399         f"{__expired_attributes__[attr]}"
    400     )
    402 if attr == "chararray":
    403     warnings.warn(
    404         "`np.chararray` is deprecated and will be removed from "
    405         "the main namespace in the future. Use an array with a string "
    406         "or bytes dtype instead.", DeprecationWarning, stacklevel=2)

AttributeError: `np.Inf` was removed in the NumPy 2.0 release. Use `np.inf` instead.

How many trials can run in parallel?

By default, Ax restricts number of trials that can run in parallel for some optimization stages, in order to improve the optimization performance and reduce the number of trials that the optimization will require. To check the maximum parallelism for each optimization stage:

ax_client.get_max_parallelism()
[(12, 12), (-1, 3)]

The output of this function is a list of tuples of form (number of trials, max parallelism), so the example above means “the max parallelism is 12 for the first 12 trials and 3 for all subsequent trials.” This is because the first 12 trials are produced quasi-randomly and can all be evaluated at once, and subsequent trials are produced via Bayesian optimization, which converges on optimal point in fewer trials when parallelism is limited. MaxParallelismReachedException indicates that the parallelism limit has been reached –– refer to the ‘Service API Exceptions Meaning and Handling’ section at the end of the tutorial for handling.

How to view all existing trials during optimization?

ax_client.generation_strategy.trials_as_df
[INFO 01-05 19:07:42] ax.modelbridge.generation_strategy: Note that parameter values in dataframe are rounded to 2 decimal points; the values in the dataframe are thus not the exact ones suggested by Ax in trials.
Generation Step Generation Model Trial Index Trial Status Arm Parameterizations
0 0 Sobol 0 COMPLETED {'0_0': {'x1': 0.16, 'x2': 0.82, 'x3': 0.55, '...
1 0 Sobol 1 COMPLETED {'1_0': {'x1': 0.08, 'x2': 0.36, 'x3': 0.55, '...
2 0 Sobol 2 COMPLETED {'2_0': {'x1': 0.83, 'x2': 0.78, 'x3': 0.28, '...
3 0 Sobol 3 COMPLETED {'3_0': {'x1': 0.94, 'x2': 0.73, 'x3': 0.61, '...
4 0 Sobol 4 COMPLETED {'4_0': {'x1': 0.28, 'x2': 0.55, 'x3': 0.82, '...
5 0 Sobol 5 COMPLETED {'5_0': {'x1': 0.3, 'x2': 0.96, 'x3': 0.72, 'x...
6 0 Sobol 6 COMPLETED {'6_0': {'x1': 0.31, 'x2': 0.88, 'x3': 0.87, '...
7 0 Sobol 7 COMPLETED {'7_0': {'x1': 0.99, 'x2': 0.92, 'x3': 0.96, '...
8 0 Sobol 8 COMPLETED {'8_0': {'x1': 0.97, 'x2': 0.52, 'x3': 0.4, 'x...
9 0 Sobol 9 COMPLETED {'9_0': {'x1': 0.54, 'x2': 0.04, 'x3': 0.48, '...
10 0 Sobol 10 COMPLETED {'10_0': {'x1': 0.4, 'x2': 1.0, 'x3': 0.31, 'x...
11 0 Sobol 11 COMPLETED {'11_0': {'x1': 0.87, 'x2': 0.95, 'x3': 0.98, ...
12 1 BoTorch 0 COMPLETED {'0_0': {'x1': 0.16, 'x2': 0.82, 'x3': 0.55, '...
13 1 BoTorch 1 COMPLETED {'1_0': {'x1': 0.08, 'x2': 0.36, 'x3': 0.55, '...
14 1 BoTorch 2 COMPLETED {'2_0': {'x1': 0.83, 'x2': 0.78, 'x3': 0.28, '...
15 1 BoTorch 3 COMPLETED {'3_0': {'x1': 0.94, 'x2': 0.73, 'x3': 0.61, '...
16 1 BoTorch 4 COMPLETED {'4_0': {'x1': 0.28, 'x2': 0.55, 'x3': 0.82, '...
17 1 BoTorch 5 COMPLETED {'5_0': {'x1': 0.3, 'x2': 0.96, 'x3': 0.72, 'x...
18 1 BoTorch 6 COMPLETED {'6_0': {'x1': 0.31, 'x2': 0.88, 'x3': 0.87, '...
19 1 BoTorch 7 COMPLETED {'7_0': {'x1': 0.99, 'x2': 0.92, 'x3': 0.96, '...
20 1 BoTorch 8 COMPLETED {'8_0': {'x1': 0.97, 'x2': 0.52, 'x3': 0.4, 'x...
21 1 BoTorch 9 COMPLETED {'9_0': {'x1': 0.54, 'x2': 0.04, 'x3': 0.48, '...
22 1 BoTorch 10 COMPLETED {'10_0': {'x1': 0.4, 'x2': 1.0, 'x3': 0.31, 'x...
23 1 BoTorch 11 COMPLETED {'11_0': {'x1': 0.87, 'x2': 0.95, 'x3': 0.98, ...
24 1 BoTorch 12 COMPLETED {'12_0': {'x1': 0.37, 'x2': 0.95, 'x3': 0.3, '...
25 1 BoTorch 13 COMPLETED {'13_0': {'x1': 0.4, 'x2': 0.99, 'x3': 0.31, '...
26 1 BoTorch 14 COMPLETED {'14_0': {'x1': 0.44, 'x2': 1.0, 'x3': 0.27, '...
27 1 BoTorch 15 COMPLETED {'15_0': {'x1': 0.39, 'x2': 1.0, 'x3': 0.23, '...
28 1 BoTorch 16 COMPLETED {'16_0': {'x1': 0.38, 'x2': 1.0, 'x3': 0.29, '...
29 1 BoTorch 17 COMPLETED {'17_0': {'x1': 0.33, 'x2': 1.0, 'x3': 0.27, '...
30 1 BoTorch 18 COMPLETED {'18_0': {'x1': 0.37, 'x2': 1.0, 'x3': 0.23, '...
31 1 BoTorch 19 COMPLETED {'19_0': {'x1': 0.39, 'x2': 0.95, 'x3': 0.23, ...
32 1 BoTorch 20 COMPLETED {'20_0': {'x1': 0.41, 'x2': 0.9, 'x3': 0.22, '...
33 1 BoTorch 21 COMPLETED {'21_0': {'x1': 0.42, 'x2': 0.85, 'x3': 0.26, ...
34 1 BoTorch 22 COMPLETED {'22_0': {'x1': 0.39, 'x2': 0.83, 'x3': 0.18, ...
35 1 BoTorch 23 COMPLETED {'23_0': {'x1': 0.42, 'x2': 0.9, 'x3': 0.24, '...
36 1 BoTorch 24 COMPLETED {'24_0': {'x1': 0.41, 'x2': 0.87, 'x3': 0.27, ...

5. Retrieve best parameters

Once it’s complete, we can access the best parameters found, as well as the corresponding metric values.

best_parameters, values = ax_client.get_best_parameters()
best_parameters
{'x1': 0.40817620096757834,
 'x2': 0.8731657716481355,
 'x3': 0.27026868276996546,
 'x4': 0.5723680198802129,
 'x5': 0.27955297611077395,
 'x6': 0.055291866177425805}
means, covariances = values
means
{'hartmann6': -3.1273235118485294, 'l2norm': 1.1878074124227307}

For comparison, Hartmann6 minimum:

hartmann6.fmin
-3.32237

6. Plot the response surface and optimization trace

Here we arbitrarily select “x1” and “x2” as the two parameters to plot for both metrics, “hartmann6” and “l2norm”.

render(ax_client.get_contour_plot())
[INFO 01-05 19:07:42] ax.service.ax_client: Retrieving contour plot with parameter 'x1' on X-axis and 'x2' on Y-axis, for metric 'hartmann6'. Remaining parameters are affixed to the middle of their range.

We can also retrieve a contour plot for the other metric, “l2norm” –– say, we are interested in seeing the response surface for parameters “x3” and “x4” for this one.

render(ax_client.get_contour_plot(param_x="x3", param_y="x4", metric_name="l2norm"))
[INFO 01-05 19:07:45] ax.service.ax_client: Retrieving contour plot with parameter 'x3' on X-axis and 'x4' on Y-axis, for metric 'l2norm'. Remaining parameters are affixed to the middle of their range.

Here we plot the optimization trace, showing the progression of finding the point with the optimal objective:

# Objective_optimum is optional.
trace = ax_client.get_optimization_trace(objective_optimum=hartmann6.fmin)
render(trace)
c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\numpy\core\_methods.py:173: RuntimeWarning:

invalid value encountered in subtract

c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\numpy\lib\function_base.py:4655: RuntimeWarning:

invalid value encountered in subtract

7. Save / reload optimization to JSON / SQL

We can serialize the state of optimization to JSON and save it to a .json file or save it to the SQL backend. For the former:

ax_client.save_to_json_file()  # For custom filepath, pass `filepath` argument.
[INFO 01-05 19:07:46] ax.service.ax_client: Saved JSON-serialized state of optimization to `ax_client_snapshot.json`.
restored_ax_client = (
    AxClient.load_from_json_file()
)  # For custom filepath, pass `filepath` argument.
c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\ax\core\data.py:203: FutureWarning:

Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.

c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\ax\core\data.py:203: FutureWarning:

Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.

c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\ax\core\data.py:203: FutureWarning:

Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.

c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\ax\core\data.py:203: FutureWarning:

Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.

c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\ax\core\data.py:203: FutureWarning:

Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.

c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\ax\core\data.py:203: FutureWarning:

Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.

c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\ax\core\data.py:203: FutureWarning:

Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.

c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\ax\core\data.py:203: FutureWarning:

Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.

c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\ax\core\data.py:203: FutureWarning:

Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.

c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\ax\core\data.py:203: FutureWarning:

Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.

c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\ax\core\data.py:203: FutureWarning:

Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.

c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\ax\core\data.py:203: FutureWarning:

Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.

c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\ax\core\data.py:203: FutureWarning:

Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.

c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\ax\core\data.py:203: FutureWarning:

Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.

c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\ax\core\data.py:203: FutureWarning:

Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.

c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\ax\core\data.py:203: FutureWarning:

Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.

c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\ax\core\data.py:203: FutureWarning:

Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.

c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\ax\core\data.py:203: FutureWarning:

Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.

c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\ax\core\data.py:203: FutureWarning:

Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.

c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\ax\core\data.py:203: FutureWarning:

Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.

c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\ax\core\data.py:203: FutureWarning:

Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.

c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\ax\core\data.py:203: FutureWarning:

Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.

c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\ax\core\data.py:203: FutureWarning:

Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.

c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\ax\core\data.py:203: FutureWarning:

Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.

c:\Users\sterg\miniconda3\envs\ac-microcourses\Lib\site-packages\ax\core\data.py:203: FutureWarning:

Passing literal json to 'read_json' is deprecated and will be removed in a future version. To read from a literal string, wrap it in a 'StringIO' object.

[INFO 01-05 19:07:46] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 6 decimal points.

To store state of optimization to an SQL backend, first follow setup instructions on Ax website.

Having set up the SQL backend, pass DBSettings to AxClient on instantiation (note that SQLAlchemy dependency will have to be installed – for installation, refer to optional dependencies on Ax website):

# NOTE: Requires running `pip install ax-platform[mysql]` and setting up a database per instructions above (out of scope for this tutorial)

# from ax.storage.sqa_store.structs import DBSettings

# # URL is of the form "dialect+driver://username:password@host:port/database".
# db_settings = DBSettings(url="sqlite:///foo.db")
# # Instead of URL, can provide a `creator function`; can specify custom encoders/decoders if necessary.
# new_ax = AxClient(db_settings=db_settings)

When valid DBSettings are passed into AxClient, a unique experiment name is a required argument (name) to ax_client.create_experiment. The state of the optimization is auto-saved any time it changes (i.e. a new trial is added or completed, etc).

To reload an optimization state later, instantiate AxClient with the same DBSettings and use ax_client.load_experiment_from_database(experiment_name="my_experiment").

Special Cases

Evaluation failure: should any optimization iterations fail during evaluation, log_trial_failure will ensure that the same trial is not proposed again.

_, trial_index = ax_client.get_next_trial()
ax_client.log_trial_failure(trial_index=trial_index)
[INFO 01-05 19:08:04] ax.service.ax_client: Generated new trial 25 with parameters {'x1': 0.407242, 'x2': 0.873801, 'x3': 0.315026, 'x4': 0.568915, 'x5': 0.315262, 'x6': 0.015128}.
[INFO 01-05 19:08:04] ax.service.ax_client: Registered failure of trial 25.

Adding custom trials: should there be need to evaluate a specific parameterization, attach_trial will add it to the experiment.

ax_client.attach_trial(
    parameters={"x1": 0.9, "x2": 0.9, "x3": 0.9, "x4": 0.9, "x5": 0.9, "x6": 0.9}
)
[INFO 01-05 19:08:04] ax.core.experiment: Attached custom parameterizations [{'x1': 0.9, 'x2': 0.9, 'x3': 0.9, 'x4': 0.9, 'x5': 0.9, 'x6': 0.9}] as trial 26.
({'x1': 0.9, 'x2': 0.9, 'x3': 0.9, 'x4': 0.9, 'x5': 0.9, 'x6': 0.9}, 26)

Need to run many trials in parallel: for optimal results and optimization efficiency, we strongly recommend sequential optimization (generating a few trials, then waiting for them to be completed with evaluation data). However, if your use case needs to dispatch many trials in parallel before they are updated with data and you are running into the “All trials for current model have been generated, but not enough data has been observed to fit next model” error, instantiate AxClient as AxClient(enforce_sequential_optimization=False).