Skip to content

Optimize

inspeqtor.optimize

inspeqtor.optimize.get_default_optimizer

get_default_optimizer(
    n_iterations: int,
) -> GradientTransformation

Generate present optimizer from number of training iteration.

Parameters:

Name Type Description Default
n_iterations int

Training iteration

required

Returns:

Type Description
GradientTransformation

optax.GradientTransformation: Optax optimizer.

Source code in src/inspeqtor/v1/optimize.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def get_default_optimizer(n_iterations: int) -> optax.GradientTransformation:
    """Generate present optimizer from number of training iteration.

    Args:
        n_iterations (int): Training iteration

    Returns:
        optax.GradientTransformation: Optax optimizer.
    """
    return optax.adamw(
        learning_rate=optax.warmup_cosine_decay_schedule(
            init_value=1e-6,
            peak_value=1e-2,
            warmup_steps=int(0.1 * n_iterations),
            decay_steps=n_iterations,
            end_value=1e-6,
        )
    )

inspeqtor.optimize.minimize

minimize(
    params: ArrayTree,
    func: Callable[[ndarray], tuple[ndarray, Any]],
    optimizer: GradientTransformation,
    lower: ArrayTree | None = None,
    upper: ArrayTree | None = None,
    maxiter: int = 1000,
    callbacks: list[Callable] = [],
) -> tuple[ArrayTree, list[Any]]

Optimize the loss function with bounded parameters.

Parameters:

Name Type Description Default
params ArrayTree

Intiial parameters to be optimized

required
lower ArrayTree

Lower bound of the parameters

None
upper ArrayTree

Upper bound of the parameters

None
func Callable[[ndarray], tuple[ndarray, Any]]

Loss function

required
optimizer GradientTransformation

Instance of optax optimizer

required
maxiter int

Number of optimization step. Defaults to 1000.

1000

Returns:

Type Description
tuple[ArrayTree, list[Any]]

tuple[chex.ArrayTree, list[typing.Any]]: Tuple of parameters and optimization history

Source code in src/inspeqtor/v1/optimize.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def minimize(
    params: chex.ArrayTree,
    func: typing.Callable[[jnp.ndarray], tuple[jnp.ndarray, typing.Any]],
    optimizer: optax.GradientTransformation,
    lower: chex.ArrayTree | None = None,
    upper: chex.ArrayTree | None = None,
    maxiter: int = 1000,
    callbacks: list[typing.Callable] = [],
) -> tuple[chex.ArrayTree, list[typing.Any]]:
    """Optimize the loss function with bounded parameters.

    Args:
        params (chex.ArrayTree): Intiial parameters to be optimized
        lower (chex.ArrayTree): Lower bound of the parameters
        upper (chex.ArrayTree): Upper bound of the parameters
        func (typing.Callable[[jnp.ndarray], tuple[jnp.ndarray, typing.Any]]): Loss function
        optimizer (optax.GradientTransformation): Instance of optax optimizer
        maxiter (int, optional): Number of optimization step. Defaults to 1000.

    Returns:
        tuple[chex.ArrayTree, list[typing.Any]]: Tuple of parameters and optimization history
    """
    opt_state = optimizer.init(params)
    history = []

    for step_idx in range(maxiter):
        grads, aux = jax.grad(func, has_aux=True)(params)
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)

        if lower is not None and upper is not None:
            # Apply projection
            params = optax.projections.projection_box(params, lower, upper)

        # Log the history
        aux["params"] = params
        history.append(aux)

        for callback in callbacks:
            callback(step_idx, aux)

    return params, history

inspeqtor.optimize.stochastic_minimize

stochastic_minimize(
    key: ndarray,
    params: ArrayTree,
    func: Callable[[ndarray, ndarray], tuple[ndarray, Any]],
    optimizer: GradientTransformation,
    lower: ArrayTree | None = None,
    upper: ArrayTree | None = None,
    maxiter: int = 1000,
    callbacks: list[Callable] = [],
) -> tuple[ArrayTree, list[Any]]

Optimize the loss function with bounded parameters.

Parameters:

Name Type Description Default
params ArrayTree

Intiial parameters to be optimized

required
lower ArrayTree

Lower bound of the parameters

None
upper ArrayTree

Upper bound of the parameters

None
func Callable[[ndarray], tuple[ndarray, Any]]

Loss function

required
optimizer GradientTransformation

Instance of optax optimizer

required
maxiter int

Number of optimization step. Defaults to 1000.

1000

Returns:

Type Description
tuple[ArrayTree, list[Any]]

tuple[chex.ArrayTree, list[typing.Any]]: Tuple of parameters and optimization history

Source code in src/inspeqtor/v1/optimize.py
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def stochastic_minimize(
    key: jnp.ndarray,
    params: chex.ArrayTree,
    func: typing.Callable[[jnp.ndarray, jnp.ndarray], tuple[jnp.ndarray, typing.Any]],
    optimizer: optax.GradientTransformation,
    lower: chex.ArrayTree | None = None,
    upper: chex.ArrayTree | None = None,
    maxiter: int = 1000,
    callbacks: list[typing.Callable] = [],
) -> tuple[chex.ArrayTree, list[typing.Any]]:
    """Optimize the loss function with bounded parameters.

    Args:
        params (chex.ArrayTree): Intiial parameters to be optimized
        lower (chex.ArrayTree): Lower bound of the parameters
        upper (chex.ArrayTree): Upper bound of the parameters
        func (typing.Callable[[jnp.ndarray], tuple[jnp.ndarray, typing.Any]]): Loss function
        optimizer (optax.GradientTransformation): Instance of optax optimizer
        maxiter (int, optional): Number of optimization step. Defaults to 1000.

    Returns:
        tuple[chex.ArrayTree, list[typing.Any]]: Tuple of parameters and optimization history
    """
    opt_state = optimizer.init(params)
    history = []

    for step_idx in range(maxiter):
        key, _ = jax.random.split(key)
        grads, aux = jax.grad(func, has_aux=True)(params, key)
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)

        if lower is not None and upper is not None:
            # Apply projection
            params = optax.projections.projection_box(params, lower, upper)

        # Log the history
        aux["params"] = params
        history.append(aux)

        for callback in callbacks:
            callback(step_idx, aux)

    return params, history

inspeqtor.optimize.fit_gaussian_process

fit_gaussian_process(D: Dataset)

Fit the Gaussian process given an instance of Dataset

Parameters:

Name Type Description Default
D Dataset

The gpx.Dataset instance

required

Returns:

Type Description

tuple[]: description

Source code in src/inspeqtor/v2/optimize.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def fit_gaussian_process(D: gpx.Dataset):
    """Fit the Gaussian process given an instance of Dataset

    Args:
        D (gpx.Dataset): The `gpx.Dataset` instance

    Returns:
        tuple[]: _description_
    """
    kernel = gpx.kernels.RBF()  # 1-dimensional input
    meanf = gpx.mean_functions.Zero()
    prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)

    likelihood = gpx.likelihoods.Gaussian(num_datapoints=D.n)

    posterior = prior * likelihood

    opt_posterior, history = gpx.fit_scipy(
        model=posterior,
        objective=lambda p, d: -gpx.objectives.conjugate_mll(p, d),  # type: ignore
        train_data=D,
        trainable=gpx.parameters.Parameter,
        verbose=True,
    )
    return opt_posterior, history

inspeqtor.optimize.predict_with_gaussian_process

predict_with_gaussian_process(
    x, posterior: ConjugatePosterior, D: Dataset
) -> tuple[ndarray, ndarray]
Source code in src/inspeqtor/v2/optimize.py
36
37
38
39
40
41
42
43
44
def predict_with_gaussian_process(
    x, posterior: gpx.gps.ConjugatePosterior, D: gpx.Dataset
) -> tuple[jnp.ndarray, jnp.ndarray]:
    latent_dist = posterior.predict(x, train_data=D)
    predictive_dist = posterior.likelihood(latent_dist)
    # For the Gaussian process, only mean and variance should be enough?
    predictive_mean = predictive_dist.mean
    predictive_std = jnp.sqrt(predictive_dist.variance)
    return predictive_mean, predictive_std

inspeqtor.optimize.predict_mean_and_std

predict_mean_and_std(
    x: ndarray, D: Dataset
) -> tuple[ndarray, ndarray]

Predict a Gaussian distribution to the given x using the dataset D

Parameters:

Name Type Description Default
x ndarray

The array of points to evaluate the gaussian process.

required
D Dataset

The dataset contain observation from the real process.

required

Returns:

Type Description
tuple[ndarray, ndarray]

tuple[jnp.ndarray, jnp.ndarray]: The array of mean and standard deviation of the Gaussian process at ponits x.

Source code in src/inspeqtor/v2/optimize.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def predict_mean_and_std(
    x: jnp.ndarray, D: gpx.Dataset
) -> tuple[jnp.ndarray, jnp.ndarray]:
    """Predict a Gaussian distribution to the given `x` using the dataset `D`

    Args:
        x (jnp.ndarray): The array of points to evaluate the gaussian process.
        D (gpx.Dataset): The dataset contain observation from the real process.

    Returns:
        tuple[jnp.ndarray, jnp.ndarray]: The array of mean and standard deviation of the Gaussian process at ponits `x`.
    """
    opt_posterior, _ = fit_gaussian_process(D)

    return predict_with_gaussian_process(x, opt_posterior, D)

inspeqtor.optimize.expected_improvement

expected_improvement(
    y_best: ndarray,
    posterior_mean: ndarray,
    posterior_var: ndarray,
    exploration_factor: float,
) -> ndarray

The expected improvement calculated using posterior mean and variance of the gaussian process. The exploration factor can be adjust to balance between exploration and exploitation.

Parameters:

Name Type Description Default
y_best ndarray

The current maximum value of y

required
posterior_mean ndarray

The posterior mean of the gaussian process

required
posterior_var ndarray

The posterior variance of the gaussian process

required
exploration_factor float

The factor that balance between exploration and exploitation. Set to 0. to maximize exploitation.

required

Returns:

Type Description
ndarray

jnp.ndarray: The expeced improvement corresponding to the points given from array of the posterior.

Source code in src/inspeqtor/v2/optimize.py
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def expected_improvement(
    y_best: jnp.ndarray,
    posterior_mean: jnp.ndarray,
    posterior_var: jnp.ndarray,
    exploration_factor: float,
) -> jnp.ndarray:
    """The expected improvement calculated using posterior mean and variance of the gaussian process.
    The exploration factor can be adjust to balance between exploration and exploitation.


    Args:
        y_best (jnp.ndarray): The current maximum value of y
        posterior_mean (jnp.ndarray): The posterior mean of the gaussian process
        posterior_var (jnp.ndarray): The posterior variance of the gaussian process
        exploration_factor (float): The factor that balance between exploration and exploitation. Set to 0. to maximize exploitation.

    Returns:
        jnp.ndarray: The expeced improvement corresponding to the points given from array of the posterior.
    """
    # https://github.com/alonfnt/bayex/blob/main/bayex/acq.py
    std = jnp.sqrt(posterior_var)
    a = posterior_mean - y_best - exploration_factor
    z = a / std

    return a * norm.cdf(z) + std * norm.pdf(z)

inspeqtor.optimize.BayesOptState

The dataclass holding optimization state for the gaussian process.

Source code in src/inspeqtor/v2/optimize.py
91
92
93
94
95
96
@struct.dataclass
class BayesOptState:
    """The dataclass holding optimization state for the gaussian process."""

    dataset: gpx.Dataset
    control: ControlSequence

inspeqtor.optimize.init_opt_state

init_opt_state(x, y, control) -> BayesOptState

Function to intialize the optimizer

Parameters:

Name Type Description Default
x ndarray

The input arguments

required
y ndarray

The observation corresponding to the input x

required
control _type_

The intance of control sequence.

required

Returns:

Name Type Description
BayesOptState BayesOptState

The state of optimizer.

Source code in src/inspeqtor/v2/optimize.py
 99
100
101
102
103
104
105
106
107
108
109
110
def init_opt_state(x, y, control) -> BayesOptState:
    """Function to intialize the optimizer

    Args:
        x (jnp.ndarray): The input arguments
        y (jnp.ndarray): The observation corresponding to the input `x`
        control (_type_): The intance of control sequence.

    Returns:
        BayesOptState: The state of optimizer.
    """
    return BayesOptState(dataset=gpx.Dataset(X=x, y=y), control=control)

inspeqtor.optimize.suggest_next_candidates

suggest_next_candidates(
    key: ndarray,
    opt_state: BayesOptState,
    sample_size: int = 1000,
    num_suggest: int = 1,
    exploration_factor: float = 0.0,
) -> ndarray

Sample new candidates for experiment using expected improvement.

Parameters:

Name Type Description Default
key ndarray

The jax random key

required
opt_state BayesOptState

The current optimizer state

required
sample_size int

The internal number of sample size. Defaults to 1000.

1000
num_suggest int

The number of suggestion for next experiment. Defaults to 1.

1
exploration_factor float

The factor that balance between exploration and exploitation. Set to 0. to maximize exploitation. Defaults to 0.0.

0.0

Returns:

Type Description
ndarray

jnp.ndarray: The suggest data points to evalute in the experiment.

Source code in src/inspeqtor/v2/optimize.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
def suggest_next_candidates(
    key: jnp.ndarray,
    opt_state: BayesOptState,
    sample_size: int = 1000,
    num_suggest: int = 1,
    exploration_factor: float = 0.0,
) -> jnp.ndarray:
    """Sample new candidates for experiment using expected improvement.

    Args:
        key (jnp.ndarray): The jax random key
        opt_state (BayesOptState): The current optimizer state
        sample_size (int, optional): The internal number of sample size. Defaults to 1000.
        num_suggest (int, optional): The number of suggestion for next experiment. Defaults to 1.
        exploration_factor (float, optional): The factor that balance between exploration and exploitation. Set to 0. to maximize exploitation. Defaults to 0.0.

    Returns:
        jnp.ndarray: The suggest data points to evalute in the experiment.
    """
    y = opt_state.dataset.y
    assert isinstance(y, jnp.ndarray)
    y_best = jnp.max(y)

    ravel_fn, unravel_fn = ravel_unravel_fn(opt_state.control.get_structure())
    params = jax.vmap(opt_state.control.sample_params)(
        jax.random.split(key, sample_size)
    )
    # In shape of (sample_size, ctrl_feature)
    ravel_param = jax.vmap(ravel_fn)(params)

    mean, variance = predict_mean_and_std(ravel_param, opt_state.dataset)

    ei = expected_improvement(
        y_best, mean, variance, exploration_factor=exploration_factor
    )

    selected_indice = jnp.argsort(ei, descending=True)[:num_suggest]

    return ravel_param[selected_indice]

inspeqtor.optimize.add_observations

add_observations(
    opt_state: BayesOptState, x, y
) -> BayesOptState

Function to update the optimization state using new data points x and y

Parameters:

Name Type Description Default
opt_state BayesOptState

The current optimization state

required
x ndarray

The input arguments

required
y ndarray

The observation corresponding to the input x

required

Returns:

Name Type Description
BayesOptState BayesOptState

The updated optimization state.

Source code in src/inspeqtor/v2/optimize.py
154
155
156
157
158
159
160
161
162
163
164
165
166
167
def add_observations(opt_state: BayesOptState, x, y) -> BayesOptState:
    """Function to update the optimization state using new data points `x` and `y`

    Args:
        opt_state (BayesOptState): The current optimization state
        x (jnp.ndarray): The input arguments
        y (jnp.ndarray): The observation corresponding to the input `x`

    Returns:
        BayesOptState: The updated optimization state.
    """
    return BayesOptState(
        dataset=opt_state.dataset + gpx.Dataset(X=x, y=y), control=opt_state.control
    )