Characterization and Calibration using Graybox¶
Goal
This tutorial aims to guide you thourgh a process of characterizing a predicitive model, and using the predictive model in calibrating for quantum gate using open-loop optimization (although not necessary).
import jax
import jax.numpy as jnp
import inspeqtor as sq
Generate some synthetic data¶
data_model = sq.data.library.get_predefined_data_model_m1()
sample_size = 100
# Now, we use the noise model to performing the data using simulator.
exp_data, _, _, _ = sq.data.library.generate_single_qubit_experimental_data(
key=jax.random.key(0),
hamiltonian=data_model.total_hamiltonian,
sample_size=sample_size,
strategy=sq.physics.library.SimulationStrategy.SHOT,
qubit_inforamtion=data_model.qubit_information,
control_sequence=data_model.control_sequence,
method=sq.physics.library.WhiteboxStrategy.TROTTER,
trotter_steps=10_000,
)
# Now we can prepare the dataset that ready to use.
whitebox = sq.physics.make_trotterization_solver(
data_model.ideal_hamiltonian,
data_model.control_sequence.total_dt,
data_model.dt,
trotter_steps=10_000,
y0=jnp.eye(2, dtype=jnp.complex_),
)
loaded_data = sq.data.prepare_data(exp_data, data_model.control_sequence, whitebox)
We can inspect the experiment configuration from predefined noise model using display function from flax.nnx
print(loaded_data.experiment_data.config)
============================================================
EXPERIMENT CONFIGURATION
============================================================
Identifier: 0001
Backend: stardust
Date: 2025-11-21 23:02:53
Description: This is a test experiment
Shots: 1,000
Sample Size: 100
Device Cycle Time: 0.2222 ns
Sequence Duration: 320 dt
Qubits: 1
- QubitInformation(unit='GHz', qubit_idx=0, anharmonicity=-0.2, frequency=5.0005, drive_strength=0.1, date='2025-11-21 23:02:52')
Expectation Values: 18
(States: {'r', '0', '-', '1', 'l', '+'})
(Observables: {'X', 'Z', 'Y'})
Parameter Structure: [('0', 'theta'), ('0', 'beta')]
Tags: test, test2
============================================================
And the experimental data,
exp_data.observed_dataframe.join(exp_data.parameter_dataframe, on="parameter_id")
| parameter_id | +/X | +/Y | +/Z | -/X | -/Y | -/Z | 0/X | 0/Y | 0/Z | 1/X | 1/Y | 1/Z | l/X | l/Y | l/Z | r/X | r/Y | r/Z | 0/beta | 0/theta |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| u32 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 | f64 |
| 0 | 0.984 | 0.128 | -0.02 | -0.994 | -0.1 | -0.012 | -0.022 | 0.372 | 0.906 | 0.046 | -0.35 | -0.924 | 0.154 | -0.92 | 0.36 | -0.16 | 0.92 | -0.368 | 3.077096 | 5.877436 |
| 1 | 0.962 | -0.214 | -0.082 | -0.978 | 0.228 | 0.052 | -0.082 | -0.476 | 0.888 | 0.038 | 0.474 | -0.892 | -0.234 | -0.836 | -0.506 | 0.212 | 0.856 | 0.46 | 4.624236 | 0.500617 |
| 2 | 0.81 | 0.29 | -0.476 | -0.788 | -0.328 | 0.456 | -0.548 | 0.824 | -0.32 | 0.46 | -0.786 | 0.252 | 0.298 | 0.498 | 0.8 | -0.308 | -0.456 | -0.806 | 8.179866 | 4.171584 |
| 3 | 0.954 | -0.216 | -0.358 | -0.956 | 0.228 | 0.298 | -0.236 | -0.94 | -0.244 | 0.246 | 0.934 | 0.208 | -0.216 | 0.33 | -0.958 | 0.198 | -0.294 | 0.954 | 7.529713 | 1.837246 |
| 4 | 0.972 | -0.054 | -0.216 | -0.976 | 0.072 | 0.2 | -0.216 | -0.574 | -0.782 | 0.188 | 0.62 | 0.77 | -0.03 | 0.798 | -0.618 | 0.088 | -0.784 | 0.594 | 3.19421 | 2.468479 |
| … | … | … | … | … | … | … | … | … | … | … | … | … | … | … | … | … | … | … | … | … |
| 95 | 0.988 | 0.17 | -0.118 | -0.968 | -0.17 | 0.128 | -0.084 | 0.902 | 0.386 | 0.086 | -0.904 | -0.422 | 0.152 | -0.438 | 0.914 | -0.19 | 0.422 | -0.902 | 3.416527 | 5.099819 |
| 96 | 0.954 | -0.218 | -0.148 | -0.972 | 0.212 | 0.196 | -0.14 | -0.914 | 0.34 | 0.15 | 0.9 | -0.416 | -0.244 | -0.372 | -0.914 | 0.156 | 0.308 | 0.886 | 6.3615 | 1.196432 |
| 97 | 0.966 | -0.196 | -0.06 | -0.96 | 0.212 | 0.006 | -0.046 | -0.448 | 0.902 | 0.084 | 0.472 | -0.87 | -0.194 | -0.882 | -0.426 | 0.212 | 0.866 | 0.41 | 5.360042 | 0.48225 |
| 98 | 0.96 | -0.314 | -0.25 | -0.938 | 0.26 | 0.288 | -0.244 | -0.972 | 0.09 | 0.222 | 0.986 | -0.058 | -0.228 | -0.038 | -0.972 | 0.248 | -0.022 | 0.966 | 9.496838 | 1.55557 |
| 99 | 0.966 | -0.202 | -0.066 | -0.974 | 0.204 | 0.036 | -0.008 | -0.338 | 0.936 | 0.062 | 0.37 | -0.952 | -0.188 | -0.93 | -0.33 | 0.23 | 0.906 | 0.3 | 1.424265 | 0.339996 |
Data preprocessing¶
Since we are going to train Deep neural network, it is considered a good practice to split dataset into training and testing dataset. Here we use a sq.utils.random_split helper function for this task.
# Here, we just bundling things up for convinience uses.
key = jax.random.key(0)
key, random_split_key, training_key = jax.random.split(key, 3)
(
train_control_parameters,
train_unitaries,
train_expectation_values,
test_control_paramaeters,
test_unitaries,
test_expectation_values,
) = sq.utils.random_split(
random_split_key,
int(loaded_data.control_parameters.shape[0] * 0.1), # Test size
loaded_data.control_parameters,
loaded_data.unitaries,
loaded_data.observed_values,
)
We going to pass the data back and forth a lot, so we use a helper DataBundled to hold the dataset. The advantange of using this dataclass is that we have a code completion.
train_data = sq.data.DataBundled(
control_params=sq.control.library.drag_feature_map(train_control_parameters),
unitaries=train_unitaries,
observables=train_expectation_values,
)
test_data = sq.data.DataBundled(
control_params=sq.control.library.drag_feature_map(test_control_paramaeters),
unitaries=test_unitaries,
observables=test_expectation_values,
)
Now, we setup the optimizer with the number of epoches.
NUM_EPOCH = 5000
optimizer = sq.optimize.get_default_optimizer(8 * NUM_EPOCH)
In this tutorial, we are going to use predefined $\hat{W}_{O}$-based model using flax.linen.
model = sq.models.library.linen.WoModel(
shared_layers=[10],
pauli_layers=[10],
)
model
WoModel(
# attributes
shared_layers = [10]
pauli_layers = [10]
pauli_operators = ('X', 'Y', 'Z')
NUM_UNITARY_PARAMS = 3
NUM_DIAGONAL_PARAMS = 2
unitary_activation_fn = <lambda>
diagonal_activation_fn = <lambda>
)
Next, we also have to make a loss function using make_loss_fn. This part depends on the implementation of the model that you choose to use.
loss_fn = sq.models.library.linen.make_loss_fn(
adapter_fn=sq.models.adapter.observable_to_expvals,
model=model,
evaluate_fn=lambda x, y, z: sq.models.shared.mse(x, y),
)
import optax
from alive_progress import alive_bar
with alive_bar(NUM_EPOCH, title="Training 🚀", force_tty=True) as bar:
def callback(
model_params,
opt_state: optax.OptState,
histories: list[sq.models.shared.HistoryEntryV3],
):
bar()
model_params, opt_state, histories = sq.models.library.linen.train_model(
training_key,
train_data=train_data,
val_data=test_data, # Here, we did not care about the validating dataset.
test_data=test_data,
model=model,
optimizer=optimizer,
loss_fn=loss_fn,
callbacks=[lambda x, y, z: bar()],
NUM_EPOCH=NUM_EPOCH,
)
# Alternatively, you can use callback for a compact callback function definition.
# alt_callback_fn = lambda x, y, z: bar() # noqa: E731
Training 🚀 |████████████████████████████████████████| 5000/5000 [100%] in 30.0s
Save and load model¶
We can save the model using ModelData as follows,
import tempfile
from pathlib import Path
with tempfile.TemporaryDirectory() as tmpdir:
model_path = Path(tmpdir)
# Create the path with parents if not existed already
model_path.mkdir(parents=True, exist_ok=True)
model_state = sq.models.shared.ModelData(
params=model_params,
config={
"shared_layers": model.shared_layers,
"pauli_layers": model.shared_layers,
},
)
# Save with,
model_state.to_file(model_path / "model.json")
# and load it back using,
model_state_from_file = sq.models.shared.ModelData.from_file(
model_path / "model.json"
)
We can check that both are equal,
model_state == model_state_from_file
True
import pandas as pd
import matplotlib.pyplot as plt
hist_df = pd.DataFrame(
[
{"step": entry.step, "loss": float(entry.loss), "loop": entry.loop}
for entry in histories
]
)
fig, ax = plt.subplots(1, 1, figsize=(10, 3))
sq.utils.plot_loss_with_moving_average(
x=hist_df.query("loop == 'train'")["step"].to_numpy(),
y=hist_df.query("loop == 'train'")["loss"].to_numpy(),
ax=ax,
annotate_at=[],
color="red",
label="Training",
)
sq.utils.plot_loss_with_moving_average(
x=hist_df.query("loop == 'test'")["step"].to_numpy(),
y=hist_df.query("loop == 'test'")["loss"].to_numpy(),
ax=ax,
annotate_at=[],
color="blue",
label="Testing",
)
shots = loaded_data.experiment_data.config.shots
ax.set_yscale("log")
ax.set_xscale("log")
ax.axhline(y=2 / (3 * shots), linestyle="dashed", color="gray")
ax.legend()
sq.utils.set_fontsize(ax, 16)
Predictive model construction.¶
We can use the adapter function that we used to make a loss function with partial to create the predictive model.
model = sq.models.library.linen.WoModel(**model_state.config)
predictive_fn = sq.models.library.linen.make_predictive_fn(
sq.models.adapter.observable_to_expvals, model, model_state.params
)
Here is how can the predictive model be used to predict the expectation values given new parameters.
key, params_key = jax.random.split(key)
ravel_fn, _ = sq.control.ravel_unravel_fn(loaded_data.control_sequence.get_structure())
sample_params = ravel_fn(loaded_data.control_sequence.sample_params(params_key))
unitary_f = loaded_data.whitebox(sample_params)[-1]
predictive_fn(sq.control.library.drag_feature_map(sample_params), unitary_f)
Array([ 0.97959636, -0.97959636, 0.20069228, -0.20069228, -0.01065217,
0.01065217, -0.20536352, 0.20536352, 0.9390662 , -0.9390662 ,
-0.2756454 , 0.2756454 , -0.06116489, 0.06116489, 0.24733496,
-0.24733496, 0.96699755, -0.96699755], dtype=float64)
Even more elegant, we can define a predictive model with the whitebox embeded as follows.
def embed_predictive_model(control_parameters: jnp.ndarray):
unitary_f = loaded_data.whitebox(control_parameters)[-1]
return predictive_fn(
sq.control.library.drag_feature_map(control_parameters), unitary_f
)
embed_predictive_model(sample_params)
Array([ 0.97959636, -0.97959636, 0.20069228, -0.20069228, -0.01065217,
0.01065217, -0.20536352, 0.20536352, 0.9390662 , -0.9390662 ,
-0.2756454 , 0.2756454 , -0.06116489, 0.06116489, 0.24733496,
-0.24733496, 0.96699755, -0.96699755], dtype=float64)
Control Calibration¶
As an example, we are going to use the predictive model to calibrate for the quantum gate. Specifically, we want to find a control parameters that maximize an average gate fidelity with respected to $\sqrt{X}$ gate. First, we define a cost function that the optimizer should find the parameters that minimize its output.
calculate_agf_sx = sq.physics.direct_AGF_estimation_fn(sq.utils.SX)
@jax.jit
def average_gate_infidelity(params: jnp.ndarray):
# Predict the expectation values
predicted_expvals = embed_predictive_model(params)
# Calculate the average gate fidelity with respected to SX gate.
AGF = calculate_agf_sx(predicted_expvals)
# return average gate infidelity squared and log the results.
return (1 - AGF) ** 2, {"AGF": AGF}
average_gate_infidelity(sample_params)
(Array(0.06228546, dtype=float64), {'AGF': Array(0.75042945, dtype=float64)})
Now, we can optimize it using sq.optimize.minimize.
init_key = jax.random.key(73)
init_params = ravel_fn(loaded_data.control_sequence.sample_params(init_key))
lower, upper = loaded_data.control_sequence.get_bounds()
steps = 400
optimized_params, aux = sq.optimize.minimize(
init_params,
average_gate_infidelity,
sq.optimize.get_default_optimizer(steps),
ravel_fn(lower),
ravel_fn(upper),
maxiter=steps,
)
optimized_params
Array([1.22120919, 3.96821785], dtype=float64)
aux[-1]
{'AGF': Array(0.97244255, dtype=float64),
'params': Array([1.22120919, 3.96821785], dtype=float64)}
Benchmark¶
Let's check if the model can accurately characterized the hidden device.
def quantum_device(params: jnp.ndarray):
return sq.models.shared.get_predict_expectation_value(
{"X": sq.utils.X, "Y": sq.utils.Y, "Z": sq.utils.Z},
data_model.solver(params)[-1],
sq.utils.default_expectation_values_order,
)
assert isinstance(optimized_params, jnp.ndarray)
real_expvals = quantum_device(optimized_params)
calculate_agf_sx(real_expvals)
Array(0.97119931, dtype=float64)
As a final thought:
- This is the benchmark without taking the finite-shot effect into the account.
- This is characterization without a model selection process such as hyperparameter tuning.