How to convert DNN into BNN¶
Goal
This notebook aims to showcase how can you convert a statistical deep neural network that predict a point into a bayesian neural network that predict the distribution instead.
Let us start by importing libraries.
import jax
import jax.numpy as jnp
import inspeqtor as sq
Get synthetic dataset ready 🚀¶
Here we are working with synthetic dataset. So, we need to define a simulator, perform an experiment, prepare the dataset for model training/inference, benchmarking. Luckily, inspeqtor provide a serveral helper functions and predefined noise model for user to quickly get stuff setting up.
def get_data():
# This is the predefined noise model that we are going to work with.
data_model = sq.data.library.get_predefined_data_model_m1()
# Now, we use the noise model to performing the data using simulator.
sample_size = 100
exp_data, _, _, _ = sq.data.library.generate_single_qubit_experimental_data(
key=jax.random.key(0),
hamiltonian=data_model.total_hamiltonian,
sample_size=sample_size,
strategy=sq.physics.library.SimulationStrategy.SHOT,
qubit_inforamtion=data_model.qubit_information,
control_sequence=data_model.control_sequence,
method=sq.physics.library.WhiteboxStrategy.TROTTER,
trotter_steps=10_000,
)
# Now we can prepare the dataset that ready to use.
whitebox = sq.physics.make_trotterization_solver(
data_model.ideal_hamiltonian,
data_model.control_sequence.total_dt,
data_model.dt,
trotter_steps=10_000,
y0=jnp.eye(2, dtype=jnp.complex128),
)
loaded_data = sq.data.prepare_data(exp_data, data_model.control_sequence, whitebox)
# Here, we just bundling things up for convinience uses.
key = jax.random.key(0)
key, random_split_key = jax.random.split(key)
(
train_pulse_parameters,
train_unitaries,
train_expectation_values,
test_pulse_parameters,
test_unitaries,
test_expectation_values,
) = sq.utils.random_split(
random_split_key,
int(loaded_data.control_parameters.shape[0] * 0.1), # Test size
loaded_data.control_parameters,
loaded_data.unitaries,
loaded_data.observed_values,
)
shots = loaded_data.experiment_data.config.shots
train_binaries = sq.utils.eigenvalue_to_binary(
sq.utils.expectation_value_to_eigenvalue(train_expectation_values, shots)
)
train_binaries = jnp.swapaxes(jnp.swapaxes(train_binaries, 1, 2), 0, 1)
test_binaries = sq.utils.eigenvalue_to_binary(
sq.utils.expectation_value_to_eigenvalue(test_expectation_values, shots)
)
test_binaries = jnp.swapaxes(jnp.swapaxes(test_binaries, 1, 2), 0, 1)
assert train_binaries.shape == (shots, train_pulse_parameters.shape[0], 18)
assert test_binaries.shape == (shots, test_pulse_parameters.shape[0], 18)
train_data = sq.data.DataBundled(
control_params=sq.control.library.drag_feature_map(train_pulse_parameters),
unitaries=train_unitaries,
observables=train_binaries,
aux=train_expectation_values,
)
test_data = sq.data.DataBundled(
control_params=sq.control.library.drag_feature_map(test_pulse_parameters),
unitaries=test_unitaries,
observables=test_binaries,
aux=test_expectation_values,
)
# Return data ready to use.
return data_model, loaded_data, train_data, test_data
data_model, loaded_data, train_data, test_data = get_data()
Routes to convert DNN model to BNN model¶
inspeqtor provides serveral ways to convert DNN to BNN and user can also create the BNN from scratch too! For the predefined models or model defined using flax, numpyro.contrib.module implemented the function that transform the statistical model to probabilistic model with ease. Thus, inspeqtor provides a wrapper function make_flax_probabilistic_graybox_model that help convet user defined model into a proper probabilistic Graybox model. Here are examples of how to do it.
For the linen models, user has to use random_flax_module for a flax_module argument to make_flax_probabilistic_graybox_model function. Similar to models version, the adapter function that transform model's output into expectation values has to be use appropiately.
from numpyro.contrib.module import random_flax_module
base_model, adapter_fn, flax_module = (
sq.models.library.linen.UnitaryModel([10, 10]),
sq.models.adapter.toggling_unitary_to_expvals,
random_flax_module,
)
base_model, adapter_fn, flax_module = (
sq.models.library.linen.WoModel([5], [5]),
sq.models.adapter.observable_to_expvals,
random_flax_module,
)
from flax import nnx
from numpyro.contrib.module import random_nnx_module
base_model, adapter_fn, flax_module = (
sq.models.library.nnx.UnitaryModel([8, 8], rngs=nnx.Rngs(0)),
sq.models.adapter.toggling_unitary_to_expvals,
random_nnx_module,
)
base_model, adapter_fn, flax_module = (
sq.models.library.nnx.WoModel([8, 4, 6], [6, 4, 5], rngs=nnx.Rngs(0)),
sq.models.adapter.observable_to_expvals,
random_nnx_module,
)
Finally, we can define graybox model with the choice of the DNN model. For the custom linen and nnx models, you have to define a corresponding adapter_fn as well.
graybox_model = sq.models.probabilistic.make_flax_probabilistic_graybox_model(
name="graybox",
base_model=base_model,
adapter_fn=adapter_fn,
prior=sq.models.probabilistic.dist.Normal(0, 1),
flax_module=flax_module,
)
nnx.display(base_model)
In the case that you need a completely in control of the model behavior. You can define the probabilistic mdoel from scratch too. You can do this using our primitive bayesian neural network components that design to be compatible with numpyro.
Below is $\hat{W}_{O}$-based model defined from scratch to mirror the flax implementation. Note that we use sq.probabilistic.dense_layer for our mathematical operation. The mental model of defining probabilistic model using our primitive as example below is that you can define it as operating on point prediction, while numpyro will handle the distribution part for you.
base_model = sq.models.probabilistic.WoModel("graybox", (5,), (5,))
adapter_fn = sq.models.adapter.observable_to_expvals
graybox_model = sq.models.probabilistic.make_probabilistic_graybox_model(
base_model, adapter_fn
)
You can inspect the model using sq.probabilistic.get_trace. Below, we visualize the trace of model using nnx.display. Note that get_trace handle the random key for you under the hood. User can suppliment their own key if desire.
from flax import nnx
nnx.display(
sq.models.probabilistic.get_trace(graybox_model)(
test_data.control_params, test_data.unitaries
)
)
Stochastic Variational Inference of BNN¶
Let import the goodies to use.
import numpyro
from numpyro.infer import (
SVI,
TraceMeanField_ELBO,
)
from alive_progress import alive_it
You will see below that probabilistic model (model variable) accept graybox which can be defined from multiple ways demonstrated previously. In the following code snippet, we also use custom guide and custom training loop to demonstrate the flexibility.
model = sq.models.probabilistic.make_probabilistic_model(
predictive_model=graybox_model,
)
guide = sq.models.probabilistic.auto_diagonal_normal_guide_v3(
model,
train_data.control_params,
train_data.unitaries,
train_data.observables,
init_dist_fn=sq.models.probabilistic.bnn_init_dist_fn,
init_params_fn=sq.models.probabilistic.bnn_init_params_fn,
)
NUM_STEPS = 10_000
optimizer = sq.optimize.get_default_optimizer(NUM_STEPS)
svi = SVI(
model=model,
guide=guide,
optim=numpyro.optim.optax_to_numpyro(optimizer),
loss=TraceMeanField_ELBO(),
)
svi_state = svi.init(
rng_key=jax.random.key(0),
control_parameters=train_data.control_params,
unitaries=train_data.unitaries,
observables=train_data.observables,
)
update_fn = sq.models.probabilistic.make_update_fn(
svi,
control_parameters=train_data.control_params,
unitaries=train_data.unitaries,
observables=train_data.observables,
)
eval_fn = sq.models.probabilistic.make_evaluate_fn(
svi,
control_parameters=test_data.control_params,
unitaries=test_data.unitaries,
observables=test_data.observables,
)
eval_losses = []
losses = []
for i in alive_it(range(NUM_STEPS), force_tty=True):
svi_state, loss = jax.jit(update_fn)(svi_state)
eval_loss = jax.jit(eval_fn)(svi_state)
losses.append(loss)
eval_losses.append(eval_loss)
svi_result = sq.models.probabilistic.SVIRunResult(
svi.get_params(svi_state), svi_state, jnp.stack(losses), jnp.stack(eval_losses)
)
|████████████████████████████████████████| 10000/10000 [100%] in 41.4s (241.31/s
sq.models.probabilistic.get_trace(guide)()
OrderedDict([('graybox/shared.dense_0.kernel_loc',
{'type': 'param',
'name': 'graybox/shared.dense_0.kernel_loc',
'fn': <function numpyro.util.identity(x, *args, **kwargs)>,
'args': (Array([[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]], dtype=float64),),
'kwargs': {},
'value': Array([[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]], dtype=float64),
'scale': None,
'cond_indep_stack': []}),
('graybox/shared.dense_0.bias_loc',
{'type': 'param',
'name': 'graybox/shared.dense_0.bias_loc',
'fn': <function numpyro.util.identity(x, *args, **kwargs)>,
'args': (Array([0., 0., 0., 0., 0.], dtype=float64),),
'kwargs': {},
'value': Array([0., 0., 0., 0., 0.], dtype=float64),
'scale': None,
'cond_indep_stack': []}),
('graybox/pauli_X.dense_0.kernel_loc',
{'type': 'param',
'name': 'graybox/pauli_X.dense_0.kernel_loc',
'fn': <function numpyro.util.identity(x, *args, **kwargs)>,
'args': (Array([[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]], dtype=float64),),
'kwargs': {},
'value': Array([[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]], dtype=float64),
'scale': None,
'cond_indep_stack': []}),
('graybox/pauli_X.dense_0.bias_loc',
{'type': 'param',
'name': 'graybox/pauli_X.dense_0.bias_loc',
'fn': <function numpyro.util.identity(x, *args, **kwargs)>,
'args': (Array([0., 0., 0., 0., 0.], dtype=float64),),
'kwargs': {},
'value': Array([0., 0., 0., 0., 0.], dtype=float64),
'scale': None,
'cond_indep_stack': []}),
('graybox/U_X.kernel_loc',
{'type': 'param',
'name': 'graybox/U_X.kernel_loc',
'fn': <function numpyro.util.identity(x, *args, **kwargs)>,
'args': (Array([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]], dtype=float64),),
'kwargs': {},
'value': Array([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]], dtype=float64),
'scale': None,
'cond_indep_stack': []}),
('graybox/U_X.bias_loc',
{'type': 'param',
'name': 'graybox/U_X.bias_loc',
'fn': <function numpyro.util.identity(x, *args, **kwargs)>,
'args': (Array([0., 0., 0.], dtype=float64),),
'kwargs': {},
'value': Array([0., 0., 0.], dtype=float64),
'scale': None,
'cond_indep_stack': []}),
('graybox/D_X.kernel_loc',
{'type': 'param',
'name': 'graybox/D_X.kernel_loc',
'fn': <function numpyro.util.identity(x, *args, **kwargs)>,
'args': (Array([[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.]], dtype=float64),),
'kwargs': {},
'value': Array([[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.]], dtype=float64),
'scale': None,
'cond_indep_stack': []}),
('graybox/D_X.bias_loc',
{'type': 'param',
'name': 'graybox/D_X.bias_loc',
'fn': <function numpyro.util.identity(x, *args, **kwargs)>,
'args': (Array([0., 0.], dtype=float64),),
'kwargs': {},
'value': Array([0., 0.], dtype=float64),
'scale': None,
'cond_indep_stack': []}),
('graybox/pauli_Y.dense_0.kernel_loc',
{'type': 'param',
'name': 'graybox/pauli_Y.dense_0.kernel_loc',
'fn': <function numpyro.util.identity(x, *args, **kwargs)>,
'args': (Array([[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]], dtype=float64),),
'kwargs': {},
'value': Array([[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]], dtype=float64),
'scale': None,
'cond_indep_stack': []}),
('graybox/pauli_Y.dense_0.bias_loc',
{'type': 'param',
'name': 'graybox/pauli_Y.dense_0.bias_loc',
'fn': <function numpyro.util.identity(x, *args, **kwargs)>,
'args': (Array([0., 0., 0., 0., 0.], dtype=float64),),
'kwargs': {},
'value': Array([0., 0., 0., 0., 0.], dtype=float64),
'scale': None,
'cond_indep_stack': []}),
('graybox/U_Y.kernel_loc',
{'type': 'param',
'name': 'graybox/U_Y.kernel_loc',
'fn': <function numpyro.util.identity(x, *args, **kwargs)>,
'args': (Array([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]], dtype=float64),),
'kwargs': {},
'value': Array([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]], dtype=float64),
'scale': None,
'cond_indep_stack': []}),
('graybox/U_Y.bias_loc',
{'type': 'param',
'name': 'graybox/U_Y.bias_loc',
'fn': <function numpyro.util.identity(x, *args, **kwargs)>,
'args': (Array([0., 0., 0.], dtype=float64),),
'kwargs': {},
'value': Array([0., 0., 0.], dtype=float64),
'scale': None,
'cond_indep_stack': []}),
('graybox/D_Y.kernel_loc',
{'type': 'param',
'name': 'graybox/D_Y.kernel_loc',
'fn': <function numpyro.util.identity(x, *args, **kwargs)>,
'args': (Array([[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.]], dtype=float64),),
'kwargs': {},
'value': Array([[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.]], dtype=float64),
'scale': None,
'cond_indep_stack': []}),
('graybox/D_Y.bias_loc',
{'type': 'param',
'name': 'graybox/D_Y.bias_loc',
'fn': <function numpyro.util.identity(x, *args, **kwargs)>,
'args': (Array([0., 0.], dtype=float64),),
'kwargs': {},
'value': Array([0., 0.], dtype=float64),
'scale': None,
'cond_indep_stack': []}),
('graybox/pauli_Z.dense_0.kernel_loc',
{'type': 'param',
'name': 'graybox/pauli_Z.dense_0.kernel_loc',
'fn': <function numpyro.util.identity(x, *args, **kwargs)>,
'args': (Array([[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]], dtype=float64),),
'kwargs': {},
'value': Array([[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]], dtype=float64),
'scale': None,
'cond_indep_stack': []}),
('graybox/pauli_Z.dense_0.bias_loc',
{'type': 'param',
'name': 'graybox/pauli_Z.dense_0.bias_loc',
'fn': <function numpyro.util.identity(x, *args, **kwargs)>,
'args': (Array([0., 0., 0., 0., 0.], dtype=float64),),
'kwargs': {},
'value': Array([0., 0., 0., 0., 0.], dtype=float64),
'scale': None,
'cond_indep_stack': []}),
('graybox/U_Z.kernel_loc',
{'type': 'param',
'name': 'graybox/U_Z.kernel_loc',
'fn': <function numpyro.util.identity(x, *args, **kwargs)>,
'args': (Array([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]], dtype=float64),),
'kwargs': {},
'value': Array([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]], dtype=float64),
'scale': None,
'cond_indep_stack': []}),
('graybox/U_Z.bias_loc',
{'type': 'param',
'name': 'graybox/U_Z.bias_loc',
'fn': <function numpyro.util.identity(x, *args, **kwargs)>,
'args': (Array([0., 0., 0.], dtype=float64),),
'kwargs': {},
'value': Array([0., 0., 0.], dtype=float64),
'scale': None,
'cond_indep_stack': []}),
('graybox/D_Z.kernel_loc',
{'type': 'param',
'name': 'graybox/D_Z.kernel_loc',
'fn': <function numpyro.util.identity(x, *args, **kwargs)>,
'args': (Array([[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.]], dtype=float64),),
'kwargs': {},
'value': Array([[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.],
[0., 0.]], dtype=float64),
'scale': None,
'cond_indep_stack': []}),
('graybox/D_Z.bias_loc',
{'type': 'param',
'name': 'graybox/D_Z.bias_loc',
'fn': <function numpyro.util.identity(x, *args, **kwargs)>,
'args': (Array([0., 0.], dtype=float64),),
'kwargs': {},
'value': Array([0., 0.], dtype=float64),
'scale': None,
'cond_indep_stack': []}),
('graybox/shared.dense_0.kernel_scale',
{'type': 'param',
'name': 'graybox/shared.dense_0.kernel_scale',
'fn': <function numpyro.util.identity(x, *args, **kwargs)>,
'args': (Array([[0.1, 0.1, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, 0.1, 0.1]], dtype=float64),),
'kwargs': {'constraint': SoftplusPositive(lower_bound=0.0)},
'value': Array([[0.1, 0.1, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, 0.1, 0.1]], dtype=float64),
'scale': None,
'cond_indep_stack': []}),
('graybox/shared.dense_0.bias_scale',
{'type': 'param',
'name': 'graybox/shared.dense_0.bias_scale',
'fn': <function numpyro.util.identity(x, *args, **kwargs)>,
'args': (Array([0.1, 0.1, 0.1, 0.1, 0.1], dtype=float64),),
'kwargs': {'constraint': SoftplusPositive(lower_bound=0.0)},
'value': Array([0.1, 0.1, 0.1, 0.1, 0.1], dtype=float64),
'scale': None,
'cond_indep_stack': []}),
('graybox/pauli_X.dense_0.kernel_scale',
{'type': 'param',
'name': 'graybox/pauli_X.dense_0.kernel_scale',
'fn': <function numpyro.util.identity(x, *args, **kwargs)>,
'args': (Array([[0.1, 0.1, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, 0.1, 0.1]], dtype=float64),),
'kwargs': {'constraint': SoftplusPositive(lower_bound=0.0)},
'value': Array([[0.1, 0.1, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, 0.1, 0.1]], dtype=float64),
'scale': None,
'cond_indep_stack': []}),
('graybox/pauli_X.dense_0.bias_scale',
{'type': 'param',
'name': 'graybox/pauli_X.dense_0.bias_scale',
'fn': <function numpyro.util.identity(x, *args, **kwargs)>,
'args': (Array([0.1, 0.1, 0.1, 0.1, 0.1], dtype=float64),),
'kwargs': {'constraint': SoftplusPositive(lower_bound=0.0)},
'value': Array([0.1, 0.1, 0.1, 0.1, 0.1], dtype=float64),
'scale': None,
'cond_indep_stack': []}),
('graybox/U_X.kernel_scale',
{'type': 'param',
'name': 'graybox/U_X.kernel_scale',
'fn': <function numpyro.util.identity(x, *args, **kwargs)>,
'args': (Array([[0.1, 0.1, 0.1],
[0.1, 0.1, 0.1],
[0.1, 0.1, 0.1],
[0.1, 0.1, 0.1],
[0.1, 0.1, 0.1]], dtype=float64),),
'kwargs': {'constraint': SoftplusPositive(lower_bound=0.0)},
'value': Array([[0.1, 0.1, 0.1],
[0.1, 0.1, 0.1],
[0.1, 0.1, 0.1],
[0.1, 0.1, 0.1],
[0.1, 0.1, 0.1]], dtype=float64),
'scale': None,
'cond_indep_stack': []}),
('graybox/U_X.bias_scale',
{'type': 'param',
'name': 'graybox/U_X.bias_scale',
'fn': <function numpyro.util.identity(x, *args, **kwargs)>,
'args': (Array([0.1, 0.1, 0.1], dtype=float64),),
'kwargs': {'constraint': SoftplusPositive(lower_bound=0.0)},
'value': Array([0.1, 0.1, 0.1], dtype=float64),
'scale': None,
'cond_indep_stack': []}),
('graybox/D_X.kernel_scale',
{'type': 'param',
'name': 'graybox/D_X.kernel_scale',
'fn': <function numpyro.util.identity(x, *args, **kwargs)>,
'args': (Array([[0.1, 0.1],
[0.1, 0.1],
[0.1, 0.1],
[0.1, 0.1],
[0.1, 0.1]], dtype=float64),),
'kwargs': {'constraint': SoftplusPositive(lower_bound=0.0)},
'value': Array([[0.1, 0.1],
[0.1, 0.1],
[0.1, 0.1],
[0.1, 0.1],
[0.1, 0.1]], dtype=float64),
'scale': None,
'cond_indep_stack': []}),
('graybox/D_X.bias_scale',
{'type': 'param',
'name': 'graybox/D_X.bias_scale',
'fn': <function numpyro.util.identity(x, *args, **kwargs)>,
'args': (Array([0.1, 0.1], dtype=float64),),
'kwargs': {'constraint': SoftplusPositive(lower_bound=0.0)},
'value': Array([0.1, 0.1], dtype=float64),
'scale': None,
'cond_indep_stack': []}),
('graybox/pauli_Y.dense_0.kernel_scale',
{'type': 'param',
'name': 'graybox/pauli_Y.dense_0.kernel_scale',
'fn': <function numpyro.util.identity(x, *args, **kwargs)>,
'args': (Array([[0.1, 0.1, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, 0.1, 0.1]], dtype=float64),),
'kwargs': {'constraint': SoftplusPositive(lower_bound=0.0)},
'value': Array([[0.1, 0.1, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, 0.1, 0.1]], dtype=float64),
'scale': None,
'cond_indep_stack': []}),
('graybox/pauli_Y.dense_0.bias_scale',
{'type': 'param',
'name': 'graybox/pauli_Y.dense_0.bias_scale',
'fn': <function numpyro.util.identity(x, *args, **kwargs)>,
'args': (Array([0.1, 0.1, 0.1, 0.1, 0.1], dtype=float64),),
'kwargs': {'constraint': SoftplusPositive(lower_bound=0.0)},
'value': Array([0.1, 0.1, 0.1, 0.1, 0.1], dtype=float64),
'scale': None,
'cond_indep_stack': []}),
('graybox/U_Y.kernel_scale',
{'type': 'param',
'name': 'graybox/U_Y.kernel_scale',
'fn': <function numpyro.util.identity(x, *args, **kwargs)>,
'args': (Array([[0.1, 0.1, 0.1],
[0.1, 0.1, 0.1],
[0.1, 0.1, 0.1],
[0.1, 0.1, 0.1],
[0.1, 0.1, 0.1]], dtype=float64),),
'kwargs': {'constraint': SoftplusPositive(lower_bound=0.0)},
'value': Array([[0.1, 0.1, 0.1],
[0.1, 0.1, 0.1],
[0.1, 0.1, 0.1],
[0.1, 0.1, 0.1],
[0.1, 0.1, 0.1]], dtype=float64),
'scale': None,
'cond_indep_stack': []}),
('graybox/U_Y.bias_scale',
{'type': 'param',
'name': 'graybox/U_Y.bias_scale',
'fn': <function numpyro.util.identity(x, *args, **kwargs)>,
'args': (Array([0.1, 0.1, 0.1], dtype=float64),),
'kwargs': {'constraint': SoftplusPositive(lower_bound=0.0)},
'value': Array([0.1, 0.1, 0.1], dtype=float64),
'scale': None,
'cond_indep_stack': []}),
('graybox/D_Y.kernel_scale',
{'type': 'param',
'name': 'graybox/D_Y.kernel_scale',
'fn': <function numpyro.util.identity(x, *args, **kwargs)>,
'args': (Array([[0.1, 0.1],
[0.1, 0.1],
[0.1, 0.1],
[0.1, 0.1],
[0.1, 0.1]], dtype=float64),),
'kwargs': {'constraint': SoftplusPositive(lower_bound=0.0)},
'value': Array([[0.1, 0.1],
[0.1, 0.1],
[0.1, 0.1],
[0.1, 0.1],
[0.1, 0.1]], dtype=float64),
'scale': None,
'cond_indep_stack': []}),
('graybox/D_Y.bias_scale',
{'type': 'param',
'name': 'graybox/D_Y.bias_scale',
'fn': <function numpyro.util.identity(x, *args, **kwargs)>,
'args': (Array([0.1, 0.1], dtype=float64),),
'kwargs': {'constraint': SoftplusPositive(lower_bound=0.0)},
'value': Array([0.1, 0.1], dtype=float64),
'scale': None,
'cond_indep_stack': []}),
('graybox/pauli_Z.dense_0.kernel_scale',
{'type': 'param',
'name': 'graybox/pauli_Z.dense_0.kernel_scale',
'fn': <function numpyro.util.identity(x, *args, **kwargs)>,
'args': (Array([[0.1, 0.1, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, 0.1, 0.1]], dtype=float64),),
'kwargs': {'constraint': SoftplusPositive(lower_bound=0.0)},
'value': Array([[0.1, 0.1, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, 0.1, 0.1],
[0.1, 0.1, 0.1, 0.1, 0.1]], dtype=float64),
'scale': None,
'cond_indep_stack': []}),
('graybox/pauli_Z.dense_0.bias_scale',
{'type': 'param',
'name': 'graybox/pauli_Z.dense_0.bias_scale',
'fn': <function numpyro.util.identity(x, *args, **kwargs)>,
'args': (Array([0.1, 0.1, 0.1, 0.1, 0.1], dtype=float64),),
'kwargs': {'constraint': SoftplusPositive(lower_bound=0.0)},
'value': Array([0.1, 0.1, 0.1, 0.1, 0.1], dtype=float64),
'scale': None,
'cond_indep_stack': []}),
('graybox/U_Z.kernel_scale',
{'type': 'param',
'name': 'graybox/U_Z.kernel_scale',
'fn': <function numpyro.util.identity(x, *args, **kwargs)>,
'args': (Array([[0.1, 0.1, 0.1],
[0.1, 0.1, 0.1],
[0.1, 0.1, 0.1],
[0.1, 0.1, 0.1],
[0.1, 0.1, 0.1]], dtype=float64),),
'kwargs': {'constraint': SoftplusPositive(lower_bound=0.0)},
'value': Array([[0.1, 0.1, 0.1],
[0.1, 0.1, 0.1],
[0.1, 0.1, 0.1],
[0.1, 0.1, 0.1],
[0.1, 0.1, 0.1]], dtype=float64),
'scale': None,
'cond_indep_stack': []}),
('graybox/U_Z.bias_scale',
{'type': 'param',
'name': 'graybox/U_Z.bias_scale',
'fn': <function numpyro.util.identity(x, *args, **kwargs)>,
'args': (Array([0.1, 0.1, 0.1], dtype=float64),),
'kwargs': {'constraint': SoftplusPositive(lower_bound=0.0)},
'value': Array([0.1, 0.1, 0.1], dtype=float64),
'scale': None,
'cond_indep_stack': []}),
('graybox/D_Z.kernel_scale',
{'type': 'param',
'name': 'graybox/D_Z.kernel_scale',
'fn': <function numpyro.util.identity(x, *args, **kwargs)>,
'args': (Array([[0.1, 0.1],
[0.1, 0.1],
[0.1, 0.1],
[0.1, 0.1],
[0.1, 0.1]], dtype=float64),),
'kwargs': {'constraint': SoftplusPositive(lower_bound=0.0)},
'value': Array([[0.1, 0.1],
[0.1, 0.1],
[0.1, 0.1],
[0.1, 0.1],
[0.1, 0.1]], dtype=float64),
'scale': None,
'cond_indep_stack': []}),
('graybox/D_Z.bias_scale',
{'type': 'param',
'name': 'graybox/D_Z.bias_scale',
'fn': <function numpyro.util.identity(x, *args, **kwargs)>,
'args': (Array([0.1, 0.1], dtype=float64),),
'kwargs': {'constraint': SoftplusPositive(lower_bound=0.0)},
'value': Array([0.1, 0.1], dtype=float64),
'scale': None,
'cond_indep_stack': []}),
('graybox/shared.dense_0.kernel',
{'type': 'sample',
'name': 'graybox/shared.dense_0.kernel',
'fn': <numpyro.distributions.distribution.Independent object at 0x12e9717f0 with batch shape () and event shape (8, 5)>,
'args': (),
'kwargs': {'rng_key': Array((), dtype=key<fry>) overlaying:
[ 928981903 3453687069],
'sample_shape': ()},
'value': Array([[-0.14008841, 0.1432145 , 0.06248107, 0.02004873, 0.02471476],
[ 0.05244771, 0.08618686, 0.12237145, -0.14305551, 0.11400058],
[-0.07499601, 0.04027209, -0.0546837 , 0.04114347, 0.03301723],
[-0.00954418, 0.09128964, 0.0388787 , 0.19317479, 0.15932732],
[ 0.13931197, 0.01366918, 0.07825694, 0.03967606, 0.07148536],
[ 0.22849079, 0.03318426, 0.13805099, 0.05436563, 0.09425814],
[-0.02164755, -0.05206585, -0.15751706, 0.05851865, -0.08691522],
[-0.10058468, 0.13832578, 0.01075193, -0.05132055, -0.05932651]], dtype=float64),
'scale': None,
'is_observed': False,
'intermediates': [],
'cond_indep_stack': [],
'infer': {}}),
('graybox/shared.dense_0.bias',
{'type': 'sample',
'name': 'graybox/shared.dense_0.bias',
'fn': <numpyro.distributions.distribution.Independent object at 0x12e972f90 with batch shape () and event shape (5,)>,
'args': (),
'kwargs': {'rng_key': Array((), dtype=key<fry>) overlaying:
[1353695780 2116000888],
'sample_shape': ()},
'value': Array([1.04531782, 1.07238262, 0.85906295, 1.14733469, 1.00264397], dtype=float64),
'scale': None,
'is_observed': False,
'intermediates': [],
'cond_indep_stack': [],
'infer': {}}),
('graybox/pauli_X.dense_0.kernel',
{'type': 'sample',
'name': 'graybox/pauli_X.dense_0.kernel',
'fn': <numpyro.distributions.distribution.Independent object at 0x12e972190 with batch shape () and event shape (5, 5)>,
'args': (),
'kwargs': {'rng_key': Array((), dtype=key<fry>) overlaying:
[3531307783 465290248],
'sample_shape': ()},
'value': Array([[-0.03808256, -0.02172061, -0.06575782, -0.03350542, -0.03510779],
[ 0.0584175 , 0.11405403, -0.10272196, 0.15822406, 0.11285611],
[-0.08091083, -0.02301288, -0.03948298, -0.1102386 , 0.03338434],
[-0.03640981, 0.02810232, -0.10063834, 0.02922136, 0.04713823],
[-0.10802706, 0.05611299, -0.10468815, 0.03382672, -0.19492614]], dtype=float64),
'scale': None,
'is_observed': False,
'intermediates': [],
'cond_indep_stack': [],
'infer': {}}),
('graybox/pauli_X.dense_0.bias',
{'type': 'sample',
'name': 'graybox/pauli_X.dense_0.bias',
'fn': <numpyro.distributions.distribution.Independent object at 0x12e973850 with batch shape () and event shape (5,)>,
'args': (),
'kwargs': {'rng_key': Array((), dtype=key<fry>) overlaying:
[1539457558 118255239],
'sample_shape': ()},
'value': Array([0.90524273, 0.8990806 , 1.0311893 , 0.98192345, 1.03276893], dtype=float64),
'scale': None,
'is_observed': False,
'intermediates': [],
'cond_indep_stack': [],
'infer': {}}),
('graybox/U_X.kernel',
{'type': 'sample',
'name': 'graybox/U_X.kernel',
'fn': <numpyro.distributions.distribution.Independent object at 0x12e9706e0 with batch shape () and event shape (5, 3)>,
'args': (),
'kwargs': {'rng_key': Array((), dtype=key<fry>) overlaying:
[4274742258 3380111416],
'sample_shape': ()},
'value': Array([[ 0.17632242, -0.11744192, 0.01794445],
[-0.05177495, 0.16960584, -0.06622758],
[ 0.17634344, -0.03464814, 0.16775395],
[ 0.0823993 , 0.01903537, -0.09046388],
[ 0.15659127, 0.10918574, -0.0401434 ]], dtype=float64),
'scale': None,
'is_observed': False,
'intermediates': [],
'cond_indep_stack': [],
'infer': {}}),
('graybox/U_X.bias',
{'type': 'sample',
'name': 'graybox/U_X.bias',
'fn': <numpyro.distributions.distribution.Independent object at 0x12e973af0 with batch shape () and event shape (3,)>,
'args': (),
'kwargs': {'rng_key': Array((), dtype=key<fry>) overlaying:
[1093704277 2843913905],
'sample_shape': ()},
'value': Array([0.9271848 , 1.0555702 , 1.06170073], dtype=float64),
'scale': None,
'is_observed': False,
'intermediates': [],
'cond_indep_stack': [],
'infer': {}}),
('graybox/D_X.kernel',
{'type': 'sample',
'name': 'graybox/D_X.kernel',
'fn': <numpyro.distributions.distribution.Independent object at 0x12e972b30 with batch shape () and event shape (5, 2)>,
'args': (),
'kwargs': {'rng_key': Array((), dtype=key<fry>) overlaying:
[3905300018 1047965080],
'sample_shape': ()},
'value': Array([[-0.05559002, 0.04963322],
[-0.05268756, 0.16933363],
[-0.0222924 , -0.18154015],
[-0.03311026, -0.07193543],
[ 0.07697746, 0.04136223]], dtype=float64),
'scale': None,
'is_observed': False,
'intermediates': [],
'cond_indep_stack': [],
'infer': {}}),
('graybox/D_X.bias',
{'type': 'sample',
'name': 'graybox/D_X.bias',
'fn': <numpyro.distributions.distribution.Independent object at 0x12e972c10 with batch shape () and event shape (2,)>,
'args': (),
'kwargs': {'rng_key': Array((), dtype=key<fry>) overlaying:
[3381965182 2262451415],
'sample_shape': ()},
'value': Array([0.89985102, 1.03960598], dtype=float64),
'scale': None,
'is_observed': False,
'intermediates': [],
'cond_indep_stack': [],
'infer': {}}),
('graybox/pauli_Y.dense_0.kernel',
{'type': 'sample',
'name': 'graybox/pauli_Y.dense_0.kernel',
'fn': <numpyro.distributions.distribution.Independent object at 0x12e972040 with batch shape () and event shape (5, 5)>,
'args': (),
'kwargs': {'rng_key': Array((), dtype=key<fry>) overlaying:
[2772760534 1150241264],
'sample_shape': ()},
'value': Array([[ 0.10520649, -0.09439651, 0.25773657, -0.19807482, 0.08300809],
[ 0.03322761, -0.11226499, 0.13162118, 0.00421367, 0.11259789],
[ 0.02646392, 0.13100174, -0.13760007, 0.07654653, 0.00921974],
[-0.00175621, 0.19364408, -0.04745788, 0.04232322, -0.00253774],
[ 0.00661635, -0.13051184, -0.0986543 , -0.07035516, -0.00834113]], dtype=float64),
'scale': None,
'is_observed': False,
'intermediates': [],
'cond_indep_stack': [],
'infer': {}}),
('graybox/pauli_Y.dense_0.bias',
{'type': 'sample',
'name': 'graybox/pauli_Y.dense_0.bias',
'fn': <numpyro.distributions.distribution.Independent object at 0x12e972820 with batch shape () and event shape (5,)>,
'args': (),
'kwargs': {'rng_key': Array((), dtype=key<fry>) overlaying:
[3949573971 553583122],
'sample_shape': ()},
'value': Array([0.89491221, 0.99243349, 0.99889162, 0.98189325, 1.04351397], dtype=float64),
'scale': None,
'is_observed': False,
'intermediates': [],
'cond_indep_stack': [],
'infer': {}}),
('graybox/U_Y.kernel',
{'type': 'sample',
'name': 'graybox/U_Y.kernel',
'fn': <numpyro.distributions.distribution.Independent object at 0x12e972430 with batch shape () and event shape (5, 3)>,
'args': (),
'kwargs': {'rng_key': Array((), dtype=key<fry>) overlaying:
[3789552139 2399770011],
'sample_shape': ()},
'value': Array([[-0.07285353, 0.04304867, -0.03740092],
[-0.22240879, 0.04847291, -0.16410529],
[-0.01032646, 0.06037331, -0.00884798],
[-0.05031299, 0.02473008, -0.05038715],
[-0.02642271, 0.10363026, -0.0586004 ]], dtype=float64),
'scale': None,
'is_observed': False,
'intermediates': [],
'cond_indep_stack': [],
'infer': {}}),
('graybox/U_Y.bias',
{'type': 'sample',
'name': 'graybox/U_Y.bias',
'fn': <numpyro.distributions.distribution.Independent object at 0x12e973a80 with batch shape () and event shape (3,)>,
'args': (),
'kwargs': {'rng_key': Array((), dtype=key<fry>) overlaying:
[3499959921 3652298783],
'sample_shape': ()},
'value': Array([1.13729885, 0.76343103, 0.95312259], dtype=float64),
'scale': None,
'is_observed': False,
'intermediates': [],
'cond_indep_stack': [],
'infer': {}}),
('graybox/D_Y.kernel',
{'type': 'sample',
'name': 'graybox/D_Y.kernel',
'fn': <numpyro.distributions.distribution.Independent object at 0x12e973b60 with batch shape () and event shape (5, 2)>,
'args': (),
'kwargs': {'rng_key': Array((), dtype=key<fry>) overlaying:
[ 312422605 4273504752],
'sample_shape': ()},
'value': Array([[ 0.10310684, 0.00423635],
[ 0.03049902, -0.00904775],
[-0.07080258, -0.14189202],
[-0.03537874, -0.03072079],
[ 0.02076615, 0.14101123]], dtype=float64),
'scale': None,
'is_observed': False,
'intermediates': [],
'cond_indep_stack': [],
'infer': {}}),
('graybox/D_Y.bias',
{'type': 'sample',
'name': 'graybox/D_Y.bias',
'fn': <numpyro.distributions.distribution.Independent object at 0x12e973690 with batch shape () and event shape (2,)>,
'args': (),
'kwargs': {'rng_key': Array((), dtype=key<fry>) overlaying:
[1301443351 1262990949],
'sample_shape': ()},
'value': Array([0.95188058, 1.00563022], dtype=float64),
'scale': None,
'is_observed': False,
'intermediates': [],
'cond_indep_stack': [],
'infer': {}}),
('graybox/pauli_Z.dense_0.kernel',
{'type': 'sample',
'name': 'graybox/pauli_Z.dense_0.kernel',
'fn': <numpyro.distributions.distribution.Independent object at 0x12e9732a0 with batch shape () and event shape (5, 5)>,
'args': (),
'kwargs': {'rng_key': Array((), dtype=key<fry>) overlaying:
[1115782097 492331180],
'sample_shape': ()},
'value': Array([[-0.1325124 , -0.03843222, 0.02943445, -0.16186975, 0.15639413],
[ 0.09779884, 0.03734697, -0.01805349, 0.07809086, -0.04967616],
[-0.33209741, 0.07484574, 0.08881937, 0.05133246, 0.05045518],
[ 0.00230586, 0.0461714 , -0.13727346, -0.0351774 , 0.05739598],
[-0.01912264, 0.07885565, 0.0339848 , -0.07031864, -0.00646732]], dtype=float64),
'scale': None,
'is_observed': False,
'intermediates': [],
'cond_indep_stack': [],
'infer': {}}),
('graybox/pauli_Z.dense_0.bias',
{'type': 'sample',
'name': 'graybox/pauli_Z.dense_0.bias',
'fn': <numpyro.distributions.distribution.Independent object at 0x12e972cf0 with batch shape () and event shape (5,)>,
'args': (),
'kwargs': {'rng_key': Array((), dtype=key<fry>) overlaying:
[3716532200 665832165],
'sample_shape': ()},
'value': Array([1.05892567, 0.90481273, 1.02666678, 0.93606215, 0.99313125], dtype=float64),
'scale': None,
'is_observed': False,
'intermediates': [],
'cond_indep_stack': [],
'infer': {}}),
('graybox/U_Z.kernel',
{'type': 'sample',
'name': 'graybox/U_Z.kernel',
'fn': <numpyro.distributions.distribution.Independent object at 0x12e972dd0 with batch shape () and event shape (5, 3)>,
'args': (),
'kwargs': {'rng_key': Array((), dtype=key<fry>) overlaying:
[3468366468 328474123],
'sample_shape': ()},
'value': Array([[-0.02167346, 0.02430774, -0.06703646],
[-0.15356107, -0.02620818, -0.11573705],
[ 0.03559265, 0.04111457, -0.10566783],
[-0.08103412, -0.05960112, -0.09647807],
[ 0.15874484, 0.23166607, 0.0868952 ]], dtype=float64),
'scale': None,
'is_observed': False,
'intermediates': [],
'cond_indep_stack': [],
'infer': {}}),
('graybox/U_Z.bias',
{'type': 'sample',
'name': 'graybox/U_Z.bias',
'fn': <numpyro.distributions.distribution.Independent object at 0x12e972890 with batch shape () and event shape (3,)>,
'args': (),
'kwargs': {'rng_key': Array((), dtype=key<fry>) overlaying:
[3568434798 2807364103],
'sample_shape': ()},
'value': Array([1.02646247, 0.98463201, 1.19315067], dtype=float64),
'scale': None,
'is_observed': False,
'intermediates': [],
'cond_indep_stack': [],
'infer': {}}),
('graybox/D_Z.kernel',
{'type': 'sample',
'name': 'graybox/D_Z.kernel',
'fn': <numpyro.distributions.distribution.Independent object at 0x12e973e00 with batch shape () and event shape (5, 2)>,
'args': (),
'kwargs': {'rng_key': Array((), dtype=key<fry>) overlaying:
[3995525659 1921098443],
'sample_shape': ()},
'value': Array([[ 0.1199833 , -0.10544737],
[ 0.1504808 , -0.00809291],
[-0.20515682, 0.11901646],
[-0.02348738, -0.0813381 ],
[-0.03168612, 0.11155693]], dtype=float64),
'scale': None,
'is_observed': False,
'intermediates': [],
'cond_indep_stack': [],
'infer': {}}),
('graybox/D_Z.bias',
{'type': 'sample',
'name': 'graybox/D_Z.bias',
'fn': <numpyro.distributions.distribution.Independent object at 0x12e972580 with batch shape () and event shape (2,)>,
'args': (),
'kwargs': {'rng_key': Array((), dtype=key<fry>) overlaying:
[2938295869 4294411401],
'sample_shape': ()},
'value': Array([0.97315917, 1.04829688], dtype=float64),
'scale': None,
'is_observed': False,
'intermediates': [],
'cond_indep_stack': [],
'infer': {}})])
Visualize the Negative ELBO loss¶
Belows, we rescale the loss with the number of the sample and plot it with matplotlib.
import matplotlib.pyplot as plt
rescaled_eval_losses = svi_result.eval_losses / test_data.control_params.shape[0]
rescaled_train_losses = svi_result.losses / train_data.control_params.shape[0]
iterations = jnp.arange(len(rescaled_train_losses))
fig, ax = plt.subplots(figsize=(5, 3))
ax = sq.utils.plot_loss_with_moving_average(
iterations,
rescaled_eval_losses,
ax=ax,
color="#6366f1",
label="moving average Test ELBO Loss",
)
ax = sq.utils.plot_loss_with_moving_average(
iterations,
rescaled_train_losses,
ax,
window=1,
annotate_at=[],
color="gray",
alpha=0.25,
label="Train ELBO Loss",
)
ax.set_yscale("log")
ax.set_xlabel("Iterations")
ax.set_ylabel("Rescaled ELBO Loss")
ax.legend()
<matplotlib.legend.Legend at 0x139345a90>
import tempfile
from pathlib import Path
with tempfile.TemporaryDirectory() as tmpdir:
model_path = Path(tmpdir)
# Create the path with parents if not existed already
model_path.mkdir(parents=True, exist_ok=True)
model_state = sq.models.shared.ModelData(params=svi_result.params, config={})
model_state.to_file(model_path / "model.json")
reloaded_model = sq.models.shared.ModelData.from_file(model_path / "model.json")
assert reloaded_model == model_state
Two ways of making posterior predictive model¶
First, let us prepare our testing point and import helper function.
from numpyro.infer import Predictive
params = test_data.control_params[0]
unitary = test_data.unitaries[0]
From guide¶
This is the way with a guide from elsewhere. We have to completely rely on the Predictive to do the work for us. Because we did not assume that the guide is create with structural approach in mind.
def make_predictive_fn_v2(
model,
guide,
params,
shots: int,
):
predictive = Predictive(model, guide=guide, params=params, num_samples=shots)
def predictive_fn(*args, **kwargs):
return predictive(*args, **kwargs)
return predictive_fn
predictive_fn_from_guide = make_predictive_fn_v2(
sq.models.probabilistic.make_probabilistic_model(
predictive_model=graybox_model, log_expectation_values=True
),
guide,
model_state.params,
shots=1000,
)
guide_expectation_values = predictive_fn_from_guide(jax.random.key(0), params, unitary)[
"expectation_values"
]
From variational parameters¶
With our auto_guide, the variational parameters are keep in the sturture that is ready to use.
posterior_fn = sq.models.probabilistic.make_posterior_fn(
model_state.params, sq.models.probabilistic.bnn_init_dist_fn
)
base_model = sq.models.probabilistic.WoModel(
"graybox", (5,), (5,), priors_fn=posterior_fn
)
adapter_fn = sq.models.adapter.observable_to_expvals
graybox_model = sq.models.probabilistic.make_probabilistic_graybox_model(
base_model, adapter_fn
)
posterior_model = sq.models.probabilistic.make_probabilistic_model(
predictive_model=graybox_model,
log_expectation_values=True,
)
posterior_expectation_values = Predictive(model=posterior_model, num_samples=1000)(
jax.random.key(0), params, unitary
)["expectation_values"]
import seaborn as sns
sns.histplot(
{
"posterior": posterior_expectation_values[:, -1],
"guide": guide_expectation_values[:, -1],
}
)
<Axes: ylabel='Count'>