Skip to content

v1 API

BOED

src.inspeqtor.v1.boed

AuxEntry

The auxillary entry returned by loss function

Source code in src/inspeqtor/v1/boed.py
107
108
109
110
111
class AuxEntry(typing.NamedTuple):
    """The auxillary entry returned by loss function"""

    terms: jnp.ndarray | None
    eig: jnp.ndarray

safe_shape

safe_shape(a: Any) -> tuple[int, ...] | str

Safely get the shape of the object

Parameters:

Name Type Description Default
a Any

Expect the object to be jnp.ndarray

required

Returns:

Type Description
tuple[int, ...] | str

tuple[int, ...] | str: Either return the shape of a

tuple[int, ...] | str

or string representation of the type

Source code in src/inspeqtor/v1/boed.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def safe_shape(a: typing.Any) -> tuple[int, ...] | str:
    """Safely get the shape of the object

    Args:
        a (typing.Any): Expect the object to be jnp.ndarray

    Returns:
        tuple[int, ...] | str: Either return the shape of `a`
        or string representation of the type
    """
    try:
        assert isinstance(a, jnp.ndarray)
        return a.shape
    except AttributeError:
        return str(type(a))

report_shape

report_shape(a: PyTree) -> PyTree

Report the shape of pytree

Parameters:

Name Type Description Default
a PyTree

The pytree to be report.

required

Returns:

Type Description
PyTree

jaxtyping.PyTree: The shape of pytree.

Source code in src/inspeqtor/v1/boed.py
31
32
33
34
35
36
37
38
39
40
def report_shape(a: jaxtyping.PyTree) -> jaxtyping.PyTree:
    """Report the shape of pytree

    Args:
        a (jaxtyping.PyTree): The pytree to be report.

    Returns:
        jaxtyping.PyTree: The shape of pytree.
    """
    return jax.tree.map(safe_shape, a)

lexpand

lexpand(a: ndarray, *dimensions: int) -> ndarray

Expand tensor, adding new dimensions on left.

Parameters:

Name Type Description Default
a ndarray

expand the dimension on the left with given dimension arguments.

required

Returns:

Type Description
ndarray

jnp.ndarray: New array with shape (*dimension + a.shape)

Source code in src/inspeqtor/v1/boed.py
43
44
45
46
47
48
49
50
51
52
def lexpand(a: jnp.ndarray, *dimensions: int) -> jnp.ndarray:
    """Expand tensor, adding new dimensions on left.

    Args:
        a (jnp.ndarray): expand the dimension on the left with given dimension arguments.

    Returns:
        jnp.ndarray: New array with shape (*dimension + a.shape)
    """
    return jnp.broadcast_to(a, dimensions + a.shape)

random_split_index

random_split_index(
    rng_key: ndarray, num_samples: int, test_size: int
) -> tuple[ndarray, ndarray]

Create the randomly spilt of indice to two set, with one of test_size and another as the rest.

Parameters:

Name Type Description Default
rng_key ndarray

The random key

required
num_samples int

The size of total sample size

required
test_size int

The size of test set

required

Returns:

Type Description
tuple[ndarray, ndarray]

tuple[jnp.ndarray, jnp.ndarray]: Array of train indice and array of test indice.

Source code in src/inspeqtor/v1/boed.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def random_split_index(
    rng_key: jnp.ndarray, num_samples: int, test_size: int
) -> tuple[jnp.ndarray, jnp.ndarray]:
    """Create the randomly spilt of indice to two set, with one of test_size and another as the rest.

    Args:
        rng_key (jnp.ndarray): The random key
        num_samples (int): The size of total sample size
        test_size (int): The size of test set

    Returns:
        tuple[jnp.ndarray, jnp.ndarray]: Array of train indice and array of test indice.
    """
    idx = jax.random.permutation(rng_key, jnp.arange(num_samples))
    return idx[test_size:], idx[:test_size]

marginal_loss

marginal_loss(
    model: Callable,
    marginal_guide: Callable,
    design: ndarray,
    *args,
    observation_labels: list[str],
    target_labels: list[str],
    num_particles: int,
    evaluation: bool = False,
) -> Callable[
    [ArrayTree, ndarray], tuple[ndarray, AuxEntry]
]

The marginal loss implemented following https://docs.pyro.ai/en/dev/contrib.oed.html#pyro.contrib.oed.eig.marginal_eig

Parameters:

Name Type Description Default
model Callable

The probabilistic model

required
marginal_guide Callable

The custom guide

required
design ndarray

Possible designs of the experiment

required
observation_labels list[str]

The list of string of observations

required
target_labels list[str]

The target latent parameters to be optimized for

required
num_particles int

The number of independent trials

required
evaluation bool

True for actual evalution of the EIG. Defaults to False.

False

Returns:

Type Description
Callable[[ArrayTree, ndarray], tuple[ndarray, AuxEntry]]

typing.Callable[ [chex.ArrayTree, jnp.ndarray], tuple[jnp.ndarray, AuxEntry] ]: Loss function that return tuple of (1) Total loss, (2.1) Each terms without the average, (2.2) The EIG

Source code in src/inspeqtor/v1/boed.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
def marginal_loss(
    model: typing.Callable,
    marginal_guide: typing.Callable,
    design: jnp.ndarray,
    *args,
    observation_labels: list[str],
    target_labels: list[str],
    num_particles: int,
    evaluation: bool = False,
) -> typing.Callable[[chex.ArrayTree, jnp.ndarray], tuple[jnp.ndarray, AuxEntry]]:
    """The marginal loss implemented following
    https://docs.pyro.ai/en/dev/contrib.oed.html#pyro.contrib.oed.eig.marginal_eig

    Args:
        model (typing.Callable): The probabilistic model
        marginal_guide (typing.Callable): The custom guide
        design (jnp.ndarray): Possible designs of the experiment
        observation_labels (list[str]): The list of string of observations
        target_labels (list[str]): The target latent parameters to be optimized for
        num_particles (int): The number of independent trials
        evaluation (bool, optional): True for actual evalution of the EIG. Defaults to False.

    Returns:
        typing.Callable[ [chex.ArrayTree, jnp.ndarray], tuple[jnp.ndarray, AuxEntry] ]: Loss function that return tuple of (1) Total loss, (2.1) Each terms without the average, (2.2) The EIG
    """

    # Marginal loss
    def loss_fn(param, key: jnp.ndarray) -> tuple[jnp.ndarray, AuxEntry]:
        expanded_design = lexpand(design, num_particles)
        # vectorized(model, num_particles)
        # Sample from p(y | d)
        key, subkey = jax.random.split(key)
        trace = handlers.trace(handlers.seed(model, subkey)).get_trace(
            expanded_design,
            *args,
        )
        y_dict = {
            observation_label: trace[observation_label]["value"]
            for observation_label in observation_labels
        }

        # Run through q(y | d)
        key, subkey = jax.random.split(key)
        conditioned_marginal_guide = handlers.condition(marginal_guide, data=y_dict)
        cond_trace = handlers.trace(
            handlers.substitute(
                handlers.seed(conditioned_marginal_guide, subkey), data=param
            )
        ).get_trace(
            expanded_design,
            *args,
            observation_labels=observation_labels,
            target_labels=target_labels,
        )
        # Compute the log prob of observing the data
        terms = -1 * jnp.array(
            [
                cond_trace[observation_label]["fn"].log_prob(
                    cond_trace[observation_label]["value"]
                )
                for observation_label in observation_labels
            ]
        ).sum(axis=0)

        if evaluation:
            terms += jnp.array(
                [
                    trace[observation_label]["fn"].log_prob(
                        trace[observation_label]["value"]
                    )
                    for observation_label in observation_labels
                ]
            ).sum(axis=0)

        agg_loss, loss = _safe_mean_terms_v2(terms)
        return agg_loss, AuxEntry(terms=terms, eig=loss)

    return loss_fn

vnmc_eig_loss

vnmc_eig_loss(
    model: Callable,
    marginal_guide: Callable,
    design: ndarray,
    *args,
    observation_labels: list[str],
    target_labels: list[str],
    num_particles: tuple[int, int],
    evaluation: bool = False,
) -> Callable[
    [ArrayTree, ndarray], tuple[ndarray, AuxEntry]
]

The VNMC loss implemented following https://docs.pyro.ai/en/dev/_modules/pyro/contrib/oed/eig.html#vnmc_eig

Parameters:

Name Type Description Default
model Callable

The probabilistic model

required
marginal_guide Callable

The custom guide

required
design ndarray

Possible designs of the experiment

required
observation_labels list[str]

The list of string of observations

required
target_labels list[str]

The target latent parameters to be optimized for

required
num_particles int

The number of independent trials

required
evaluation bool

True for actual evalution of the EIG. Defaults to False.

False

Returns:

Type Description
Callable[[ArrayTree, ndarray], tuple[ndarray, AuxEntry]]

typing.Callable[ [chex.ArrayTree, jnp.ndarray], tuple[jnp.ndarray, AuxEntry] ]: Loss function that return tuple of (1) Total loss, (2.1) Each terms without the average, (2.2) The EIG

Source code in src/inspeqtor/v1/boed.py
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
@warn_not_tested_function
def vnmc_eig_loss(
    model: typing.Callable,
    marginal_guide: typing.Callable,
    design: jnp.ndarray,
    *args,
    observation_labels: list[str],
    target_labels: list[str],
    num_particles: tuple[int, int],
    evaluation: bool = False,
) -> typing.Callable[[chex.ArrayTree, jnp.ndarray], tuple[jnp.ndarray, AuxEntry]]:
    """The VNMC loss implemented following
    https://docs.pyro.ai/en/dev/_modules/pyro/contrib/oed/eig.html#vnmc_eig

    Args:
        model (typing.Callable): The probabilistic model
        marginal_guide (typing.Callable): The custom guide
        design (jnp.ndarray): Possible designs of the experiment
        observation_labels (list[str]): The list of string of observations
        target_labels (list[str]): The target latent parameters to be optimized for
        num_particles (int): The number of independent trials
        evaluation (bool, optional): True for actual evalution of the EIG. Defaults to False.

    Returns:
        typing.Callable[ [chex.ArrayTree, jnp.ndarray], tuple[jnp.ndarray, AuxEntry] ]: Loss function that return tuple of (1) Total loss, (2.1) Each terms without the average, (2.2) The EIG
    """

    # Marginal loss
    def loss_fn(param, key: jnp.ndarray) -> tuple[jnp.ndarray, AuxEntry]:
        N, M = num_particles

        expanded_design = lexpand(design, N)

        # Sample from p(y, theta | d)
        key, subkey = jax.random.split(key)
        trace = handlers.trace(handlers.seed(model, subkey)).get_trace(
            expanded_design,
            *args,
        )
        y_dict = {
            observation_label: trace[observation_label]["value"]
            for observation_label in observation_labels
        }

        # Sample M times from q(theta | y, d) for each y
        key, subkey = jax.random.split(key)
        reexpanded_design = lexpand(expanded_design, M)
        conditioned_marginal_guide = handlers.condition(marginal_guide, data=y_dict)
        cond_trace = handlers.trace(
            handlers.substitute(
                handlers.seed(conditioned_marginal_guide, subkey), data=param
            )
        ).get_trace(
            reexpanded_design,
            *args,
            observation_labels=observation_labels,
            target_labels=target_labels,
        )

        theta_y_dict = {
            target_label: cond_trace[target_label]["value"]
            for target_label in target_labels
        }
        theta_y_dict.update(y_dict)

        # Re-run that through the model to compute the joint
        key, subkey = jax.random.split(key)
        conditioned_model = handlers.condition(model, data=theta_y_dict)
        conditioned_model_trace = handlers.trace(
            handlers.seed(conditioned_model, subkey)
        ).get_trace(
            reexpanded_design,
            *args,
        )

        # Compute the log prob of observing the data
        terms = -1 * jnp.array(
            [
                cond_trace[target_label]["fn"].log_prob(
                    cond_trace[target_label]["value"]
                )
                for target_label in target_labels
            ]
        ).sum(axis=0)

        terms += jnp.array(
            [
                conditioned_model_trace[target_label]["fn"].log_prob(
                    conditioned_model_trace[target_label]["value"]
                )
                for target_label in target_labels
            ]
        ).sum(axis=0)

        terms += jnp.array(
            [
                conditioned_model_trace[observation_label]["fn"].log_prob(
                    conditioned_model_trace[observation_label]["value"]
                )
                for observation_label in observation_labels
            ]
        ).sum(axis=0)

        terms = -jax.scipy.special.logsumexp(terms, axis=0) + jnp.log(M)

        if evaluation:
            terms += jnp.array(
                [
                    trace[observation_label]["fn"].log_prob(
                        trace[observation_label]["value"]
                    )
                    for observation_label in observation_labels
                ]
            ).sum(axis=0)

        agg_loss, loss = _safe_mean_terms_v2(terms)
        return agg_loss, AuxEntry(terms=terms, eig=loss)

    return loss_fn

init_params_from_guide

init_params_from_guide(
    marginal_guide: Callable,
    *args,
    key: ndarray,
    design: ndarray,
) -> ArrayTree

Initlalize parameters of marginal guide.

Parameters:

Name Type Description Default
marginal_guide Callable

Marginal guide to be used with marginal eig

required
key ndarray

Random Key

required
design ndarray

Example of the designs of the experiment

required

Returns:

Type Description
ArrayTree

chex.ArrayTree: Random parameters for marginal guide to be optimized.

Source code in src/inspeqtor/v1/boed.py
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
def init_params_from_guide(
    marginal_guide: typing.Callable,
    *args,
    key: jnp.ndarray,
    design: jnp.ndarray,
) -> chex.ArrayTree:
    """Initlalize parameters of marginal guide.

    Args:
        marginal_guide (typing.Callable): Marginal guide to be used with marginal eig
        key (jnp.ndarray): Random Key
        design (jnp.ndarray): Example of the designs of the experiment

    Returns:
        chex.ArrayTree: Random parameters for marginal guide to be optimized.
    """
    key, subkey = jax.random.split(key)
    # expanded_design = lexpand(design, num_particles)
    marginal_guide_trace = handlers.trace(
        handlers.seed(marginal_guide, subkey)
    ).get_trace(design, *args, observation_labels=[], target_labels=[])

    # Get only nodes that are parameters
    params = {
        name: node["value"]
        for name, node in marginal_guide_trace.items()
        if node["type"] == "param"
    }

    return params

opt_eig_ape_loss

opt_eig_ape_loss(
    loss_fn: Callable[
        [ArrayTree, ndarray], tuple[ndarray, AuxEntry]
    ],
    params: ArrayTree,
    num_steps: int,
    optim: GradientTransformation,
    key: ndarray,
    callbacks: list = [],
) -> ArrayTree

Optimize the EIG loss function.

Parameters:

Name Type Description Default
loss_fn Callable[[ArrayTree, ndarray], tuple[ndarray, AuxEntry]]

Loss function

required
params ArrayTree

Initial parameter

required
num_steps int

Number of optimization step

required
optim GradientTransformation

Optax Optimizer

required
key ndarray

Random key

required

Returns:

Type Description
ArrayTree

tuple[chex.ArrayTree, list[typing.Any]]: Optimized parameters, and optimization history.

Source code in src/inspeqtor/v1/boed.py
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
def opt_eig_ape_loss(
    loss_fn: typing.Callable[
        [chex.ArrayTree, jnp.ndarray], tuple[jnp.ndarray, AuxEntry]
    ],
    params: chex.ArrayTree,
    num_steps: int,
    optim: optax.GradientTransformation,
    key: jnp.ndarray,
    callbacks: list = [],
) -> chex.ArrayTree:
    """Optimize the EIG loss function.

    Args:
        loss_fn (typing.Callable[[chex.ArrayTree, jnp.ndarray], tuple[jnp.ndarray, AuxEntry]]): Loss function
        params (chex.ArrayTree): Initial parameter
        num_steps (int): Number of optimization step
        optim (optax.GradientTransformation): Optax Optimizer
        key (jnp.ndarray): Random key

    Returns:
        tuple[chex.ArrayTree, list[typing.Any]]: Optimized parameters, and optimization history.
    """
    # Initialize the optimizer
    opt_state = optim.init(params)
    # jit the loss function
    loss_fn = jax.jit(loss_fn)

    for step in range(num_steps):
        key, subkey = jax.random.split(key)
        # Compute the loss and its gradient
        (loss, aux), grad = jax.value_and_grad(loss_fn, has_aux=True)(params, subkey)
        # Update the optimizer and params
        updates, opt_state = optim.update(grad, opt_state, params)
        params = optax.apply_updates(params, updates)

        # entry = (step, loss, aux)
        entry = HistoryEntry(step=step, loss=loss, aux=aux)

        for callback in callbacks:
            callback(entry)

    return params

estimate_eig

estimate_eig(
    key: ndarray,
    model: Callable,
    marginal_guide: Callable,
    design: ndarray,
    *args,
    optimizer: GradientTransformation,
    num_optimization_steps: int,
    observation_labels: list[str],
    target_labels: list[str],
    num_particles: tuple[int, int] | int,
    final_num_particles: tuple[int, int]
    | int
    | None = None,
    loss_fn: Callable = marginal_loss,
    callbacks: list = [],
) -> tuple[ndarray, dict[str, Any]]

Optimize for marginal EIG

Parameters:

Name Type Description Default
key ndarray

Random key

required
model Callable

Probabilistic model of the experiment

required
marginal_guide Callable

The marginal guide of the experiment

required
design ndarray

Possible designs of the experiment

required
optimizer GradientTransformation

Optax optimizer

required
num_optimization_steps int

Number of the optimization step

required
observation_labels list[str]

The list of string of observations

required
target_labels list[str]

The target latent parameters to be optimized for

required
num_particles int

The number of independent trials

required
final_num_particles int | None

Final independent trials to calculate marginal EIG. Defaults to None.

None

Returns:

Type Description
tuple[ndarray, dict[str, Any]]

tuple[jnp.ndarray, dict[str, typing.Any]]: EIG, and tuple of optimized parameters and optimization history.

Source code in src/inspeqtor/v1/boed.py
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
def estimate_eig(
    key: jnp.ndarray,
    model: typing.Callable,
    marginal_guide: typing.Callable,
    design: jnp.ndarray,
    *args,
    optimizer: optax.GradientTransformation,
    num_optimization_steps: int,
    observation_labels: list[str],
    target_labels: list[str],
    num_particles: tuple[int, int] | int,
    final_num_particles: tuple[int, int] | int | None = None,
    loss_fn: typing.Callable = marginal_loss,
    callbacks: list = [],
) -> tuple[jnp.ndarray, dict[str, typing.Any]]:
    """Optimize for marginal EIG

    Args:
        key (jnp.ndarray): Random key
        model (typing.Callable): Probabilistic model of the experiment
        marginal_guide (typing.Callable): The marginal guide of the experiment
        design (jnp.ndarray): Possible designs of the experiment
        optimizer (optax.GradientTransformation): Optax optimizer
        num_optimization_steps (int): Number of the optimization step
        observation_labels (list[str]): The list of string of observations
        target_labels (list[str]): The target latent parameters to be optimized for
        num_particles (int): The number of independent trials
        final_num_particles (int | None, optional): Final independent trials to calculate marginal EIG. Defaults to None.

    Returns:
        tuple[jnp.ndarray, dict[str, typing.Any]]: EIG, and tuple of optimized parameters and optimization history.
    """
    # NOTE: In final evalution, if final_num_particles != num_particles,
    # the code will error because we train params with num_particles
    # the shape will mismatch
    # final_num_particles = final_num_particles or num_particles

    # Initialize the parameters by using trace from the marginal_guide
    key, subkey = jax.random.split(key)
    params = init_params_from_guide(
        marginal_guide,
        *args,
        key=subkey,
        design=design,
    )

    # Optimize the loss function first to get the optimal parameters
    # for marginal guide
    params = opt_eig_ape_loss(
        loss_fn=loss_fn(
            model,
            marginal_guide,
            design,
            *args,
            observation_labels=observation_labels,
            target_labels=target_labels,
            num_particles=num_particles,
            evaluation=False,
        ),
        params=params,
        num_steps=num_optimization_steps,
        optim=optimizer,
        key=subkey,
        callbacks=callbacks,
    )

    key, subkey = jax.random.split(key)
    # Evaluate the loss
    _, aux = loss_fn(
        model,
        marginal_guide,
        design,
        *args,
        observation_labels=observation_labels,
        target_labels=target_labels,
        num_particles=final_num_particles,
        evaluation=True,
    )(params, subkey)

    return aux.eig, {
        "params": params,
    }

vectorized_for_eig

vectorized_for_eig(model)

Vectorization function for the EIG function

Parameters:

Name Type Description Default
model Any

Probabilistic model.

required
Source code in src/inspeqtor/v1/boed.py
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
def vectorized_for_eig(model):
    """Vectorization function for the EIG function

    Args:
        model (typing.Any): Probabilistic model.
    """

    def wrapper(
        design: jnp.ndarray,
        *args,
        **kwargs,
    ):
        # This wrapper has the same call signature as the probabilistic graybox model
        # Expect the design to has shape == (extra, design, feature)
        with plate_stack(prefix="vectorized_plate", sizes=[*design.shape[:2]]):
            return model(design, *args, **kwargs)

    return wrapper

Constant

src.inspeqtor.v1.constant

Control

src.inspeqtor.v1.control

BaseControl dataclass

Source code in src/inspeqtor/v1/control.py
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
@dataclass
class BaseControl(ABC):
    # duration: int

    def __post_init__(self):
        # self.t_eval = jnp.arange(0, self.duration, 1)
        self.validate()

    def validate(self):
        # Validate that all attributes are json serializable
        try:
            json.dumps(self.to_dict())
        except TypeError as e:
            raise TypeError(
                f"Cannot serialize {self.__class__.__name__} to json"
            ) from e

        lower, upper = self.get_bounds()
        # Validate that the sampling function is working
        key = jax.random.key(0)
        params = sample_params(key, lower, upper)
        # waveform = self.get_waveform(params)

        assert all([isinstance(k, str) for k in params.keys()]), (
            "All key of params dict must be string"
        )
        assert all([isinstance(v, float) for v in params.values()]) or all(
            [isinstance(v, jnp.ndarray) for v in params.values()]
        ), "All value of params dict must be float"
        # assert isinstance(waveform, jax.Array), "Waveform must be jax.Array"

    @abstractmethod
    def get_bounds(
        self, *arg, **kwarg
    ) -> tuple[ParametersDictType, ParametersDictType]: ...

    @abstractmethod
    def get_envelope(self, params: ParametersDictType) -> typing.Callable:
        raise NotImplementedError("get_envelopes method is not implemented")

    # def get_waveform(self, params: ParametersDictType) -> jnp.ndarray:
    #     """Get the discrete waveform of the pulse

    #     Args:
    #         params (ParametersDictType): Control parameter

    #     Returns:
    #         jnp.ndarray: Waveform of the control.
    #     """
    #     return jax.vmap(self.get_envelope(params), in_axes=(0,))((self.t_eval))

    def to_dict(self) -> dict[str, typing.Union[int, float, str]]:
        """Convert the control configuration to dictionary

        Returns:
            dict[str, typing.Union[int, float, str]]: Configuration of the control
        """
        return asdict(self)

    @classmethod
    def from_dict(cls, data):
        """Construct the control instace from the dictionary.

        Args:
            data (dict): Dictionary for construction of the control instance.

        Returns:
            The instance of the control.
        """
        return cls(**data)

to_dict

to_dict() -> dict[str, Union[int, float, str]]

Convert the control configuration to dictionary

Returns:

Type Description
dict[str, Union[int, float, str]]

dict[str, typing.Union[int, float, str]]: Configuration of the control

Source code in src/inspeqtor/v1/control.py
90
91
92
93
94
95
96
def to_dict(self) -> dict[str, typing.Union[int, float, str]]:
    """Convert the control configuration to dictionary

    Returns:
        dict[str, typing.Union[int, float, str]]: Configuration of the control
    """
    return asdict(self)

from_dict classmethod

from_dict(data)

Construct the control instace from the dictionary.

Parameters:

Name Type Description Default
data dict

Dictionary for construction of the control instance.

required

Returns:

Type Description

The instance of the control.

Source code in src/inspeqtor/v1/control.py
 98
 99
100
101
102
103
104
105
106
107
108
@classmethod
def from_dict(cls, data):
    """Construct the control instace from the dictionary.

    Args:
        data (dict): Dictionary for construction of the control instance.

    Returns:
        The instance of the control.
    """
    return cls(**data)

ControlSequenceProtocol

Protocol defining the interface for control sequences.

Source code in src/inspeqtor/v1/control.py
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
class ControlSequenceProtocol(typing.Protocol):
    """Protocol defining the interface for control sequences."""

    # Attributes
    controls: dict[str, BaseControl]
    total_dt: int

    def get_structure(self) -> list[tuple[str, str]]:
        """Get the structure/order of control parameters."""
        ...

    def sample_params(self, key: jax.Array) -> dict[str, ParametersDictType]:
        """Sample control parameters using a random key.

        Args:
            key: Random key for sampling

        Returns:
            Dictionary of sampled control parameters
        """
        ...

    def get_bounds(
        self,
    ) -> tuple[dict[str, ParametersDictType], dict[str, ParametersDictType]]:
        """Get the bounds of the controls.

        Returns:
            Tuple of lower and upper bounds dictionaries
        """
        ...

    def get_envelope(
        self, params_dict: dict[str, ParametersDictType]
    ) -> typing.Callable:
        """Create envelope function with given control parameters.

        Args:
            params_dict: Control parameters to be used

        Returns:
            Envelope function
        """
        ...

    def to_dict(self) -> dict[str, str | dict[str, str | float]]:
        """Convert the control sequence to a dictionary representation."""
        ...

    @classmethod
    def from_dict(
        cls,
        data: dict[str, str | dict[str, str | float]],
        controls: dict[str, type[BaseControl]],
    ) -> typing.Self:
        """Create a control sequence from dictionary data.

        Args:
            data: Dictionary containing control sequence data
            controls: Dictionary mapping control names to control classes

        Returns:
            New control sequence instance
        """
        ...

    def to_file(self, path: str | pathlib.Path) -> None:
        """Save configuration to file.

        Args:
            path: Path to save the sequence configuration
        """
        ...

    @classmethod
    def from_file(
        cls,
        path: str | pathlib.Path,
        controls: dict[str, type[BaseControl]],
    ) -> typing.Self:
        """Load control sequence from file.

        Args:
            path: Path to load the sequence configuration from
            controls: Dictionary mapping control names to control classes

        Returns:
            New control sequence instance
        """
        ...

get_structure

get_structure() -> list[tuple[str, str]]

Get the structure/order of control parameters.

Source code in src/inspeqtor/v1/control.py
118
119
120
def get_structure(self) -> list[tuple[str, str]]:
    """Get the structure/order of control parameters."""
    ...

sample_params

sample_params(key: Array) -> dict[str, ParametersDictType]

Sample control parameters using a random key.

Parameters:

Name Type Description Default
key Array

Random key for sampling

required

Returns:

Type Description
dict[str, ParametersDictType]

Dictionary of sampled control parameters

Source code in src/inspeqtor/v1/control.py
122
123
124
125
126
127
128
129
130
131
def sample_params(self, key: jax.Array) -> dict[str, ParametersDictType]:
    """Sample control parameters using a random key.

    Args:
        key: Random key for sampling

    Returns:
        Dictionary of sampled control parameters
    """
    ...

get_bounds

get_bounds() -> tuple[
    dict[str, ParametersDictType],
    dict[str, ParametersDictType],
]

Get the bounds of the controls.

Returns:

Type Description
tuple[dict[str, ParametersDictType], dict[str, ParametersDictType]]

Tuple of lower and upper bounds dictionaries

Source code in src/inspeqtor/v1/control.py
133
134
135
136
137
138
139
140
141
def get_bounds(
    self,
) -> tuple[dict[str, ParametersDictType], dict[str, ParametersDictType]]:
    """Get the bounds of the controls.

    Returns:
        Tuple of lower and upper bounds dictionaries
    """
    ...

get_envelope

get_envelope(
    params_dict: dict[str, ParametersDictType],
) -> Callable

Create envelope function with given control parameters.

Parameters:

Name Type Description Default
params_dict dict[str, ParametersDictType]

Control parameters to be used

required

Returns:

Type Description
Callable

Envelope function

Source code in src/inspeqtor/v1/control.py
143
144
145
146
147
148
149
150
151
152
153
154
def get_envelope(
    self, params_dict: dict[str, ParametersDictType]
) -> typing.Callable:
    """Create envelope function with given control parameters.

    Args:
        params_dict: Control parameters to be used

    Returns:
        Envelope function
    """
    ...

to_dict

to_dict() -> dict[str, str | dict[str, str | float]]

Convert the control sequence to a dictionary representation.

Source code in src/inspeqtor/v1/control.py
156
157
158
def to_dict(self) -> dict[str, str | dict[str, str | float]]:
    """Convert the control sequence to a dictionary representation."""
    ...

from_dict classmethod

from_dict(
    data: dict[str, str | dict[str, str | float]],
    controls: dict[str, type[BaseControl]],
) -> Self

Create a control sequence from dictionary data.

Parameters:

Name Type Description Default
data dict[str, str | dict[str, str | float]]

Dictionary containing control sequence data

required
controls dict[str, type[BaseControl]]

Dictionary mapping control names to control classes

required

Returns:

Type Description
Self

New control sequence instance

Source code in src/inspeqtor/v1/control.py
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
@classmethod
def from_dict(
    cls,
    data: dict[str, str | dict[str, str | float]],
    controls: dict[str, type[BaseControl]],
) -> typing.Self:
    """Create a control sequence from dictionary data.

    Args:
        data: Dictionary containing control sequence data
        controls: Dictionary mapping control names to control classes

    Returns:
        New control sequence instance
    """
    ...

to_file

to_file(path: str | Path) -> None

Save configuration to file.

Parameters:

Name Type Description Default
path str | Path

Path to save the sequence configuration

required
Source code in src/inspeqtor/v1/control.py
177
178
179
180
181
182
183
def to_file(self, path: str | pathlib.Path) -> None:
    """Save configuration to file.

    Args:
        path: Path to save the sequence configuration
    """
    ...

from_file classmethod

from_file(
    path: str | Path, controls: dict[str, type[BaseControl]]
) -> Self

Load control sequence from file.

Parameters:

Name Type Description Default
path str | Path

Path to load the sequence configuration from

required
controls dict[str, type[BaseControl]]

Dictionary mapping control names to control classes

required

Returns:

Type Description
Self

New control sequence instance

Source code in src/inspeqtor/v1/control.py
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
@classmethod
def from_file(
    cls,
    path: str | pathlib.Path,
    controls: dict[str, type[BaseControl]],
) -> typing.Self:
    """Load control sequence from file.

    Args:
        path: Path to load the sequence configuration from
        controls: Dictionary mapping control names to control classes

    Returns:
        New control sequence instance
    """
    ...

ControlSequence dataclass

Control sequence, expect to be sum of atomic control.

Source code in src/inspeqtor/v1/control.py
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
@dataclass
class ControlSequence:
    """Control sequence, expect to be sum of atomic control."""

    controls: typing.Sequence[BaseControl]
    total_dt: int
    validate: bool = True

    def __post_init__(self):
        # validate that each pulse have len of total_dt
        if self.validate:
            self._validate()

    def _validate(self):
        # Must check that the sum of the pulse lengths is equal to the total length of the pulse sequence
        key = jax.random.key(0)
        subkeys = jax.random.split(key, len(self.controls))
        for pulse_key, pulse in zip(subkeys, self.controls):
            params = sample_params(pulse_key, *pulse.get_bounds())
            # waveform = pulse.get_waveform(params)
            # assert isinstance(waveform, jax.Array)
            # Assert the waveform is of the correct length
            # assert waveform.shape == (self.total_dt,)
            # Assert that all key of params dict is string and all value is jax.Array
            assert all([isinstance(k, str) for k in params.keys()]), (
                "All key of params dict must be string"
            )
            assert all(
                [isinstance(v, (float, int, jnp.ndarray)) for v in params.values()]
            ), "All value of params dict must be float or jax.Array"

        params = self.sample_params(key)

        # Assert that the bounds have the same pytree structure as the parameters
        lower, upper = self.get_bounds()

        assert jax.tree.structure(lower) == jax.tree.structure(params)
        assert jax.tree.structure(upper) == jax.tree.structure(params)

    def sample_params(self, key: jax.Array) -> list[ParametersDictType]:
        """Sample control parameter

        Args:
            key (jax.Array): Random key

        Returns:
            list[ParametersDictType]: control parameters
        """
        # Split key for each pulse
        subkeys = jax.random.split(key, len(self.controls))

        params_list: list[ParametersDictType] = []
        for pulse_key, pulse in zip(subkeys, self.controls):
            params = sample_params(pulse_key, *pulse.get_bounds())
            params_list.append(params)

        return params_list

    # def get_waveform(self, params_list: list[ParametersDictType]) -> jnp.ndarray:
    #     """
    #     Samples the pulse sequence by generating random parameters for each pulse and computing the total waveform.

    #     Parameters:
    #         key (Key): The random key used for generating the parameters.

    #     Returns:
    #         tuple[list[ParametersDictType], Complex[Array, "time"]]: A tuple containing a list of parameter dictionaries for each pulse and the total waveform.

    #     Example:
    #         key = jax.random.PRNGKey(0)
    #         params_list, total_waveform = sample(key)
    #     """
    #     # Create base waveform
    #     total_waveform = jnp.zeros(self.total_dt, dtype=jnp.complex64)

    #     for _params, _pulse in zip(params_list, self.controls):
    #         waveform = _pulse.get_waveform(_params)
    #         total_waveform += waveform

    #     return total_waveform

    def get_envelope(self, params_list: list[ParametersDictType]) -> typing.Callable:
        """Create envelope function with given control parameters

        Args:
            params_list (list[ParametersDictType]): control parameter to be used

        Returns:
            typing.Callable: Envelope function
        """
        callables = []
        for _params, _pulse in zip(params_list, self.controls):
            callables.append(_pulse.get_envelope(_params))

        # Create a function that returns the sum of the envelopes
        def envelope(t):
            return sum([c(t) for c in callables])

        return envelope

    def get_bounds(self) -> tuple[list[ParametersDictType], list[ParametersDictType]]:
        """Get the bounds of the controls

        Returns:
            tuple[list[ParametersDictType], list[ParametersDictType]]: tuple of list of lower and upper bounds.
        """
        lower_bounds = []
        upper_bounds = []
        for pulse in self.controls:
            lower, upper = pulse.get_bounds()
            lower_bounds.append(lower)
            upper_bounds.append(upper)

        return lower_bounds, upper_bounds

    def get_parameter_names(self) -> list[list[str]]:
        """Get the name of the control parameters in the control sequence.

        Returns:
            list[list[str]]: Structured name of control parameters.
        """
        # Sample the pulse sequence to get the parameter names
        key = jax.random.key(0)
        params_list = self.sample_params(key)

        # Get the parameter names for each pulse
        parameter_names = []
        for params in params_list:
            parameter_names.append(list(params.keys()))

        return parameter_names

    def to_dict(self) -> dict[str, typing.Any]:
        """Convert control sequence to dictionary.

        Returns:
            dict[str, typing.Any]: Control sequence configuration dict.
        """
        return {
            **asdict(self),
            "controls": [
                {**pulse.to_dict(), "_name": pulse.__class__.__name__}
                for pulse in self.controls
            ],
        }

    @classmethod
    def from_dict(
        cls, data: dict[str, typing.Any], controls: typing.Sequence[type[BaseControl]]
    ) -> "ControlSequence":
        """Construct the control sequence from dict.

        Args:
            data (dict[str, typing.Any]): Dict contain information for sequence construction
            control (typing.Sequence[type[BasePulse]]): Constructor of the controls

        Returns:
            ControlSequence: Instance of the control sequence.
        """
        parsed_data = []
        for d, pulse in zip(data["controls"], controls):
            assert isinstance(d, dict), f"Expected dict, got {type(d)}"

            # remove the _name key
            d.pop("_name")
            parsed_data.append(pulse.from_dict(d))

        data["controls"] = parsed_data
        data["validate"] = True

        return cls(**data)

    def to_file(self, path: typing.Union[str, pathlib.Path]):
        """Save configuration of the pulse to file given folder path.

        Args:
            path (typing.Union[str, pathlib.Path]): Path to the folder to save sequence, will be created if not existed.
        """
        if isinstance(path, str):
            path = pathlib.Path(path)

        os.makedirs(path, exist_ok=True)
        with open(path / "control_sequence.json", "w") as f:
            json.dump(self.to_dict(), f, indent=4)

    @classmethod
    def from_file(
        cls,
        path: typing.Union[str, pathlib.Path],
        controls: typing.Sequence[type[BaseControl]],
    ) -> "ControlSequence":
        """Construct control seqence from path

        Args:
            path (typing.Union[str, pathlib.Path]): Path to configuration of control sequence.
            controls (typing.Sequence[type[BasePulse]]): Constructor of the control in the sequence.

        Returns:
            ControlSequence: Control sequence instance.
        """
        if isinstance(path, str):
            path = pathlib.Path(path)

        with open(path / "control_sequence.json", "r") as f:
            dict_control_sequence = json.load(f)

        return cls.from_dict(dict_control_sequence, controls=controls)

sample_params

sample_params(key: Array) -> list[ParametersDictType]

Sample control parameter

Parameters:

Name Type Description Default
key Array

Random key

required

Returns:

Type Description
list[ParametersDictType]

list[ParametersDictType]: control parameters

Source code in src/inspeqtor/v1/control.py
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
def sample_params(self, key: jax.Array) -> list[ParametersDictType]:
    """Sample control parameter

    Args:
        key (jax.Array): Random key

    Returns:
        list[ParametersDictType]: control parameters
    """
    # Split key for each pulse
    subkeys = jax.random.split(key, len(self.controls))

    params_list: list[ParametersDictType] = []
    for pulse_key, pulse in zip(subkeys, self.controls):
        params = sample_params(pulse_key, *pulse.get_bounds())
        params_list.append(params)

    return params_list

get_envelope

get_envelope(
    params_list: list[ParametersDictType],
) -> Callable

Create envelope function with given control parameters

Parameters:

Name Type Description Default
params_list list[ParametersDictType]

control parameter to be used

required

Returns:

Type Description
Callable

typing.Callable: Envelope function

Source code in src/inspeqtor/v1/control.py
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
def get_envelope(self, params_list: list[ParametersDictType]) -> typing.Callable:
    """Create envelope function with given control parameters

    Args:
        params_list (list[ParametersDictType]): control parameter to be used

    Returns:
        typing.Callable: Envelope function
    """
    callables = []
    for _params, _pulse in zip(params_list, self.controls):
        callables.append(_pulse.get_envelope(_params))

    # Create a function that returns the sum of the envelopes
    def envelope(t):
        return sum([c(t) for c in callables])

    return envelope

get_bounds

get_bounds() -> tuple[
    list[ParametersDictType], list[ParametersDictType]
]

Get the bounds of the controls

Returns:

Type Description
tuple[list[ParametersDictType], list[ParametersDictType]]

tuple[list[ParametersDictType], list[ParametersDictType]]: tuple of list of lower and upper bounds.

Source code in src/inspeqtor/v1/control.py
303
304
305
306
307
308
309
310
311
312
313
314
315
316
def get_bounds(self) -> tuple[list[ParametersDictType], list[ParametersDictType]]:
    """Get the bounds of the controls

    Returns:
        tuple[list[ParametersDictType], list[ParametersDictType]]: tuple of list of lower and upper bounds.
    """
    lower_bounds = []
    upper_bounds = []
    for pulse in self.controls:
        lower, upper = pulse.get_bounds()
        lower_bounds.append(lower)
        upper_bounds.append(upper)

    return lower_bounds, upper_bounds

get_parameter_names

get_parameter_names() -> list[list[str]]

Get the name of the control parameters in the control sequence.

Returns:

Type Description
list[list[str]]

list[list[str]]: Structured name of control parameters.

Source code in src/inspeqtor/v1/control.py
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
def get_parameter_names(self) -> list[list[str]]:
    """Get the name of the control parameters in the control sequence.

    Returns:
        list[list[str]]: Structured name of control parameters.
    """
    # Sample the pulse sequence to get the parameter names
    key = jax.random.key(0)
    params_list = self.sample_params(key)

    # Get the parameter names for each pulse
    parameter_names = []
    for params in params_list:
        parameter_names.append(list(params.keys()))

    return parameter_names

to_dict

to_dict() -> dict[str, Any]

Convert control sequence to dictionary.

Returns:

Type Description
dict[str, Any]

dict[str, typing.Any]: Control sequence configuration dict.

Source code in src/inspeqtor/v1/control.py
335
336
337
338
339
340
341
342
343
344
345
346
347
def to_dict(self) -> dict[str, typing.Any]:
    """Convert control sequence to dictionary.

    Returns:
        dict[str, typing.Any]: Control sequence configuration dict.
    """
    return {
        **asdict(self),
        "controls": [
            {**pulse.to_dict(), "_name": pulse.__class__.__name__}
            for pulse in self.controls
        ],
    }

from_dict classmethod

from_dict(
    data: dict[str, Any],
    controls: Sequence[type[BaseControl]],
) -> ControlSequence

Construct the control sequence from dict.

Parameters:

Name Type Description Default
data dict[str, Any]

Dict contain information for sequence construction

required
control Sequence[type[BasePulse]]

Constructor of the controls

required

Returns:

Name Type Description
ControlSequence ControlSequence

Instance of the control sequence.

Source code in src/inspeqtor/v1/control.py
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
@classmethod
def from_dict(
    cls, data: dict[str, typing.Any], controls: typing.Sequence[type[BaseControl]]
) -> "ControlSequence":
    """Construct the control sequence from dict.

    Args:
        data (dict[str, typing.Any]): Dict contain information for sequence construction
        control (typing.Sequence[type[BasePulse]]): Constructor of the controls

    Returns:
        ControlSequence: Instance of the control sequence.
    """
    parsed_data = []
    for d, pulse in zip(data["controls"], controls):
        assert isinstance(d, dict), f"Expected dict, got {type(d)}"

        # remove the _name key
        d.pop("_name")
        parsed_data.append(pulse.from_dict(d))

    data["controls"] = parsed_data
    data["validate"] = True

    return cls(**data)

to_file

to_file(path: Union[str, Path])

Save configuration of the pulse to file given folder path.

Parameters:

Name Type Description Default
path Union[str, Path]

Path to the folder to save sequence, will be created if not existed.

required
Source code in src/inspeqtor/v1/control.py
375
376
377
378
379
380
381
382
383
384
385
386
def to_file(self, path: typing.Union[str, pathlib.Path]):
    """Save configuration of the pulse to file given folder path.

    Args:
        path (typing.Union[str, pathlib.Path]): Path to the folder to save sequence, will be created if not existed.
    """
    if isinstance(path, str):
        path = pathlib.Path(path)

    os.makedirs(path, exist_ok=True)
    with open(path / "control_sequence.json", "w") as f:
        json.dump(self.to_dict(), f, indent=4)

from_file classmethod

from_file(
    path: Union[str, Path],
    controls: Sequence[type[BaseControl]],
) -> ControlSequence

Construct control seqence from path

Parameters:

Name Type Description Default
path Union[str, Path]

Path to configuration of control sequence.

required
controls Sequence[type[BasePulse]]

Constructor of the control in the sequence.

required

Returns:

Name Type Description
ControlSequence ControlSequence

Control sequence instance.

Source code in src/inspeqtor/v1/control.py
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
@classmethod
def from_file(
    cls,
    path: typing.Union[str, pathlib.Path],
    controls: typing.Sequence[type[BaseControl]],
) -> "ControlSequence":
    """Construct control seqence from path

    Args:
        path (typing.Union[str, pathlib.Path]): Path to configuration of control sequence.
        controls (typing.Sequence[type[BasePulse]]): Constructor of the control in the sequence.

    Returns:
        ControlSequence: Control sequence instance.
    """
    if isinstance(path, str):
        path = pathlib.Path(path)

    with open(path / "control_sequence.json", "r") as f:
        dict_control_sequence = json.load(f)

    return cls.from_dict(dict_control_sequence, controls=controls)

sample_params

sample_params(
    key: ndarray,
    lower: ParametersDictType,
    upper: ParametersDictType,
) -> ParametersDictType

Sample parameters with the same shape with given lower and upper bounds

Parameters:

Name Type Description Default
key ndarray

Random key

required
lower ParametersDictType

Lower bound

required
upper ParametersDictType

Upper bound

required

Returns:

Name Type Description
ParametersDictType ParametersDictType

Dict of the sampled parameters

Source code in src/inspeqtor/v1/control.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def sample_params(
    key: jnp.ndarray, lower: ParametersDictType, upper: ParametersDictType
) -> ParametersDictType:
    """Sample parameters with the same shape with given lower and upper bounds

    Args:
        key (jnp.ndarray): Random key
        lower (ParametersDictType): Lower bound
        upper (ParametersDictType): Upper bound

    Returns:
        ParametersDictType: Dict of the sampled parameters
    """
    # This function is general because it is depend only on lower and upper structure
    param: ParametersDictType = {}
    param_names = lower.keys()
    for name in param_names:
        sample_key, key = jax.random.split(key)
        param[name] = jax.random.uniform(
            sample_key, shape=(), dtype=float, minval=lower[name], maxval=upper[name]
        )

    # return jax.tree.map(float, param)
    return param

array_to_list_of_params

array_to_list_of_params(
    array: ndarray, parameter_structure: list[list[str]]
) -> list[ParametersDictType]

Convert the array of control parameter to the list form

Parameters:

Name Type Description Default
array ndarray

Control parameter array

required
parameter_structure list[list[str]]

The structure of the control sequence

required

Returns:

Type Description
list[ParametersDictType]

list[ParametersDictType]: Control parameter in the list form.

Source code in src/inspeqtor/v1/control.py
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
def array_to_list_of_params(
    array: jnp.ndarray, parameter_structure: list[list[str]]
) -> list[ParametersDictType]:
    """Convert the array of control parameter to the list form

    Args:
        array (jnp.ndarray): Control parameter array
        parameter_structure (list[list[str]]): The structure of the control sequence

    Returns:
        list[ParametersDictType]: Control parameter in the list form.
    """
    temp: list[ParametersDictType] = []
    idx = 0
    for sub_pulse in parameter_structure:
        temp_dict: ParametersDictType = {}
        for param in sub_pulse:
            temp_dict[param] = array[idx]
            idx += 1
        temp.append(temp_dict)

    return temp

list_of_params_to_array

list_of_params_to_array(
    params: list[ParametersDictType],
    parameter_structure: list[list[str]],
) -> ndarray

Convert the control parameter in the list form to flatten array form

Parameters:

Name Type Description Default
params list[ParametersDictType]

Control parameter in the list form

required
parameter_structure list[list[str]]

The structure of the control sequence

required

Returns:

Type Description
ndarray

jnp.ndarray: Control parameters array

Source code in src/inspeqtor/v1/control.py
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
def list_of_params_to_array(
    params: list[ParametersDictType], parameter_structure: list[list[str]]
) -> jnp.ndarray:
    """Convert the control parameter in the list form to flatten array form

    Args:
        params (list[ParametersDictType]): Control parameter in the list form
        parameter_structure (list[list[str]]): The structure of the control sequence

    Returns:
        jnp.ndarray: Control parameters array
    """
    temp = []
    for subp_idx, sub_pulse in enumerate(parameter_structure):
        for param in sub_pulse:
            temp.append(params[subp_idx][param])

    return jnp.array(temp)

get_param_array_converter

get_param_array_converter(
    control_sequence: ControlSequence,
)

This function returns two functions that can convert between a list of parameter dictionaries and a flat array.

array_to_list_of_params_fn, list_of_params_to_array_fn = get_param_array_converter(control_sequence)

Args:
control_sequence (ControlSequence): The pulse sequence object.

Returns:

Type Description

typing.Any: A tuple containing two functions. The first function converts an array to a list of parameter dictionaries, and the second function converts a list of parameter dictionaries to an array.

Source code in src/inspeqtor/v1/control.py
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
def get_param_array_converter(control_sequence: ControlSequence):
    """This function returns two functions that can convert between a list of parameter dictionaries and a flat array.

    >>> array_to_list_of_params_fn, list_of_params_to_array_fn = get_param_array_converter(control_sequence)

        Args:
        control_sequence (ControlSequence): The pulse sequence object.

    Returns:
        typing.Any: A tuple containing two functions. The first function converts an array to a list of parameter dictionaries, and the second function converts a list of parameter dictionaries to an array.
    """
    structure = control_sequence.get_parameter_names()

    def array_to_list_of_params_fn(
        array: jnp.ndarray,
    ) -> list[ParametersDictType]:
        return array_to_list_of_params(array, structure)

    def list_of_params_to_array_fn(
        params: list[ParametersDictType],
    ) -> jnp.ndarray:
        return list_of_params_to_array(params, structure)

    return array_to_list_of_params_fn, list_of_params_to_array_fn

construct_control_sequence_reader

construct_control_sequence_reader(
    controls: list[type[BaseControl]] = [],
) -> Callable[[Union[str, Path]], ControlSequence]

Construct the control sequence reader

Parameters:

Name Type Description Default
controls list[type[BasePulse]]

List of control constructor. Defaults to [].

[]

Returns:

Type Description
Callable[[Union[str, Path]], ControlSequence]

typing.Callable[[typing.Union[str, pathlib.Path]], controlsequence]: Control sequence reader that will automatically contruct control sequence from path.

Source code in src/inspeqtor/v1/control.py
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
def construct_control_sequence_reader(
    controls: list[type[BaseControl]] = [],
) -> typing.Callable[[typing.Union[str, pathlib.Path]], ControlSequence]:
    """Construct the control sequence reader

    Args:
        controls (list[type[BasePulse]], optional): List of control constructor. Defaults to [].

    Returns:
        typing.Callable[[typing.Union[str, pathlib.Path]], controlsequence]: Control sequence reader that will automatically contruct control sequence from path.
    """
    default_controls: list[type[BaseControl]] = []

    # Merge the default controls with the provided controls
    controls_list = default_controls + controls

    def control_sequence_reader(
        path: typing.Union[str, pathlib.Path],
    ) -> ControlSequence:
        """Construct control sequence from path

        Args:
            path (typing.Union[str, pathlib.Path]): Path of the saved control sequence configuration.

        Returns:
            ControlSeqence: Control sequence instance.
        """
        if isinstance(path, str):
            path = pathlib.Path(path)

        with open(path / "control_sequence.json", "r") as f:
            control_sequence_dict = json.load(f)

        parsed_controls = []

        for pulse_dict in control_sequence_dict["controls"]:
            for control_class in controls_list:
                if pulse_dict["_name"] == control_class.__name__:
                    parsed_controls.append(control_class)

        return ControlSequence.from_dict(
            control_sequence_dict, controls=parsed_controls
        )

    return control_sequence_reader

get_envelope_transformer

get_envelope_transformer(control_sequence: ControlSequence)

Generate get_envelope function with control parameter array as an input instead of list form

Parameters:

Name Type Description Default
control_sequence ControlSequence

Control seqence instance

required

Returns:

Type Description

typing.Callable[[jnp.ndarray], typing.Any]: Transformed get envelope function

Source code in src/inspeqtor/v1/control.py
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
def get_envelope_transformer(control_sequence: ControlSequence):
    """Generate get_envelope function with control parameter array as an input instead of list form

    Args:
        control_sequence (ControlSequence): Control seqence instance

    Returns:
        typing.Callable[[jnp.ndarray], typing.Any]: Transformed get envelope function
    """
    structure = control_sequence.get_parameter_names()

    def array_to_list_of_params_fn(array: jnp.ndarray):
        return array_to_list_of_params(array, structure)

    def get_envelope(params: jnp.ndarray) -> typing.Callable[..., typing.Any]:
        return control_sequence.get_envelope(array_to_list_of_params_fn(params))

    return get_envelope

Data

src.inspeqtor.v1.data

Operator dataclass

Dataclass for accessing qubit operators. Support X, Y, Z, Hadamard, S, Sdg, and I gate.

Raises:

Type Description
ValueError

Provided operator is not supperted

Source code in src/inspeqtor/v1/data.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
@dataclass
class Operator:
    """Dataclass for accessing qubit operators. Support X, Y, Z, Hadamard, S, Sdg, and I gate.

    Raises:
        ValueError: Provided operator is not supperted

    """

    _pauli_x = jnp.array([[0, 1], [1, 0]], dtype=jnp.complex64)
    _pauli_y = jnp.array([[0, -1j], [1j, 0]], dtype=jnp.complex64)
    _pauli_z = jnp.array([[1, 0], [0, -1]], dtype=jnp.complex64)
    _hadamard = jnp.array([[1, 1], [1, -1]], dtype=jnp.complex64) / jnp.sqrt(2)
    _s_gate = jnp.array([[1, 0], [0, 1j]], dtype=jnp.complex64)
    _sdg_gate = jnp.array([[1, 0], [0, -1j]], dtype=jnp.complex64)
    _identity = jnp.array([[1, 0], [0, 1]], dtype=jnp.complex64)

    @classmethod
    def from_label(cls, op: str) -> jnp.ndarray:
        """Initialize the operator from the label

        Args:
            op (str): The label of the operator

        Raises:
            ValueError: Operator not supported

        Returns:
            jnp.ndarray: The operator
        """

        if op == "X":
            operator = cls._pauli_x
        elif op == "Y":
            operator = cls._pauli_y
        elif op == "Z":
            operator = cls._pauli_z
        elif op == "H":
            operator = cls._hadamard
        elif op == "S":
            operator = cls._s_gate
        elif op == "Sdg":
            operator = cls._sdg_gate
        elif op == "I":
            operator = cls._identity
        else:
            raise ValueError(f"Operator {op} is not supported")

        return operator

    @classmethod
    def to_qutrit(cls, op: jnp.ndarray, value: float = 1.0) -> jnp.ndarray:
        """Add extra dimension to the operator

        Args:
            op (jnp.ndarray): Qubit operator
            value (float, optional): Value to be add at the extra dimension diagonal entry. Defaults to 1.0.

        Returns:
            jnp.ndarray: New operator for qutrit space.
        """
        return add_hilbert_level(op, x=jnp.array([value]))

from_label classmethod

from_label(op: str) -> ndarray

Initialize the operator from the label

Parameters:

Name Type Description Default
op str

The label of the operator

required

Raises:

Type Description
ValueError

Operator not supported

Returns:

Type Description
ndarray

jnp.ndarray: The operator

Source code in src/inspeqtor/v1/data.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
@classmethod
def from_label(cls, op: str) -> jnp.ndarray:
    """Initialize the operator from the label

    Args:
        op (str): The label of the operator

    Raises:
        ValueError: Operator not supported

    Returns:
        jnp.ndarray: The operator
    """

    if op == "X":
        operator = cls._pauli_x
    elif op == "Y":
        operator = cls._pauli_y
    elif op == "Z":
        operator = cls._pauli_z
    elif op == "H":
        operator = cls._hadamard
    elif op == "S":
        operator = cls._s_gate
    elif op == "Sdg":
        operator = cls._sdg_gate
    elif op == "I":
        operator = cls._identity
    else:
        raise ValueError(f"Operator {op} is not supported")

    return operator

to_qutrit classmethod

to_qutrit(op: ndarray, value: float = 1.0) -> ndarray

Add extra dimension to the operator

Parameters:

Name Type Description Default
op ndarray

Qubit operator

required
value float

Value to be add at the extra dimension diagonal entry. Defaults to 1.0.

1.0

Returns:

Type Description
ndarray

jnp.ndarray: New operator for qutrit space.

Source code in src/inspeqtor/v1/data.py
88
89
90
91
92
93
94
95
96
97
98
99
@classmethod
def to_qutrit(cls, op: jnp.ndarray, value: float = 1.0) -> jnp.ndarray:
    """Add extra dimension to the operator

    Args:
        op (jnp.ndarray): Qubit operator
        value (float, optional): Value to be add at the extra dimension diagonal entry. Defaults to 1.0.

    Returns:
        jnp.ndarray: New operator for qutrit space.
    """
    return add_hilbert_level(op, x=jnp.array([value]))

State dataclass

Dataclass for accessing eigenvector corresponded to eigenvalue of Pauli operator X, Y, and Z.

Raises:

Type Description
ValueError

Provided state is not supported

ValueError

Provided state is not qubit

Source code in src/inspeqtor/v1/data.py
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
@dataclass
class State:
    """Dataclass for accessing eigenvector corresponded to eigenvalue of Pauli operator X, Y, and Z.

    Raises:
        ValueError: Provided state is not supported
        ValueError: Provided state is not qubit
    """

    _zero = jnp.array([1, 0], dtype=jnp.complex64)
    _one = jnp.array([0, 1], dtype=jnp.complex64)
    _plus = jnp.array([1, 1], dtype=jnp.complex64) / jnp.sqrt(2)
    _minus = jnp.array([1, -1], dtype=jnp.complex64) / jnp.sqrt(2)
    _right = jnp.array([1, 1j], dtype=jnp.complex64) / jnp.sqrt(2)
    _left = jnp.array([1, -1j], dtype=jnp.complex64) / jnp.sqrt(2)

    @classmethod
    def from_label(cls, state: str, dm: bool = False) -> jnp.ndarray:
        """Initialize the state from the label

        Args:
            state (str): The label of the state
            dm (bool, optional): Initialized as statevector or density matrix. Defaults to False.

        Raises:
            ValueError: State not supported

        Returns:
            jnp.ndarray: The state
        """

        if state in ["0", "Z+"]:
            state_vec = cls._zero
        elif state in ["1", "Z-"]:
            state_vec = cls._one
        elif state in ["+", "X+"]:
            state_vec = cls._plus
        elif state in ["-", "X-"]:
            state_vec = cls._minus
        elif state in ["r", "Y+"]:
            state_vec = cls._right
        elif state in ["l", "Y-"]:
            state_vec = cls._left
        else:
            raise ValueError(f"State {state} is not supported")

        state_vec = state_vec.reshape(2, 1)

        return state_vec if not dm else jnp.outer(state_vec, state_vec.conj())

    @classmethod
    def to_qutrit(cls, state: jnp.ndarray) -> jnp.ndarray:
        """Promote qubit state to qutrit with zero probability

        Args:
            state (jnp.ndarray): Density matrix of 2 x 2 qubit state.

        Raises:
            ValueError: Provided state is not qubit

        Returns:
            jnp.ndarray: Qutrit density matrix
        """
        if state.shape != (2, 2):
            raise ValueError("Shape of the state is not as expected, expect (2, 2)")

        return add_hilbert_level(state, x=jnp.array([0.0]))

from_label classmethod

from_label(state: str, dm: bool = False) -> ndarray

Initialize the state from the label

Parameters:

Name Type Description Default
state str

The label of the state

required
dm bool

Initialized as statevector or density matrix. Defaults to False.

False

Raises:

Type Description
ValueError

State not supported

Returns:

Type Description
ndarray

jnp.ndarray: The state

Source code in src/inspeqtor/v1/data.py
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
@classmethod
def from_label(cls, state: str, dm: bool = False) -> jnp.ndarray:
    """Initialize the state from the label

    Args:
        state (str): The label of the state
        dm (bool, optional): Initialized as statevector or density matrix. Defaults to False.

    Raises:
        ValueError: State not supported

    Returns:
        jnp.ndarray: The state
    """

    if state in ["0", "Z+"]:
        state_vec = cls._zero
    elif state in ["1", "Z-"]:
        state_vec = cls._one
    elif state in ["+", "X+"]:
        state_vec = cls._plus
    elif state in ["-", "X-"]:
        state_vec = cls._minus
    elif state in ["r", "Y+"]:
        state_vec = cls._right
    elif state in ["l", "Y-"]:
        state_vec = cls._left
    else:
        raise ValueError(f"State {state} is not supported")

    state_vec = state_vec.reshape(2, 1)

    return state_vec if not dm else jnp.outer(state_vec, state_vec.conj())

to_qutrit classmethod

to_qutrit(state: ndarray) -> ndarray

Promote qubit state to qutrit with zero probability

Parameters:

Name Type Description Default
state ndarray

Density matrix of 2 x 2 qubit state.

required

Raises:

Type Description
ValueError

Provided state is not qubit

Returns:

Type Description
ndarray

jnp.ndarray: Qutrit density matrix

Source code in src/inspeqtor/v1/data.py
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
@classmethod
def to_qutrit(cls, state: jnp.ndarray) -> jnp.ndarray:
    """Promote qubit state to qutrit with zero probability

    Args:
        state (jnp.ndarray): Density matrix of 2 x 2 qubit state.

    Raises:
        ValueError: Provided state is not qubit

    Returns:
        jnp.ndarray: Qutrit density matrix
    """
    if state.shape != (2, 2):
        raise ValueError("Shape of the state is not as expected, expect (2, 2)")

    return add_hilbert_level(state, x=jnp.array([0.0]))

QubitInformation dataclass

Dataclass to store qubit information

Parameters:

Name Type Description Default
unit str

The string representation of unit, currently support "GHz", "2piGHz", "2piHz", or "Hz".

required
qubit_idx int

the index of the qubit.

required
anharmonicity float

Anhamonicity of the qubit, kept for the sake of completeness.

required
frequency float

Qubit frequency.

required
drive_strength float

Drive strength of qubit, might be specific for IBMQ platform.

required

Raises:

Type Description
ValueError

Fail to convert unit to GHz

Source code in src/inspeqtor/v1/data.py
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
@dataclass
class QubitInformation:
    """Dataclass to store qubit information

    Args:
        unit (str): The string representation of unit, currently support "GHz", "2piGHz", "2piHz", or "Hz".
        qubit_idx (int): the index of the qubit.
        anharmonicity (float): Anhamonicity of the qubit, kept for the sake of completeness.
        frequency (float): Qubit frequency.
        drive_strength (float): Drive strength of qubit, might be specific for IBMQ platform.

    Raises:
        ValueError: Fail to convert unit to GHz
    """

    unit: str
    qubit_idx: int
    anharmonicity: float
    frequency: float
    drive_strength: float
    date: str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

    def __post_init__(self):
        self.convert_unit_to_ghz()

    def convert_unit_to_ghz(self):
        """Convert the unit of data stored in self to unit of GHz

        Raises:
            ValueError: Data stored in the unsupported unit
        """
        if self.unit == "GHz":
            pass
        elif self.unit == "Hz":
            self.anharmonicity = self.anharmonicity * 1e-9
            self.frequency = self.frequency * 1e-9
            self.drive_strength = self.drive_strength * 1e-9
        elif self.unit == "2piGHz":
            self.anharmonicity = self.anharmonicity / (2 * jnp.pi)
            self.frequency = self.frequency / (2 * jnp.pi)
            self.drive_strength = self.drive_strength / (2 * jnp.pi)
        elif self.unit == "2piHz":
            self.anharmonicity = self.anharmonicity / (2 * jnp.pi) * 1e-9
            self.frequency = self.frequency / (2 * jnp.pi) * 1e-9
            self.drive_strength = self.drive_strength / (2 * jnp.pi) * 1e-9
        else:
            raise ValueError("Unit must be GHz, 2piGHz, 2piHz, or Hz")

        # Set unit to GHz
        self.unit = "GHz"

    def to_dict(self):
        return asdict(self)

    @classmethod
    def from_dict(cls, dict_qubit_info: dict):
        return cls(**dict_qubit_info)

convert_unit_to_ghz

convert_unit_to_ghz()

Convert the unit of data stored in self to unit of GHz

Raises:

Type Description
ValueError

Data stored in the unsupported unit

Source code in src/inspeqtor/v1/data.py
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
def convert_unit_to_ghz(self):
    """Convert the unit of data stored in self to unit of GHz

    Raises:
        ValueError: Data stored in the unsupported unit
    """
    if self.unit == "GHz":
        pass
    elif self.unit == "Hz":
        self.anharmonicity = self.anharmonicity * 1e-9
        self.frequency = self.frequency * 1e-9
        self.drive_strength = self.drive_strength * 1e-9
    elif self.unit == "2piGHz":
        self.anharmonicity = self.anharmonicity / (2 * jnp.pi)
        self.frequency = self.frequency / (2 * jnp.pi)
        self.drive_strength = self.drive_strength / (2 * jnp.pi)
    elif self.unit == "2piHz":
        self.anharmonicity = self.anharmonicity / (2 * jnp.pi) * 1e-9
        self.frequency = self.frequency / (2 * jnp.pi) * 1e-9
        self.drive_strength = self.drive_strength / (2 * jnp.pi) * 1e-9
    else:
        raise ValueError("Unit must be GHz, 2piGHz, 2piHz, or Hz")

    # Set unit to GHz
    self.unit = "GHz"

ExpectationValue dataclass

Dataclass to store expectation value information

Parameters:

Name Type Description Default
initial_state str

String representation of inital state. Currently support "+", "-", "r", "l", "0", "1".

required
observable str

String representation of quantum observable. Currently support "X", "Y", "Z".

required
expectation_value None | float

the expectation value. Default to None

None

Raises:

Type Description
ValueError

Not support initial state

ValueError

Not support observable

ValueError

Not support initial state

ValueError

Not support observable

Source code in src/inspeqtor/v1/data.py
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
@dataclass
class ExpectationValue:
    """Dataclass to store expectation value information

    Args:
        initial_state (str): String representation of inital state. Currently support "+", "-", "r", "l", "0", "1".
        observable (str): String representation of quantum observable.  Currently support "X", "Y", "Z".
        expectation_value (None | float): the expectation value. Default to None

    Raises:
        ValueError: Not support initial state
        ValueError: Not support observable
        ValueError: Not support initial state
        ValueError: Not support observable
    """

    initial_state: str
    observable: str
    expectation_value: None | float = None

    # Not serialized
    initial_statevector: jnp.ndarray = field(init=False)
    initial_density_matrix: jnp.ndarray = field(init=False)
    observable_matrix: jnp.ndarray = field(init=False)

    def __post_init__(self):
        if self.initial_state not in ["+", "-", "r", "l", "0", "1"]:
            raise ValueError(f"Initial state {self.initial_state} is not supported")
        if self.observable not in ["X", "Y", "Z"]:
            raise ValueError(f"Observable {self.observable} is not supported")

        self.initial_statevector = State.from_label(self.initial_state)
        self.initial_density_matrix = State.from_label(self.initial_state, dm=True)
        self.observable_matrix = Operator.from_label(self.observable)

    def to_dict(self):
        return {
            "initial_state": self.initial_state,
            "observable": self.observable,
            "expectation_value": self.expectation_value,
        }

    def __eq__(self, __value: object) -> bool:
        if not isinstance(__value, ExpectationValue):
            return False

        return (
            self.initial_state == __value.initial_state
            and self.observable == __value.observable
            and self.expectation_value == __value.expectation_value
        )

    def __str__(self):
        return f"{self.initial_state}/{self.observable} = {self.expectation_value}"

    # Overwrite the __repr__ method of the class
    def __repr__(self):
        return f'{self.__class__.__name__}(initial_state="{self.initial_state}", observable="{self.observable}", expectation_value={self.expectation_value})'

    @classmethod
    def from_dict(cls, data):
        return cls(**data)

ExperimentConfiguration dataclass

Experiment configuration dataclass

Source code in src/inspeqtor/v1/data.py
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
@dataclass
class ExperimentConfiguration:
    """Experiment configuration dataclass"""

    qubits: typing.Sequence[QubitInformation]
    expectation_values_order: typing.Sequence[ExpectationValue]
    parameter_names: typing.Sequence[
        typing.Sequence[str]
    ]  # Get from the pulse sequence .get_parameter_names()
    backend_name: str
    shots: int
    EXPERIMENT_IDENTIFIER: str
    EXPERIMENT_TAGS: typing.Sequence[str]
    description: str
    device_cycle_time_ns: float
    sequence_duration_dt: int
    instance: str
    sample_size: int
    date: str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    additional_info: dict[str, str | int | float] = field(default_factory=dict)

    def to_dict(self):
        return {
            **asdict(self),
            "qubits": [qubit.to_dict() for qubit in self.qubits],
            "expectation_values_order": [
                exp.to_dict() for exp in self.expectation_values_order
            ],
        }

    @classmethod
    def from_dict(cls, dict_experiment_config):
        dict_experiment_config["qubits"] = [
            QubitInformation.from_dict(qubit)
            for qubit in dict_experiment_config["qubits"]
        ]

        dict_experiment_config["expectation_values_order"] = [
            ExpectationValue.from_dict(exp)
            for exp in dict_experiment_config["expectation_values_order"]
        ]

        return cls(**dict_experiment_config)

    def to_file(self, path: typing.Union[Path, str]):
        if isinstance(path, str):
            path = Path(path)

        # os.makedirs(path, exist_ok=True)
        path.mkdir(parents=True, exist_ok=True)
        with open(path / "config.json", "w") as f:
            json.dump(self.to_dict(), f, indent=4)

    @classmethod
    def from_file(cls, path: typing.Union[Path, str]):
        if isinstance(path, str):
            path = Path(path)
        with open(path / "config.json", "r") as f:
            dict_experiment_config = json.load(f)

        return cls.from_dict(dict_experiment_config)

ExperimentData dataclass

Dataclass for processing of the characterization dataset. A difference between preprocess and postprocess dataset is that postprocess group expectation values same control parameter id within single row instead of multiple rows.

Parameters:

Name Type Description Default
experiment_config ExperimentConfiguration

Experiment configuration

required
preprocess_data DataFrame

Pandas dataframe containing the preprocess dataset

required
_postprocessed_data DataFrame | None

(pd.DataFrame): Provide this optional argument to skip dataset postprocessing.

None
keep_decimal int

the precision of floating point to keep.

10
Source code in src/inspeqtor/v1/data.py
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
@dataclass
class ExperimentData:
    """Dataclass for processing of the characterization dataset.
    A difference between preprocess and postprocess dataset is that postprocess group
    expectation values same control parameter id within single row instead of multiple rows.

    Args:
        experiment_config (ExperimentConfiguration): Experiment configuration
        preprocess_data (pd.DataFrame): Pandas dataframe containing the preprocess dataset
        _postprocessed_data: (pd.DataFrame): Provide this optional argument to skip dataset postprocessing.
        keep_decimal (int): the precision of floating point to keep.
    """

    experiment_config: ExperimentConfiguration
    preprocess_data: pd.DataFrame
    # optional
    _postprocessed_data: pd.DataFrame | None = field(default=None)

    # Setting
    keep_decimal: int = 10
    # Postprocessing
    postprocessed_data: pd.DataFrame = field(init=False)
    parameter_columns: list[str] = field(init=False)
    parameters: np.ndarray = field(init=False)

    def __post_init__(self):
        self.preprocess_data = self.preprocess_data.round(self.keep_decimal)

        # Validate that self.preprocess_data have all the required columns
        self.validate_preprocess_data()
        logging.info("Preprocess data validated")

        if self._postprocessed_data is not None:
            self.postprocessed_data = self._postprocessed_data
            logging.info("Postprocess data set")

        else:
            post_data = self.transform_preprocess_data_to_postprocess_data()
            logging.info("Preprocess data transformed to postprocess data")

            self.postprocessed_data = post_data.round(self.keep_decimal)

        # Validate the data with schema
        self.validate_postprocess_data(self.postprocessed_data)
        logging.info("Postprocess data validated")

        self.parameter_columns = flatten_parameter_name_with_prefix(
            self.experiment_config.parameter_names
        )
        num_features = len(self.experiment_config.parameter_names[0])
        num_controls = len(self.experiment_config.parameter_names)

        try:
            temp_params = np.array(
                self.postprocessed_data[self.parameter_columns]
                .to_numpy()
                .reshape(
                    (self.experiment_config.sample_size, num_controls, num_features)
                )
            )
        except Exception:
            logging.info(
                "Could not reshape parameters with shape (sample_size, num_controls, num_features), automatically reshaping to (sample_size, -1)"
            )
            temp_params = np.array(
                self.postprocessed_data[self.parameter_columns]
                .to_numpy()
                .reshape((self.experiment_config.sample_size, -1))
            )
            logging.info(f"Parameters reshaped to {temp_params.shape}")

        self.parameters = temp_params

        logging.info("Parameters converted to numpy array")

        assert (
            self.preprocess_data[self.parameter_columns]
            .drop_duplicates(ignore_index=True)
            .equals(
                self.postprocessed_data[self.parameter_columns].drop_duplicates(
                    ignore_index=True
                )
            )
        ), (
            "The preprocess_data and postprocessed_data does not have the same parameters."
        )
        logging.info("Preprocess data and postprocess data have the same parameters")

    def __eq__(self, __value: object) -> bool:
        if not isinstance(__value, ExperimentData):
            return False

        return (
            self.experiment_config == __value.experiment_config
            and self.preprocess_data.equals(__value.preprocess_data)
        )

    def validate_preprocess_data(self):
        """Validate that the preprocess_data have all the required columns.

        Required columns:
            - EXPECTATION_VALUE
            - INITIAL_STATE
            - OBSERVABLE
            - PARAMETERS_ID
        """
        for col in REQUIRED_COLUMNS:
            if col.required:
                assert col.name in self.preprocess_data.columns, (
                    f"Column {col.name} is required but not found in the preprocess_data."
                )

        # Validate that the preprocess_data have all expected parameters columns
        required_parameters_columns = flatten_parameter_name_with_prefix(
            self.experiment_config.parameter_names
        )

        for _col in required_parameters_columns:
            assert _col in self.preprocess_data.columns, (
                f"Column {_col} is required but not found in the preprocess_data."
            )

    def validate_postprocess_data(self, post_data: pd.DataFrame):
        """Validate postprocess dataset, by check the requirements given by `PredefinedCol` instance of each column
        that required in the postprocessed dataset.

        Args:
            post_data (pd.DataFrame): Postprocessed dataset to be validated.
        """
        logging.info("Validating postprocess data")
        # Validate that the postprocess_data have all the required columns
        for col in REQUIRED_COLUMNS:
            if col.required:
                assert col.name in post_data.columns, (
                    f"Column {col.name} is required but not found in the postprocess_data."
                )

        # Validate the check functions
        for col in REQUIRED_COLUMNS:
            for check in col.checks:
                assert all([check(v) for v in post_data[col.name]]), (
                    f"Column {col.name} failed the check function {check}"
                )

        # Validate that the postprocess_data have all expected parameters columns
        required_parameters_columns = flatten_parameter_name_with_prefix(
            self.experiment_config.parameter_names
        )
        for _col in required_parameters_columns:
            assert _col in post_data.columns, (
                f"Column {_col} is required but not found in the postprocess_data."
            )

    def transform_preprocess_data_to_postprocess_data(self) -> pd.DataFrame:
        """Internal method to post process the dataset.

        Todo:
            Use new experimental implementation from_long to wide dataframe

        Raises:
            ValueError: There is duplicate entry of the expectation value.

        Returns:
            pd.DataFrame: Postprocessed experiment dataset.
        """
        # Postprocess the data squeezing the data into the expectation values
        # Required columns: PARAMETERS_ID, OBSERVABLE, INITIAL_STATE, EXPECTATION_VALUE, + experiment_config.parameter_names
        post_data = []

        for params_id in range(self.experiment_config.sample_size):
            # NOTE: Assume that parameters_id starts from 0 and is continuous to sample_size - 1
            rows = self.preprocess_data.loc[
                self.preprocess_data[PARAMETERS_ID.name] == params_id
            ]

            expectation_values = {}
            for _, exp_order in enumerate(
                self.experiment_config.expectation_values_order
            ):
                expectation_value = rows.loc[
                    (rows[OBSERVABLE.name] == exp_order.observable)
                    & (rows[INITIAL_STATE.name] == exp_order.initial_state)
                ][EXPECTATION_VALUE.name].values

                if expectation_value.shape[0] != 1:
                    raise ValueError(
                        f"Expectation value for params_id {params_id}, initial_state {exp_order.initial_state}, observable {exp_order.observable} is not unique. The length is {len(expectation_value)}."
                    )

                expectation_values[
                    f"{EXPECTATION_VALUE.name}/{exp_order.initial_state}/{exp_order.observable}"
                ] = expectation_value[0]

            drop_duplicates_row = rows.drop_duplicates(
                subset=flatten_parameter_name_with_prefix(
                    self.experiment_config.parameter_names
                )
            )
            # Assert that only one row is returned
            assert drop_duplicates_row.shape[0] == 1
            pulse_parameters = drop_duplicates_row.to_dict(orient="records")[0]

            new_row = {
                PARAMETERS_ID.name: params_id,
                **expectation_values,
                **{str(k): v for k, v in pulse_parameters.items()},
            }

            post_data.append(new_row)

        return pd.DataFrame(post_data)

    def get_parameters_dataframe(self) -> pd.DataFrame:
        """Get dataframe with only the columns of control parameters.

        Returns:
            pd.DataFrame: Dataframe with only the columns of control parameters.
        """
        return self.postprocessed_data[self.parameter_columns]

    def get_expectation_values(self) -> np.ndarray:
        """Get the expectation value of the shape (sample_size, num_expectation_value)

        Returns:
            np.ndarray: expectation value of the shape (sample_size, num_expectation_value)
        """
        expectation_value = self.postprocessed_data[
            [
                f"expectation_value/{col.initial_state}/{col.observable}"
                for col in self.experiment_config.expectation_values_order
            ]
        ].to_numpy()

        return np.array(expectation_value)

    def get_parameters_dict_list(self) -> list[list[ParametersDictType]]:
        """Get the list, where each element is list of dict of the control parameters of the dataset.

        Returns:
            list[list[ParametersDictType]]: The list of list of dict of parameter.
        """
        _temp = self.postprocessed_data[self.parameter_columns]

        _params_list = [
            get_parameters_dict_list(self.experiment_config.parameter_names, row)
            for _, row in _temp.iterrows()
        ]

        return _params_list

    def save_to_folder(self, path: typing.Union[Path, str]):
        """Save the experiment data to given folder

        Args:
            path (typing.Union[Path, str]): Path of the folder for experiment data to be saved.
        """
        if isinstance(path, str):
            path = Path(path)

        # os.makedirs(path, exist_ok=True)
        path.mkdir(parents=True, exist_ok=True)
        self.experiment_config.to_file(path)
        self.preprocess_data.to_csv(path / "preprocess_data.csv", index=False)
        self.postprocessed_data.to_csv(path / "postprocessed_data.csv", index=False)

    @classmethod
    def from_folder(cls, path: typing.Union[Path, str]) -> "ExperimentData":
        """Read the experiment data from path

        Args:
            path (typing.Union[Path, str]): path to the folder contain experiment data. Expected to be used with `self.save_to_folder` method.

        Returns:
            ExperimentData: Intance of `ExperimentData` read from path.
        """
        if isinstance(path, str):
            path = Path(path)

        experiment_config = ExperimentConfiguration.from_file(path)
        preprocess_data = pd.read_csv(
            path / "preprocess_data.csv",
        )

        # Check if postprocessed_data exists
        if not (path / "postprocessed_data.csv").exists():
            # if not os.path.exists(path / "postprocessed_data.csv"):
            postprocessed_data = None
        else:
            postprocessed_data = pd.read_csv(
                path / "postprocessed_data.csv",
            )

        return cls(
            experiment_config=experiment_config,
            preprocess_data=preprocess_data,
            _postprocessed_data=postprocessed_data,
        )

    def analysis_sum_of_expectation_values(self) -> pd.DataFrame:
        paulis = ["X", "Y", "Z"]
        initial_states = [("0", "1"), ("+", "-"), ("r", "l")]
        data = {}
        for pauli in paulis:
            for initial_state in initial_states:
                _name = f"{pauli}/{initial_state[0]}/{initial_state[1]}"

                res = (
                    self.postprocessed_data[
                        f"expectation_value/{initial_state[0]}/{pauli}"
                    ]
                    + self.postprocessed_data[
                        f"expectation_value/{initial_state[1]}/{pauli}"
                    ]
                )

                data[_name] = res.to_numpy()

        return pd.DataFrame(data)

validate_preprocess_data

validate_preprocess_data()

Validate that the preprocess_data have all the required columns.

Required columns
  • EXPECTATION_VALUE
  • INITIAL_STATE
  • OBSERVABLE
  • PARAMETERS_ID
Source code in src/inspeqtor/v1/data.py
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
def validate_preprocess_data(self):
    """Validate that the preprocess_data have all the required columns.

    Required columns:
        - EXPECTATION_VALUE
        - INITIAL_STATE
        - OBSERVABLE
        - PARAMETERS_ID
    """
    for col in REQUIRED_COLUMNS:
        if col.required:
            assert col.name in self.preprocess_data.columns, (
                f"Column {col.name} is required but not found in the preprocess_data."
            )

    # Validate that the preprocess_data have all expected parameters columns
    required_parameters_columns = flatten_parameter_name_with_prefix(
        self.experiment_config.parameter_names
    )

    for _col in required_parameters_columns:
        assert _col in self.preprocess_data.columns, (
            f"Column {_col} is required but not found in the preprocess_data."
        )

validate_postprocess_data

validate_postprocess_data(post_data: DataFrame)

Validate postprocess dataset, by check the requirements given by PredefinedCol instance of each column that required in the postprocessed dataset.

Parameters:

Name Type Description Default
post_data DataFrame

Postprocessed dataset to be validated.

required
Source code in src/inspeqtor/v1/data.py
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
def validate_postprocess_data(self, post_data: pd.DataFrame):
    """Validate postprocess dataset, by check the requirements given by `PredefinedCol` instance of each column
    that required in the postprocessed dataset.

    Args:
        post_data (pd.DataFrame): Postprocessed dataset to be validated.
    """
    logging.info("Validating postprocess data")
    # Validate that the postprocess_data have all the required columns
    for col in REQUIRED_COLUMNS:
        if col.required:
            assert col.name in post_data.columns, (
                f"Column {col.name} is required but not found in the postprocess_data."
            )

    # Validate the check functions
    for col in REQUIRED_COLUMNS:
        for check in col.checks:
            assert all([check(v) for v in post_data[col.name]]), (
                f"Column {col.name} failed the check function {check}"
            )

    # Validate that the postprocess_data have all expected parameters columns
    required_parameters_columns = flatten_parameter_name_with_prefix(
        self.experiment_config.parameter_names
    )
    for _col in required_parameters_columns:
        assert _col in post_data.columns, (
            f"Column {_col} is required but not found in the postprocess_data."
        )

transform_preprocess_data_to_postprocess_data

transform_preprocess_data_to_postprocess_data() -> (
    DataFrame
)

Internal method to post process the dataset.

Todo

Use new experimental implementation from_long to wide dataframe

Raises:

Type Description
ValueError

There is duplicate entry of the expectation value.

Returns:

Type Description
DataFrame

pd.DataFrame: Postprocessed experiment dataset.

Source code in src/inspeqtor/v1/data.py
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
def transform_preprocess_data_to_postprocess_data(self) -> pd.DataFrame:
    """Internal method to post process the dataset.

    Todo:
        Use new experimental implementation from_long to wide dataframe

    Raises:
        ValueError: There is duplicate entry of the expectation value.

    Returns:
        pd.DataFrame: Postprocessed experiment dataset.
    """
    # Postprocess the data squeezing the data into the expectation values
    # Required columns: PARAMETERS_ID, OBSERVABLE, INITIAL_STATE, EXPECTATION_VALUE, + experiment_config.parameter_names
    post_data = []

    for params_id in range(self.experiment_config.sample_size):
        # NOTE: Assume that parameters_id starts from 0 and is continuous to sample_size - 1
        rows = self.preprocess_data.loc[
            self.preprocess_data[PARAMETERS_ID.name] == params_id
        ]

        expectation_values = {}
        for _, exp_order in enumerate(
            self.experiment_config.expectation_values_order
        ):
            expectation_value = rows.loc[
                (rows[OBSERVABLE.name] == exp_order.observable)
                & (rows[INITIAL_STATE.name] == exp_order.initial_state)
            ][EXPECTATION_VALUE.name].values

            if expectation_value.shape[0] != 1:
                raise ValueError(
                    f"Expectation value for params_id {params_id}, initial_state {exp_order.initial_state}, observable {exp_order.observable} is not unique. The length is {len(expectation_value)}."
                )

            expectation_values[
                f"{EXPECTATION_VALUE.name}/{exp_order.initial_state}/{exp_order.observable}"
            ] = expectation_value[0]

        drop_duplicates_row = rows.drop_duplicates(
            subset=flatten_parameter_name_with_prefix(
                self.experiment_config.parameter_names
            )
        )
        # Assert that only one row is returned
        assert drop_duplicates_row.shape[0] == 1
        pulse_parameters = drop_duplicates_row.to_dict(orient="records")[0]

        new_row = {
            PARAMETERS_ID.name: params_id,
            **expectation_values,
            **{str(k): v for k, v in pulse_parameters.items()},
        }

        post_data.append(new_row)

    return pd.DataFrame(post_data)

get_parameters_dataframe

get_parameters_dataframe() -> DataFrame

Get dataframe with only the columns of control parameters.

Returns:

Type Description
DataFrame

pd.DataFrame: Dataframe with only the columns of control parameters.

Source code in src/inspeqtor/v1/data.py
709
710
711
712
713
714
715
def get_parameters_dataframe(self) -> pd.DataFrame:
    """Get dataframe with only the columns of control parameters.

    Returns:
        pd.DataFrame: Dataframe with only the columns of control parameters.
    """
    return self.postprocessed_data[self.parameter_columns]

get_expectation_values

get_expectation_values() -> ndarray

Get the expectation value of the shape (sample_size, num_expectation_value)

Returns:

Type Description
ndarray

np.ndarray: expectation value of the shape (sample_size, num_expectation_value)

Source code in src/inspeqtor/v1/data.py
717
718
719
720
721
722
723
724
725
726
727
728
729
730
def get_expectation_values(self) -> np.ndarray:
    """Get the expectation value of the shape (sample_size, num_expectation_value)

    Returns:
        np.ndarray: expectation value of the shape (sample_size, num_expectation_value)
    """
    expectation_value = self.postprocessed_data[
        [
            f"expectation_value/{col.initial_state}/{col.observable}"
            for col in self.experiment_config.expectation_values_order
        ]
    ].to_numpy()

    return np.array(expectation_value)

get_parameters_dict_list

get_parameters_dict_list() -> list[
    list[ParametersDictType]
]

Get the list, where each element is list of dict of the control parameters of the dataset.

Returns:

Type Description
list[list[ParametersDictType]]

list[list[ParametersDictType]]: The list of list of dict of parameter.

Source code in src/inspeqtor/v1/data.py
732
733
734
735
736
737
738
739
740
741
742
743
744
745
def get_parameters_dict_list(self) -> list[list[ParametersDictType]]:
    """Get the list, where each element is list of dict of the control parameters of the dataset.

    Returns:
        list[list[ParametersDictType]]: The list of list of dict of parameter.
    """
    _temp = self.postprocessed_data[self.parameter_columns]

    _params_list = [
        get_parameters_dict_list(self.experiment_config.parameter_names, row)
        for _, row in _temp.iterrows()
    ]

    return _params_list

save_to_folder

save_to_folder(path: Union[Path, str])

Save the experiment data to given folder

Parameters:

Name Type Description Default
path Union[Path, str]

Path of the folder for experiment data to be saved.

required
Source code in src/inspeqtor/v1/data.py
747
748
749
750
751
752
753
754
755
756
757
758
759
760
def save_to_folder(self, path: typing.Union[Path, str]):
    """Save the experiment data to given folder

    Args:
        path (typing.Union[Path, str]): Path of the folder for experiment data to be saved.
    """
    if isinstance(path, str):
        path = Path(path)

    # os.makedirs(path, exist_ok=True)
    path.mkdir(parents=True, exist_ok=True)
    self.experiment_config.to_file(path)
    self.preprocess_data.to_csv(path / "preprocess_data.csv", index=False)
    self.postprocessed_data.to_csv(path / "postprocessed_data.csv", index=False)

from_folder classmethod

from_folder(path: Union[Path, str]) -> ExperimentData

Read the experiment data from path

Parameters:

Name Type Description Default
path Union[Path, str]

path to the folder contain experiment data. Expected to be used with self.save_to_folder method.

required

Returns:

Name Type Description
ExperimentData ExperimentData

Intance of ExperimentData read from path.

Source code in src/inspeqtor/v1/data.py
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
@classmethod
def from_folder(cls, path: typing.Union[Path, str]) -> "ExperimentData":
    """Read the experiment data from path

    Args:
        path (typing.Union[Path, str]): path to the folder contain experiment data. Expected to be used with `self.save_to_folder` method.

    Returns:
        ExperimentData: Intance of `ExperimentData` read from path.
    """
    if isinstance(path, str):
        path = Path(path)

    experiment_config = ExperimentConfiguration.from_file(path)
    preprocess_data = pd.read_csv(
        path / "preprocess_data.csv",
    )

    # Check if postprocessed_data exists
    if not (path / "postprocessed_data.csv").exists():
        # if not os.path.exists(path / "postprocessed_data.csv"):
        postprocessed_data = None
    else:
        postprocessed_data = pd.read_csv(
            path / "postprocessed_data.csv",
        )

    return cls(
        experiment_config=experiment_config,
        preprocess_data=preprocess_data,
        _postprocessed_data=postprocessed_data,
    )

add_hilbert_level

add_hilbert_level(op: ndarray, x: ndarray) -> ndarray

Add a level to the operator or state

Parameters:

Name Type Description Default
op ndarray

The qubit operator or state

required
is_state bool

True if the operator is a state, False if the operator is an operator

required

Returns:

Type Description
ndarray

jnp.ndarray: The qutrit operator or state

Source code in src/inspeqtor/v1/data.py
25
26
27
28
29
30
31
32
33
34
35
def add_hilbert_level(op: jnp.ndarray, x: jnp.ndarray) -> jnp.ndarray:
    """Add a level to the operator or state

    Args:
        op (jnp.ndarray): The qubit operator or state
        is_state (bool): True if the operator is a state, False if the operator is an operator

    Returns:
        jnp.ndarray: The qutrit operator or state
    """
    return jax.scipy.linalg.block_diag(op, x)

flatten_parameter_name_with_prefix

flatten_parameter_name_with_prefix(
    parameter_names: Sequence[Sequence[str]],
) -> list[str]

Create a flatten list of parameter names with prefix parameter/{i}/

Parameters:

Name Type Description Default
parameter_names Sequence[Sequence[str]]

The list of parameter names from the pulse sequence or the experiment configuration

required

Returns:

Type Description
list[str]

list[str]: The flatten list of parameter names with prefix parameter/{i}/

Source code in src/inspeqtor/v1/data.py
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
def flatten_parameter_name_with_prefix(
    parameter_names: typing.Sequence[typing.Sequence[str]],
) -> list[str]:
    """Create a flatten list of parameter names with prefix parameter/{i}/

    Args:
        parameter_names (typing.Sequence[typing.Sequence[str]]): The list of parameter names from the pulse sequence
                                                                 or the experiment configuration

    Returns:
        list[str]: The flatten list of parameter names with prefix parameter/{i}/
    """
    return [
        f"parameter/{i}/{name}"
        for i, names in enumerate(parameter_names)
        for name in names
    ]

transform_parameter_name

transform_parameter_name(name: str) -> str

Remove "parameter/{i}/" from provided name

Parameters:

Name Type Description Default
name str

Name of the control parameters

required

Returns:

Name Type Description
str str

Name that have "parameter/{i}/" strip.

Source code in src/inspeqtor/v1/data.py
454
455
456
457
458
459
460
461
462
463
464
465
466
def transform_parameter_name(name: str) -> str:
    """Remove "parameter/{i}/" from provided name

    Args:
        name (str): Name of the control parameters

    Returns:
        str: Name that have "parameter/{i}/" strip.
    """
    if name.startswith("parameter/"):
        return "/".join(name.split("/")[2:])
    else:
        return name

get_parameters_dict_list

get_parameters_dict_list(
    parameters_name: Sequence[Sequence[str]],
    parameters_row: Series,
) -> list[ParametersDictType]

Get the list of dict containing name and value of each control in the sequence.

Parameters:

Name Type Description Default
parameters_name Sequence[Sequence[str]]

description

required
parameters_row Series

description

required

Returns:

Type Description
list[ParametersDictType]

list[ParametersDictType]: description

Source code in src/inspeqtor/v1/data.py
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
def get_parameters_dict_list(
    parameters_name: typing.Sequence[typing.Sequence[str]], parameters_row: pd.Series
) -> list[ParametersDictType]:
    """Get the list of dict containing name and value of each control in the sequence.

    Args:
        parameters_name (typing.Sequence[typing.Sequence[str]]): _description_
        parameters_row (pd.Series): _description_

    Returns:
        list[ParametersDictType]: _description_
    """
    recovered_parameters: list[ParametersDictType] = [
        {
            # Split to remove the parameter/{i}/
            transform_parameter_name(k): v
            for k, v in parameters_row.items()
            # Check if the key is parameter/{i}/ and the value is float
            if isinstance(k, str)
            and k.startswith(f"parameter/{i}/")
            and isinstance(v, (float, int))
        }
        for i in range(len(parameters_name))
    ]

    return recovered_parameters

save_to_json

save_to_json(data: dict, path: Union[str, Path])

Save the dictionary as json to the path

Parameters:

Name Type Description Default
data dict

Dict to be save to file

required
path Union[str, Path]

Path to save file.

required
Source code in src/inspeqtor/v1/data.py
817
818
819
820
821
822
823
824
825
826
827
828
829
def save_to_json(data: dict, path: typing.Union[str, Path]):
    """Save the dictionary as json to the path

    Args:
        data (dict): Dict to be save to file
        path (typing.Union[str, Path]): Path to save file.
    """
    if isinstance(path, str):
        path = Path(path)

    path.parent.mkdir(exist_ok=True)
    with open(path, "w") as f:
        json.dump(data, f, indent=4)

read_from_json

read_from_json(
    path: Union[str, Path],
    dataclass: Union[None, type[DataclassVar]] = None,
) -> Union[dict, DataclassVar]

Construct provided dataclass instance with json file

Parameters:

Name Type Description Default
path Union[str, Path]

Path to json file

required
dataclass Union[None, type[DataclassVar]]

The constructor of the dataclass. Defaults to None.

None

Returns:

Type Description
Union[dict, DataclassVar]

typing.Union[dict, DataclassVar]: Dataclass instance, if dataclass is not provideded, return dict instead.

Source code in src/inspeqtor/v1/data.py
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
def read_from_json(
    path: typing.Union[str, Path],
    dataclass: typing.Union[None, type[DataclassVar]] = None,
) -> typing.Union[dict, DataclassVar]:
    """Construct provided `dataclass` instance with json file

    Args:
        path (typing.Union[str, Path]): Path to json file
        dataclass (typing.Union[None, type[DataclassVar]], optional): The constructor of the dataclass. Defaults to None.

    Returns:
        typing.Union[dict, DataclassVar]: Dataclass instance, if dataclass is not provideded, return dict instead.
    """
    if isinstance(path, str):
        path = Path(path)
    with open(path, "r") as f:
        config_dict = json.load(f)

    if dataclass is None:
        return config_dict
    else:
        return dataclass(**config_dict)

load_pytree_from_json

load_pytree_from_json(
    path: str | Path, parse_fn=default_parse_fn
)

Load pytree from json

Parameters:

Name Type Description Default
path str | Path

Path to JSON file containing pytree

required
array_keys list[str]

list of key to convert to jnp.numpy. Defaults to [].

required

Raises:

Type Description
ValueError

Provided path is not point to .json file

Returns:

Type Description

typing.Any: Pytree loaded from JSON

Source code in src/inspeqtor/v1/data.py
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
def load_pytree_from_json(path: str | Path, parse_fn=default_parse_fn):
    """Load pytree from json

    Args:
        path (str | Path): Path to JSON file containing pytree
        array_keys (list[str], optional): list of key to convert to jnp.numpy. Defaults to [].

    Raises:
        ValueError: Provided path is not point to .json file

    Returns:
        typing.Any: Pytree loaded from JSON
    """

    # Validate that file extension is .json
    extension = str(path).split(".")[-1]

    if extension != "json":
        raise ValueError("File extension must be json")

    if isinstance(path, str):
        path = Path(path)

    with open(path, "r") as f:
        data = json.load(f)

    data = recursive_parse(data, parse_fn=parse_fn)

    return data

save_pytree_to_json

save_pytree_to_json(pytree, path: str | Path)

Save given pytree to json file, the path must end with extension of .json

Parameters:

Name Type Description Default
pytree Any

The pytree to save

required
path str | Path

File path to save

required
Source code in src/inspeqtor/v1/data.py
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
def save_pytree_to_json(pytree, path: str | Path):
    """Save given pytree to json file, the path must end with extension of .json

    Args:
        pytree (typing.Any): The pytree to save
        path (str | Path): File path to save

    """

    # Convert jax.ndarray
    data = jax.tree.map(
        lambda x: x.tolist() if isinstance(x, jnp.ndarray) else x, pytree
    )
    # Convert ParamShape
    data = jax.tree.map(param_shape_to_dict, data, is_leaf=is_param_shape)

    if isinstance(path, str):
        path = Path(path)

    path.parent.mkdir(exist_ok=True, parents=True)

    with open(path, "w") as f:
        json.dump(data, f, indent=4)

from_long_to_wide

from_long_to_wide(preprocessed_df: DataFrame)

An experimental function to transform preprocess dataframe to postprocess dataframe.

Parameters:

Name Type Description Default
preprocessed_df DataFrame

The preprocess dataframe

required

Returns:

Type Description

pd.DataFrame: The postprocessed dataframe

Source code in src/inspeqtor/v1/data.py
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
def from_long_to_wide(preprocessed_df: pd.DataFrame):
    """An experimental function to transform preprocess dataframe to postprocess dataframe.

    Args:
        preprocessed_df (pd.DataFrame): The preprocess dataframe

    Returns:
        pd.DataFrame: The postprocessed dataframe
    """
    # Handle the expectation value using unstack
    expvals_df = preprocessed_df.pivot(
        index="parameters_id",
        columns=["initial_state", "observable"],
        values="expectation_value",  # Note: string, not a list
    )

    # Rename columns using another idiomatic approach
    expvals_df.columns = [
        f"expectation_value/{state}/{obs}" for state, obs in expvals_df.columns
    ]

    # Handle parameters columns
    params_df = (
        preprocessed_df.groupby("parameters_id")
        .first()
        .drop(
            ["expectation_value", "initial_state", "observable"], axis=1, inplace=False
        )
    )

    # Combine with join
    return params_df.join(expvals_df)

from_wide_to_long_simple

from_wide_to_long_simple(wide_df: DataFrame)

A more concise version to convert a wide DataFrame back to the long format.

Source code in src/inspeqtor/v1/data.py
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
def from_wide_to_long_simple(wide_df: pd.DataFrame):
    """
    A more concise version to convert a wide DataFrame back to the long format.
    """
    # Work with the index as a column
    df = wide_df.reset_index()

    # 1. Identify all columns that should NOT be melted (the "id" columns)
    id_vars = [col for col in df.columns if not col.startswith("expectation_value/")]

    # 2. Melt the DataFrame. pd.melt automatically uses all columns NOT in id_vars as value_vars.
    long_df = df.melt(
        id_vars=id_vars, var_name="descriptor", value_name="expectation_value"
    )

    # 3. Split the descriptor and assign new columns in one step
    long_df[["_,", "initial_state", "observable"]] = long_df["descriptor"].str.split(
        "/", expand=True
    )

    # 4. Clean up the DataFrame by dropping temporary columns and sorting
    return (
        long_df.drop(columns=["descriptor", "_,"])
        .sort_values("parameters_id")
        .reset_index(drop=True)
    )

Linen

src.inspeqtor.v1.models.linen

loss_fn

loss_fn(
    params: VariableDict,
    control_parameters: ndarray,
    unitaries: ndarray,
    expectation_values: ndarray,
    model: Module,
    predictive_fn: Callable,
    loss_metric: LossMetric,
    calculate_metric_fn: Callable = calculate_metric,
    **model_kwargs,
) -> tuple[ndarray, dict[str, ndarray]]

This function implement a unified interface for nn.Module.

Parameters:

Name Type Description Default
params VariableDict

Model parameters to be optimized

required
control_parameters ndarray

Control parameters parametrized Hamiltonian

required
unitaries ndarray

The Ideal unitary operators corresponding to the control parameters

required
expectation_values ndarray

Experimental expectation values to calculate the loss value

required
model Module

Flax linen Blackbox part of the graybox model.

required
predictive_fn Callable

Function for calculating expectation value from the model

required
loss_metric LossMetric

The choice of loss value to be minimized.

required
calculate_metric_fn Callable

Function for metrics calculation from prediction and experimental value. Defaults to calculate_metric

calculate_metric

Returns:

Type Description
tuple[ndarray, dict[str, ndarray]]

tuple[jnp.ndarray, dict[str, jnp.ndarray]]: The loss value and other metrics.

Source code in src/inspeqtor/v1/models/linen.py
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
@deprecated.deprecated(reason="use make_loss_fn instead")
def loss_fn(
    params: VariableDict,
    control_parameters: jnp.ndarray,
    unitaries: jnp.ndarray,
    expectation_values: jnp.ndarray,
    model: nn.Module,
    predictive_fn: typing.Callable,
    loss_metric: LossMetric,
    calculate_metric_fn: typing.Callable = calculate_metric,
    **model_kwargs,
) -> tuple[jnp.ndarray, dict[str, jnp.ndarray]]:
    """This function implement a unified interface for nn.Module.

    Args:
        params (VariableDict): Model parameters to be optimized
        control_parameters (jnp.ndarray): Control parameters parametrized Hamiltonian
        unitaries (jnp.ndarray): The Ideal unitary operators corresponding to the control parameters
        expectation_values (jnp.ndarray): Experimental expectation values to calculate the loss value
        model (nn.Module): Flax linen Blackbox part of the graybox model.
        predictive_fn (typing.Callable): Function for calculating expectation value from the model
        loss_metric (LossMetric): The choice of loss value to be minimized.
        calculate_metric_fn (typing.Callable): Function for metrics calculation from prediction and experimental value. Defaults to calculate_metric

    Returns:
        tuple[jnp.ndarray, dict[str, jnp.ndarray]]: The loss value and other metrics.
    """
    # Calculate the metrics
    predicted_expectation_value = predictive_fn(
        model=model,
        model_params=params,
        control_parameters=control_parameters,
        unitaries=unitaries,
        **model_kwargs,
    )

    metrics = calculate_metric_fn(
        unitaries, expectation_values, predicted_expectation_value
    )

    # Take mean of all the metrics
    metrics = jax.tree.map(jnp.mean, metrics)

    # ! Grab the metric in the `metrics`
    loss = metrics[loss_metric]

    return (loss, metrics)

wo_predictive_fn

wo_predictive_fn(
    control_parameters: ndarray,
    unitaries: ndarray,
    model: Module,
    model_params: VariableDict,
    **model_kwargs,
)

To Calculate the metrics of the model 1. MSE Loss between the predicted expectation values and the experimental expectation values 2. Average Gate Fidelity between the Pauli matrices to the Wo_model matrices 3. AGF Loss between the prediction from model and the experimental expectation values

Parameters:

Name Type Description Default
model Module

The model to be used for prediction

required
model_params VariableDict

The model parameters

required
control_parameters ndarray

The pulse parameters

required
unitaries ndarray

Ideal unitaries

required
expectation_values ndarray

Experimental expectation values

required
model_kwargs dict

Model keyword arguments

{}
Source code in src/inspeqtor/v1/models/linen.py
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
def wo_predictive_fn(
    # Input data to the model
    control_parameters: jnp.ndarray,
    unitaries: jnp.ndarray,
    # The model to be used for prediction
    model: nn.Module,
    model_params: VariableDict,
    # model keyword arguments
    **model_kwargs,
):
    """To Calculate the metrics of the model
    1. MSE Loss between the predicted expectation values and the experimental expectation values
    2. Average Gate Fidelity between the Pauli matrices to the Wo_model matrices
    3. AGF Loss between the prediction from model and the experimental expectation values

    Args:
        model (sq.model.nn.Module): The model to be used for prediction
        model_params (sq.model.VariableDict): The model parameters
        control_parameters (jnp.ndarray): The pulse parameters
        unitaries (jnp.ndarray): Ideal unitaries
        expectation_values (jnp.ndarray): Experimental expectation values
        model_kwargs (dict): Model keyword arguments
    """

    # Calculate Wo_params
    Wo = model.apply(model_params, control_parameters, **model_kwargs)

    return observable_to_expvals(Wo, unitaries)

noisy_unitary_predictive_fn

noisy_unitary_predictive_fn(
    control_parameters: ndarray,
    unitaries: ndarray,
    model: UnitaryModel,
    model_params: VariableDict,
    **model_kwargs,
)

Caculate for unitary-based Blackbox model

Parameters:

Name Type Description Default
model Module

The model to be used for prediction

required
model_params VariableDict

The model parameters

required
control_parameters ndarray

The pulse parameters

required
unitaries ndarray

Ideal unitaries

required
expectation_values ndarray

Experimental expectation values

required
model_kwargs dict

Model keyword arguments

{}

Returns:

Type Description

typing.Any: description

Source code in src/inspeqtor/v1/models/linen.py
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
def noisy_unitary_predictive_fn(
    # Input data to the model
    control_parameters: jnp.ndarray,
    unitaries: jnp.ndarray,
    # The model to be used for prediction
    model: UnitaryModel,
    model_params: VariableDict,
    # model keyword arguments
    **model_kwargs,
):
    """Caculate for unitary-based Blackbox model

    Args:
        model (sq.model.nn.Module): The model to be used for prediction
        model_params (sq.model.VariableDict): The model parameters
        control_parameters (jnp.ndarray): The pulse parameters
        unitaries (jnp.ndarray): Ideal unitaries
        expectation_values (jnp.ndarray): Experimental expectation values
        model_kwargs (dict): Model keyword arguments

    Returns:
        typing.Any: _description_
    """

    # Predict Unitary parameters
    unitary_params = model.apply(model_params, control_parameters, **model_kwargs)

    return unitary_to_expvals(unitary_params, unitaries)

toggling_unitary_predictive_fn

toggling_unitary_predictive_fn(
    control_parameters: ndarray,
    unitaries: ndarray,
    model: UnitaryModel,
    model_params: VariableDict,
    ignore_spam: bool = False,
    **model_kwargs,
)

Calcuate for unitary-based Blackbox model

Parameters:

Name Type Description Default
model Module

The model to be used for prediction

required
model_params VariableDict

The model parameters

required
control_parameters ndarray

The pulse parameters

required
unitaries ndarray

Ideal unitaries

required
expectation_values ndarray

Experimental expectation values

required
model_kwargs dict

Model keyword arguments

{}

Returns:

Type Description

typing.Any: description

Source code in src/inspeqtor/v1/models/linen.py
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
def toggling_unitary_predictive_fn(
    # Input data to the model
    control_parameters: jnp.ndarray,
    unitaries: jnp.ndarray,
    # The model to be used for prediction
    model: UnitaryModel,
    model_params: VariableDict,
    # model keyword arguments
    ignore_spam: bool = False,
    **model_kwargs,
):
    """Calcuate for unitary-based Blackbox model

    Args:
        model (sq.model.nn.Module): The model to be used for prediction
        model_params (sq.model.VariableDict): The model parameters
        control_parameters (jnp.ndarray): The pulse parameters
        unitaries (jnp.ndarray): Ideal unitaries
        expectation_values (jnp.ndarray): Experimental expectation values
        model_kwargs (dict): Model keyword arguments

    Returns:
        typing.Any: _description_
    """

    # Predict Unitary parameters
    unitary_params = model.apply(model_params, control_parameters, **model_kwargs)

    if not ignore_spam:
        return toggling_unitary_with_spam_to_expvals(
            output={
                "model_params": unitary_params,
                "spam_params": model_params["spam"],
            },
            unitaries=unitaries,
        )
    else:
        return toggling_unitary_to_expvals(
            unitary_params,  # type: ignore
            unitaries=unitaries,
        )

make_loss_fn_old

make_loss_fn_old(
    predictive_fn: Callable,
    model: Module,
    calculate_metric_fn: Callable = calculate_metric,
    loss_metric: LossMetric = MSEE,
)

summary

Parameters:

Name Type Description Default
predictive_fn Callable

Function for calculating expectation value from the model

required
model Module

Flax linen Blackbox part of the graybox model.

required
loss_metric LossMetric

The choice of loss value to be minimized. Defaults to LossMetric.MSEE.

MSEE
calculate_metric_fn Callable

Function for metrics calculation from prediction and experimental value. Defaults to calculate_metric.

calculate_metric
Source code in src/inspeqtor/v1/models/linen.py
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
@deprecated.deprecated
def make_loss_fn_old(
    predictive_fn: typing.Callable,
    model: nn.Module,
    calculate_metric_fn: typing.Callable = calculate_metric,
    loss_metric: LossMetric = LossMetric.MSEE,
):
    """_summary_

    Args:
        predictive_fn (typing.Callable): Function for calculating expectation value from the model
        model (nn.Module): Flax linen Blackbox part of the graybox model.
        loss_metric (LossMetric): The choice of loss value to be minimized. Defaults to LossMetric.MSEE.
        calculate_metric_fn (typing.Callable): Function for metrics calculation from prediction and experimental value. Defaults to calculate_metric.
    """

    def loss_fn(
        params: VariableDict,
        control_parameters: jnp.ndarray,
        unitaries: jnp.ndarray,
        expectation_values: jnp.ndarray,
        **model_kwargs,
    ) -> tuple[jnp.ndarray, dict[str, jnp.ndarray]]:
        """This function implement a unified interface for nn.Module.

        Args:
            params (VariableDict): Model parameters to be optimized
            control_parameters (jnp.ndarray): Control parameters parametrized Hamiltonian
            unitaries (jnp.ndarray): The Ideal unitary operators corresponding to the control parameters
            expectation_values (jnp.ndarray): Experimental expectation values to calculate the loss value

        Returns:
            tuple[jnp.ndarray, dict[str, jnp.ndarray]]: The loss value and other metrics.
        """
        # Calculate the metrics
        predicted_expectation_value = predictive_fn(
            model=model,
            model_params=params,
            control_parameters=control_parameters,
            unitaries=unitaries,
            **model_kwargs,
        )

        metrics = calculate_metric_fn(
            unitaries, expectation_values, predicted_expectation_value
        )

        # Take mean of all the metrics
        metrics = jax.tree.map(jnp.mean, metrics)

        # ! Grab the metric in the `metrics`
        loss = metrics[loss_metric]

        return (loss, metrics)

    return loss_fn

make_loss_fn_oldv2

make_loss_fn_oldv2(
    adapter_fn: Callable,
    model: Module,
    calculate_metric_fn: Callable = calculate_metric,
    loss_metric: LossMetric = MSEE,
)

summary

Parameters:

Name Type Description Default
predictive_fn Callable

Function for calculating expectation value from the model

required
model Module

Flax linen Blackbox part of the graybox model.

required
loss_metric LossMetric

The choice of loss value to be minimized. Defaults to LossMetric.MSEE.

MSEE
calculate_metric_fn Callable

Function for metrics calculation from prediction and experimental value. Defaults to calculate_metric.

calculate_metric
Source code in src/inspeqtor/v1/models/linen.py
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
def make_loss_fn_oldv2(
    adapter_fn: typing.Callable,
    model: nn.Module,
    calculate_metric_fn: typing.Callable = calculate_metric,
    loss_metric: LossMetric = LossMetric.MSEE,
):
    """_summary_

    Args:
        predictive_fn (typing.Callable): Function for calculating expectation value from the model
        model (nn.Module): Flax linen Blackbox part of the graybox model.
        loss_metric (LossMetric): The choice of loss value to be minimized. Defaults to LossMetric.MSEE.
        calculate_metric_fn (typing.Callable): Function for metrics calculation from prediction and experimental value. Defaults to calculate_metric.
    """

    def loss_fn(
        params: VariableDict,
        control_parameters: jnp.ndarray,
        unitaries: jnp.ndarray,
        expectation_values: jnp.ndarray,
        **model_kwargs,
    ) -> tuple[jnp.ndarray, dict[str, jnp.ndarray]]:
        """This function implement a unified interface for nn.Module.

        Args:
            params (VariableDict): Model parameters to be optimized
            control_parameters (jnp.ndarray): Control parameters parametrized Hamiltonian
            unitaries (jnp.ndarray): The Ideal unitary operators corresponding to the control parameters
            expectation_values (jnp.ndarray): Experimental expectation values to calculate the loss value

        Returns:
            tuple[jnp.ndarray, dict[str, jnp.ndarray]]: The loss value and other metrics.
        """
        output = model.apply(params, control_parameters, **model_kwargs)
        predicted_expectation_value = adapter_fn(output, unitaries=unitaries)

        # Calculate the metrics
        metrics = calculate_metric_fn(
            unitaries, expectation_values, predicted_expectation_value
        )

        # Take mean of all the metrics
        metrics = jax.tree.map(jnp.mean, metrics)

        # ! Grab the metric in the `metrics`
        loss = metrics[loss_metric]

        return (loss, metrics)

    return loss_fn

make_loss_fn

make_loss_fn(
    adapter_fn: Callable,
    model: Module,
    evaluate_fn: Callable[
        [ndarray, ndarray, ndarray], ndarray
    ],
)

summary

Parameters:

Name Type Description Default
predictive_fn Callable

Function for calculating expectation value from the model

required
model Module

Flax linen Blackbox part of the graybox model.

required
evaluate_fn Callable[[ndarray, ndarray], ndarray, ndarray]

Take in predicted and experimental expectation values and ideal unitary and return loss value

required
Source code in src/inspeqtor/v1/models/linen.py
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
def make_loss_fn(
    adapter_fn: typing.Callable,
    model: nn.Module,
    evaluate_fn: typing.Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray], jnp.ndarray],
):
    """_summary_

    Args:
        predictive_fn (typing.Callable): Function for calculating expectation value from the model
        model (nn.Module): Flax linen Blackbox part of the graybox model.
        evaluate_fn ( typing.Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray, jnp.ndarray]): Take in predicted and experimental expectation values and ideal unitary and return loss value
    """

    def loss_fn(
        params: VariableDict,
        control_parameters: jnp.ndarray,
        unitaries: jnp.ndarray,
        expectation_values: jnp.ndarray,
        **model_kwargs,
    ) -> tuple[jnp.ndarray, dict[str, jnp.ndarray]]:
        """This function implement a unified interface for nn.Module.

        Args:
            params (VariableDict): Model parameters to be optimized
            control_parameters (jnp.ndarray): Control parameters parametrized Hamiltonian
            unitaries (jnp.ndarray): The Ideal unitary operators corresponding to the control parameters
            expectation_values (jnp.ndarray): Experimental expectation values to calculate the loss value

        Returns:
            tuple[jnp.ndarray, dict[str, jnp.ndarray]]: The loss value and other metrics.
        """
        output = model.apply(params, control_parameters, **model_kwargs)
        predicted_expectation_value = adapter_fn(output, unitaries=unitaries)

        loss = evaluate_fn(predicted_expectation_value, expectation_values, unitaries)

        return jnp.mean(loss), {}

    return loss_fn

create_step

create_step(
    optimizer: GradientTransformation,
    loss_fn: Callable[..., ndarray]
    | Callable[..., Tuple[ndarray, Any]],
    has_aux: bool = False,
)

The create_step function creates a training step function and a test step function.

loss_fn should have the following signature:

def loss_fn(params: jaxtyping.PyTree, *args) -> jnp.ndarray:
    ...
    return loss_value
where params is the parameters to be optimized, and args are the inputs for the loss function.

Parameters:

Name Type Description Default
optimizer GradientTransformation

optax optimizer.

required
loss_fn Callable[[PyTree, ...], ndarray]

Loss function, which takes in the model parameters, inputs, and targets, and returns the loss value.

required
has_aux bool

Whether the loss function return aux data or not. Defaults to False.

False

Returns:

Type Description

typing.Any: train_step, test_step

Source code in src/inspeqtor/v1/models/linen.py
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
def create_step(
    optimizer: optax.GradientTransformation,
    loss_fn: (
        typing.Callable[..., jnp.ndarray]
        | typing.Callable[..., typing.Tuple[jnp.ndarray, typing.Any]]
    ),
    has_aux: bool = False,
):
    """The create_step function creates a training step function and a test step function.

    loss_fn should have the following signature:
    ```py
    def loss_fn(params: jaxtyping.PyTree, *args) -> jnp.ndarray:
        ...
        return loss_value
    ```
    where `params` is the parameters to be optimized, and `args` are the inputs for the loss function.

    Args:
        optimizer (optax.GradientTransformation): `optax` optimizer.
        loss_fn (typing.Callable[[jaxtyping.PyTree, ...], jnp.ndarray]): Loss function, which takes in the model parameters, inputs, and targets, and returns the loss value.
        has_aux (bool, optional): Whether the loss function return aux data or not. Defaults to False.

    Returns:
        _typing.Any_: train_step, test_step
    """

    # * Generalized training step
    @jax.jit
    def train_step(
        params: jaxtyping.PyTree,
        optimizer_state: optax.OptState,
        *args,
        **kwargs,
    ):
        loss_value, grads = jax.value_and_grad(loss_fn, has_aux=has_aux)(
            params, *args, **kwargs
        )
        updates, opt_state = optimizer.update(grads, optimizer_state, params)
        params = optax.apply_updates(params, updates)

        return params, opt_state, loss_value

    @jax.jit
    def test_step(
        params: jaxtyping.PyTree,
        *args,
        **kwargs,
    ):
        return loss_fn(params, *args, **kwargs)

    return train_step, test_step

train_model

train_model(
    key: ndarray,
    train_data: DataBundled,
    val_data: DataBundled,
    test_data: DataBundled,
    model: Module,
    optimizer: GradientTransformation,
    loss_fn: Callable,
    callbacks: list[Callable] = [],
    NUM_EPOCH: int = 1000,
    model_params: VariableDict | None = None,
    opt_state: OptState | None = None,
)

Train the BlackBox model

Examples:

>>> # The number of epochs break down
... NUM_EPOCH = 150
... # Total number of iterations as 90% of data is used for training
... # 10% of the data is used for testing
... total_iterations = 9 * NUM_EPOCH
... # The step for optimizer if set to 8 * NUM_EPOCH (should be less than total_iterations)
... step_for_optimizer = 8 * NUM_EPOCH
... optimizer = get_default_optimizer(step_for_optimizer)
... # The warmup steps for the optimizer
... warmup_steps = 0.1 * step_for_optimizer
... # The cool down steps for the optimizer
... cool_down_steps = total_iterations - step_for_optimizer
... total_iterations, step_for_optimizer, warmup_steps, cool_down_steps

Parameters:

Name Type Description Default
key ndarray

Random key

required
model Module

The model to be used for training

required
optimizer GradientTransformation

The optimizer to be used for training

required
loss_fn Callable

The loss function to be used for training

required
callbacks list[Callable]

list of callback functions. Defaults to [].

[]
NUM_EPOCH int

The number of epochs. Defaults to 1_000.

1000

Returns:

Name Type Description
tuple

The model parameters, optimizer state, and the histories

Source code in src/inspeqtor/v1/models/linen.py
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
def train_model(
    # Random key
    key: jnp.ndarray,
    # Data
    train_data: DataBundled,
    val_data: DataBundled,
    test_data: DataBundled,
    # Model to be used for training
    model: nn.Module,
    optimizer: optax.GradientTransformation,
    # Loss function to be used
    loss_fn: typing.Callable,
    # Callbacks to be used
    callbacks: list[typing.Callable] = [],
    # Number of epochs
    NUM_EPOCH: int = 1_000,
    # Optional state
    model_params: VariableDict | None = None,
    opt_state: optax.OptState | None = None,
):
    """Train the BlackBox model

    Examples:
        >>> # The number of epochs break down
        ... NUM_EPOCH = 150
        ... # Total number of iterations as 90% of data is used for training
        ... # 10% of the data is used for testing
        ... total_iterations = 9 * NUM_EPOCH
        ... # The step for optimizer if set to 8 * NUM_EPOCH (should be less than total_iterations)
        ... step_for_optimizer = 8 * NUM_EPOCH
        ... optimizer = get_default_optimizer(step_for_optimizer)
        ... # The warmup steps for the optimizer
        ... warmup_steps = 0.1 * step_for_optimizer
        ... # The cool down steps for the optimizer
        ... cool_down_steps = total_iterations - step_for_optimizer
        ... total_iterations, step_for_optimizer, warmup_steps, cool_down_steps

    Args:
        key (jnp.ndarray): Random key
        model (nn.Module): The model to be used for training
        optimizer (optax.GradientTransformation): The optimizer to be used for training
        loss_fn (typing.Callable): The loss function to be used for training
        callbacks (list[typing.Callable], optional): list of callback functions. Defaults to [].
        NUM_EPOCH (int, optional): The number of epochs. Defaults to 1_000.

    Returns:
        tuple: The model parameters, optimizer state, and the histories
    """

    key, loader_key, init_key = jax.random.split(key, 3)

    train_p, train_u, train_ex = (
        train_data.control_params,
        train_data.unitaries,
        train_data.observables,
    )
    val_p, val_u, val_ex = (
        val_data.control_params,
        val_data.unitaries,
        val_data.observables,
    )
    test_p, test_u, test_ex = (
        test_data.control_params,
        test_data.unitaries,
        test_data.observables,
    )

    BATCH_SIZE = val_p.shape[0]

    if model_params is None:
        # Initialize the model parameters if it is None
        model_params = model.init(init_key, train_p[0])

    if opt_state is None:
        # Initalize the optimizer state if it is None
        opt_state = optimizer.init(model_params)  # type: ignore

    # histories: list[dict[str, typing.Any]] = []
    histories: list[HistoryEntryV3] = []

    train_step, eval_step = create_step(
        optimizer=optimizer, loss_fn=loss_fn, has_aux=True
    )

    for (step, batch_idx, is_last_batch, epoch_idx), (
        batch_p,
        batch_u,
        batch_ex,
    ) in dataloader(
        (train_p, train_u, train_ex),
        batch_size=BATCH_SIZE,
        num_epochs=NUM_EPOCH,
        key=loader_key,
    ):
        model_params, opt_state, (loss, aux) = train_step(
            model_params, opt_state, batch_p, batch_u, batch_ex
        )

        histories.append(HistoryEntryV3(step=step, loss=loss, loop="train", aux=aux))

        if is_last_batch:
            # Validation
            (val_loss, aux) = eval_step(model_params, val_p, val_u, val_ex)

            histories.append(
                HistoryEntryV3(step=step, loss=val_loss, loop="val", aux=aux)
            )

            # Testing
            (test_loss, aux) = eval_step(model_params, test_p, test_u, test_ex)

            histories.append(
                HistoryEntryV3(step=step, loss=test_loss, loop="test", aux=aux)
            )

            for callback in callbacks:
                callback(model_params, opt_state, histories)

    return model_params, opt_state, histories

NNX

src.inspeqtor.v1.models.nnx

Blackbox

The abstract class for interfacing the Blackbox model of the Graybox

Source code in src/inspeqtor/v1/models/nnx.py
22
23
24
25
26
27
28
29
class Blackbox(nnx.Module):
    """The abstract class for interfacing the Blackbox model of the Graybox"""

    def __init__(self, *, rngs: nnx.Rngs) -> None:
        super().__init__()

    def __call__(self, *args: typing.Any, **kwds: typing.Any) -> typing.Any:
        raise NotImplementedError()

WoModel

\(\hat{W}_{O}\) based blackbox model.

Source code in src/inspeqtor/v1/models/nnx.py
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
class WoModel(Blackbox):
    """$\\hat{W}_{O}$ based blackbox model."""

    def __init__(
        self, shared_layers: list[int], pauli_layers: list[int], *, rngs: nnx.Rngs
    ):
        """
        Args:
            shared_layers (list[int]): Each integer in the list is a size of the width of each hidden layer in the shared layers.
            pauli_layers (list[int]): Each integer in the list is a size of the width of each hidden layer in the Pauli layers.
            rngs (nnx.Rngs): Random number generator of `nnx`.
        """

        self.shared_layers = nnx.Dict(
            {
                f"shared/{idx}": nnx.Linear(
                    in_features=in_features, out_features=out_features, rngs=rngs
                )
                for idx, (in_features, out_features) in enumerate(
                    zip(shared_layers[:-1], shared_layers[1:])
                )
            }
        )

        self.num_shared_layers = len(shared_layers) - 1
        self.num_pauli_layers = len(pauli_layers) - 1

        self.pauli_layers = nnx.Dict()
        self.unitary_layers = nnx.Dict()
        self.diagonal_layers = nnx.Dict()
        for pauli in ["X", "Y", "Z"]:
            layers = nnx.Dict(
                {
                    f"pauli/{idx}": nnx.Linear(
                        in_features=in_features, out_features=out_features, rngs=rngs
                    )
                    for idx, (in_features, out_features) in enumerate(
                        zip(pauli_layers[:-1], pauli_layers[1:])
                    )
                }
            )

            self.pauli_layers[pauli] = layers

            self.unitary_layers[pauli] = nnx.Linear(
                in_features=pauli_layers[-1], out_features=3, rngs=rngs
            )
            self.diagonal_layers[pauli] = nnx.Linear(
                in_features=pauli_layers[-1], out_features=2, rngs=rngs
            )

    def __call__(self, x: jnp.ndarray):
        for idx in range(self.num_shared_layers):
            layer = self.shared_layers[f"shared/{idx}"]
            x = nnx.relu(layer(x))

        observables: dict[str, jnp.ndarray] = dict()
        for pauli, pauli_layer in self.pauli_layers.items():
            _x = jnp.copy(x)
            for idx in range(self.num_pauli_layers):
                layer = pauli_layer[f"pauli/{idx}"]
                _x = nnx.relu(layer(_x))

            unitary_param = self.unitary_layers[pauli](_x)
            diagonal_param = self.diagonal_layers[pauli](_x)

            unitary_param = 2 * jnp.pi * nnx.hard_sigmoid(unitary_param)
            diagonal_param = (2 * nnx.hard_sigmoid(diagonal_param)) - 1

            observables[pauli] = hermitian(unitary_param, diagonal_param)

        return observables

__init__

__init__(
    shared_layers: list[int],
    pauli_layers: list[int],
    *,
    rngs: Rngs,
)

Parameters:

Name Type Description Default
shared_layers list[int]

Each integer in the list is a size of the width of each hidden layer in the shared layers.

required
pauli_layers list[int]

Each integer in the list is a size of the width of each hidden layer in the Pauli layers.

required
rngs Rngs

Random number generator of nnx.

required
Source code in src/inspeqtor/v1/models/nnx.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def __init__(
    self, shared_layers: list[int], pauli_layers: list[int], *, rngs: nnx.Rngs
):
    """
    Args:
        shared_layers (list[int]): Each integer in the list is a size of the width of each hidden layer in the shared layers.
        pauli_layers (list[int]): Each integer in the list is a size of the width of each hidden layer in the Pauli layers.
        rngs (nnx.Rngs): Random number generator of `nnx`.
    """

    self.shared_layers = nnx.Dict(
        {
            f"shared/{idx}": nnx.Linear(
                in_features=in_features, out_features=out_features, rngs=rngs
            )
            for idx, (in_features, out_features) in enumerate(
                zip(shared_layers[:-1], shared_layers[1:])
            )
        }
    )

    self.num_shared_layers = len(shared_layers) - 1
    self.num_pauli_layers = len(pauli_layers) - 1

    self.pauli_layers = nnx.Dict()
    self.unitary_layers = nnx.Dict()
    self.diagonal_layers = nnx.Dict()
    for pauli in ["X", "Y", "Z"]:
        layers = nnx.Dict(
            {
                f"pauli/{idx}": nnx.Linear(
                    in_features=in_features, out_features=out_features, rngs=rngs
                )
                for idx, (in_features, out_features) in enumerate(
                    zip(pauli_layers[:-1], pauli_layers[1:])
                )
            }
        )

        self.pauli_layers[pauli] = layers

        self.unitary_layers[pauli] = nnx.Linear(
            in_features=pauli_layers[-1], out_features=3, rngs=rngs
        )
        self.diagonal_layers[pauli] = nnx.Linear(
            in_features=pauli_layers[-1], out_features=2, rngs=rngs
        )

UnitaryModel

Unitary-based model, predicting parameters parametrized unitary operator in range \([0, 2\pi]\).

Source code in src/inspeqtor/v1/models/nnx.py
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
class UnitaryModel(Blackbox):
    """Unitary-based model, predicting parameters parametrized unitary operator in range $[0, 2\\pi]$."""

    def __init__(self, hidden_sizes: list[int], *, rngs: nnx.Rngs) -> None:
        """

        Args:
            hidden_sizes (list[int]): Each integer in the list is a size of the width of each hidden layer in the shared layers
            rngs (nnx.Rngs): Random number generator of `nnx`.
        """
        self.hidden_sizes = hidden_sizes
        self.NUM_UNITARY_PARAMS = 4

        self.hidden_layers = nnx.Dict(
            {
                f"hidden_layers/{idx}": nnx.Linear(
                    in_features=hidden_size, out_features=hidden_size, rngs=rngs
                )
                for idx, hidden_size in enumerate(self.hidden_sizes)
            }
        )

        # Initialize the final layer for unitary parameters
        self.final_layer = nnx.Linear(
            in_features=self.hidden_sizes[-1],
            out_features=self.NUM_UNITARY_PARAMS,
            rngs=rngs,
        )

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        # Apply the hidden layers with ReLU activation
        for _, layer in self.hidden_layers.items():
            x = nnx.relu(layer(x))

        # Apply the final layer and transform the output
        x = self.final_layer(x)
        x = 2 * jnp.pi * nnx.hard_sigmoid(x)

        return x

__init__

__init__(hidden_sizes: list[int], *, rngs: Rngs) -> None

Parameters:

Name Type Description Default
hidden_sizes list[int]

Each integer in the list is a size of the width of each hidden layer in the shared layers

required
rngs Rngs

Random number generator of nnx.

required
Source code in src/inspeqtor/v1/models/nnx.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
def __init__(self, hidden_sizes: list[int], *, rngs: nnx.Rngs) -> None:
    """

    Args:
        hidden_sizes (list[int]): Each integer in the list is a size of the width of each hidden layer in the shared layers
        rngs (nnx.Rngs): Random number generator of `nnx`.
    """
    self.hidden_sizes = hidden_sizes
    self.NUM_UNITARY_PARAMS = 4

    self.hidden_layers = nnx.Dict(
        {
            f"hidden_layers/{idx}": nnx.Linear(
                in_features=hidden_size, out_features=hidden_size, rngs=rngs
            )
            for idx, hidden_size in enumerate(self.hidden_sizes)
        }
    )

    # Initialize the final layer for unitary parameters
    self.final_layer = nnx.Linear(
        in_features=self.hidden_sizes[-1],
        out_features=self.NUM_UNITARY_PARAMS,
        rngs=rngs,
    )

UnitarySPAMModel

Composite class of unitary-based model and the SPAM model.

Source code in src/inspeqtor/v1/models/nnx.py
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
class UnitarySPAMModel(Blackbox):
    """Composite class of unitary-based model and the SPAM model."""

    def __init__(
        self, unitary_model: UnitaryModel, spam_params, *, rngs: nnx.Rngs
    ) -> None:
        """

        Args:
            unitary_model (UnitaryModel): Unitary-based model that have already initialized.
            rngs (nnx.Rngs): Random number generator of `nnx`.
        """
        self.unitary_model = unitary_model
        self.spam_params = jax.tree.map(nnx.Param, spam_params)

    def __call__(self, x: jnp.ndarray) -> dict:
        return {"model_params": self.unitary_model(x), "spam_params": self.spam_params}

__init__

__init__(
    unitary_model: UnitaryModel, spam_params, *, rngs: Rngs
) -> None

Parameters:

Name Type Description Default
unitary_model UnitaryModel

Unitary-based model that have already initialized.

required
rngs Rngs

Random number generator of nnx.

required
Source code in src/inspeqtor/v1/models/nnx.py
169
170
171
172
173
174
175
176
177
178
179
def __init__(
    self, unitary_model: UnitaryModel, spam_params, *, rngs: nnx.Rngs
) -> None:
    """

    Args:
        unitary_model (UnitaryModel): Unitary-based model that have already initialized.
        rngs (nnx.Rngs): Random number generator of `nnx`.
    """
    self.unitary_model = unitary_model
    self.spam_params = jax.tree.map(nnx.Param, spam_params)

wo_predictive_fn

wo_predictive_fn(
    control_parameters: ndarray,
    unitaries: ndarray,
    model: WoModel,
) -> ndarray

Adapter function for \(\hat{W}_{O}\) based model to be used with make_loss_fn.

Parameters:

Name Type Description Default
model WoModel

\(\hat{W}_{O}\) based model

required
data DataBundled

A bundled of data for the predictive model training.

required

Returns:

Type Description
ndarray

jnp.ndarray: Predicted expectation values.

Source code in src/inspeqtor/v1/models/nnx.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def wo_predictive_fn(
    # Input data to the model
    control_parameters: jnp.ndarray,
    unitaries: jnp.ndarray,
    model: WoModel,
) -> jnp.ndarray:
    """Adapter function for $\\hat{W}_{O}$ based model to be used with `make_loss_fn`.

    Args:
        model (WoModel): $\\hat{W}_{O}$ based model
        data (DataBundled): A bundled of data for the predictive model training.

    Returns:
        jnp.ndarray: Predicted expectation values.
    """
    output = model(control_parameters)
    return observable_to_expvals(output, unitaries)

noisy_unitary_predictive_fn

noisy_unitary_predictive_fn(
    control_parameters: ndarray,
    unitaries: ndarray,
    model: UnitaryModel,
) -> ndarray

Adapter function for unitary-based model to be used with make_loss_fn

Parameters:

Name Type Description Default
model UnitaryModel

Unitary-based model.

required
data DataBundled

A bundled of data for the predictive model training.

required

Returns:

Type Description
ndarray

jnp.ndarray: Predicted expectation values.

Source code in src/inspeqtor/v1/models/nnx.py
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
def noisy_unitary_predictive_fn(
    # Input data to the model
    control_parameters: jnp.ndarray,
    unitaries: jnp.ndarray,
    model: UnitaryModel,
) -> jnp.ndarray:
    """Adapter function for unitary-based model to be used with `make_loss_fn`

    Args:
        model (UnitaryModel): Unitary-based model.
        data (DataBundled): A bundled of data for the predictive model training.

    Returns:
        jnp.ndarray: Predicted expectation values.
    """
    unitary_params = model(control_parameters)

    return unitary_to_expvals(unitary_params, unitaries)

toggling_unitary_predictive_fn

toggling_unitary_predictive_fn(
    control_parameters: ndarray,
    unitaries: ndarray,
    model: UnitaryModel,
) -> ndarray

Adapter function for rotating toggling frame unitary based model to be used with make_loss_fn

Parameters:

Name Type Description Default
model UnitaryModel

Unitary-based model.

required
data DataBundled

A bundled of data for the predictive model training.

required

Returns:

Type Description
ndarray

jnp.ndarray: Predicted expectation values.

Source code in src/inspeqtor/v1/models/nnx.py
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
def toggling_unitary_predictive_fn(
    # Input data to the model
    control_parameters: jnp.ndarray,
    unitaries: jnp.ndarray,
    model: UnitaryModel,
) -> jnp.ndarray:
    """Adapter function for rotating toggling frame unitary based model to be used with `make_loss_fn`

    Args:
        model (UnitaryModel): Unitary-based model.
        data (DataBundled): A bundled of data for the predictive model training.

    Returns:
        jnp.ndarray: Predicted expectation values.
    """
    unitary_params = model(control_parameters)

    return toggling_unitary_to_expvals(unitary_params, unitaries)

toggling_unitary_with_spam_predictive_fn

toggling_unitary_with_spam_predictive_fn(
    control_parameters: ndarray,
    unitaries: ndarray,
    model: UnitarySPAMModel,
) -> ndarray

Adapter function for a composite rotating toggling frame unitary based model to be used with make_loss_fn

Parameters:

Name Type Description Default
model UnitaryModel

Unitary-based SPAM model.

required
data DataBundled

A bundled of data for the predictive model training.

required

Returns:

Type Description
ndarray

jnp.ndarray: Predicted expectation values.

Source code in src/inspeqtor/v1/models/nnx.py
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
def toggling_unitary_with_spam_predictive_fn(
    # Input data to the model
    control_parameters: jnp.ndarray,
    unitaries: jnp.ndarray,
    model: UnitarySPAMModel,
) -> jnp.ndarray:
    """Adapter function for a composite rotating toggling
    frame unitary based model to be used with `make_loss_fn`

    Args:
        model (UnitaryModel): Unitary-based SPAM model.
        data (DataBundled): A bundled of data for the predictive model training.

    Returns:
        jnp.ndarray: Predicted expectation values.
    """
    params = model(control_parameters)

    return toggling_unitary_with_spam_to_expvals(
        # {
        #     "model_params": params['model_params'],
        #     "spam_params": model.spam_params,
        # },
        params,
        unitaries,
    )

make_loss_fn_old

make_loss_fn_old(
    predictive_fn,
    calculate_metric_fn=calculate_metric,
    loss_metric: LossMetric = MSEE,
)

A function for preparing loss function to be used for model training.

Parameters:

Name Type Description Default
predictive_fn Any

Adaptor function specifically for each model.

required
calculate_metric_fn Any

Function for calculating metrics. Defaults to calculate_metric.

calculate_metric
loss_metric LossMetric

The chosen loss function to be optimized. Defaults to LossMetric.MSEE.

MSEE
Source code in src/inspeqtor/v1/models/nnx.py
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
@deprecated.deprecated
def make_loss_fn_old(
    predictive_fn,
    calculate_metric_fn=calculate_metric,
    loss_metric: LossMetric = LossMetric.MSEE,
):
    """A function for preparing loss function to be used for model training.

    Args:
        predictive_fn (typing.Any): Adaptor function specifically for each model.
        calculate_metric_fn (typing.Any, optional): Function for calculating metrics. Defaults to calculate_metric.
        loss_metric (LossMetric, optional): The chosen loss function to be optimized. Defaults to LossMetric.MSEE.
    """

    def loss_fn(model: Blackbox, data: DataBundled):
        expval = predictive_fn(data.control_params, data.unitaries, model)

        metrics = calculate_metric_fn(data.unitaries, data.observables, expval)
        # Take mean of all the metrics
        metrics = jax.tree.map(jnp.mean, metrics)
        loss = metrics[loss_metric]

        return loss, metrics

    return loss_fn

make_loss_fn_oldv2

make_loss_fn_oldv2(
    adapter_fn,
    calculate_metric_fn=calculate_metric,
    loss_metric: LossMetric = MSEE,
)

A function for preparing loss function to be used for model training.

Parameters:

Name Type Description Default
predictive_fn Any

Adaptor function specifically for each model.

required
calculate_metric_fn Any

Function for calculating metrics. Defaults to calculate_metric.

calculate_metric
loss_metric LossMetric

The chosen loss function to be optimized. Defaults to LossMetric.MSEE.

MSEE
Source code in src/inspeqtor/v1/models/nnx.py
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
@deprecated.deprecated("Use new `make_loss_fn` instead")
def make_loss_fn_oldv2(
    adapter_fn,
    calculate_metric_fn=calculate_metric,
    loss_metric: LossMetric = LossMetric.MSEE,
):
    """A function for preparing loss function to be used for model training.

    Args:
        predictive_fn (typing.Any): Adaptor function specifically for each model.
        calculate_metric_fn (typing.Any, optional): Function for calculating metrics. Defaults to calculate_metric.
        loss_metric (LossMetric, optional): The chosen loss function to be optimized. Defaults to LossMetric.MSEE.
    """

    def loss_fn(model: Blackbox, data: DataBundled):
        output = model(data.control_params)

        expval = adapter_fn(output, data.unitaries)

        metrics = calculate_metric_fn(data.unitaries, data.observables, expval)
        # Take mean of all the metrics
        metrics = jax.tree.map(jnp.mean, metrics)
        loss = metrics[loss_metric]

        return loss, metrics

    return loss_fn

make_loss_fn

make_loss_fn(adapter_fn, evaluate_fn)

A function for preparing loss function to be used for model training.

Parameters:

Name Type Description Default
predictive_fn Any

Adaptor function specifically for each model.

required
evaluate_fn Callable[[ndarray, ndarray], ndarray, ndarray]

Take in predicted and experimental expectation values and ideal unitary and return loss value

required
Source code in src/inspeqtor/v1/models/nnx.py
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
def make_loss_fn(adapter_fn, evaluate_fn):
    """A function for preparing loss function to be used for model training.

    Args:
        predictive_fn (typing.Any): Adaptor function specifically for each model.
        evaluate_fn (typing.Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray, jnp.ndarray]): Take in predicted and experimental expectation values and ideal unitary and return loss value
    """

    def loss_fn(model: Blackbox, data: DataBundled):
        output = model(data.control_params)

        expval = adapter_fn(output, data.unitaries)

        loss = evaluate_fn(expval, data.observables, data.unitaries)

        return jnp.mean(loss), {}

    return loss_fn

create_step

create_step(
    loss_fn: Callable[
        [Blackbox, DataBundled], tuple[ndarray, Any]
    ],
)

A function to create the traning and evaluating step for model. The train step will update the model parameters and optimizer parameters inplace.

Parameters:

Name Type Description Default
loss_fn Callable[[Blackbox, DataBundled], tuple[ndarray, Any]]

Loss function returned from make_loss_fn

required

Returns:

Type Description

typing.Any: The tuple of training and eval step functions.

Source code in src/inspeqtor/v1/models/nnx.py
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
def create_step(
    loss_fn: typing.Callable[[Blackbox, DataBundled], tuple[jnp.ndarray, typing.Any]],
):
    """A function to create the traning and evaluating step for model.
    The train step will update the model parameters and optimizer parameters inplace.

    Args:
        loss_fn (typing.Callable[[Blackbox, DataBundled], tuple[jnp.ndarray, typing.Any]]): Loss function returned from `make_loss_fn`

    Returns:
        typing.Any: The tuple of training and eval step functions.
    """

    @nnx.jit
    def train_step(
        model: Blackbox,
        optimizer: nnx.Optimizer,
        metrics: nnx.MultiMetric,
        data: DataBundled,
    ):
        """Train for a single step."""
        model.train()  # Switch to train mode
        grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
        (loss, aux), grads = grad_fn(model, data)
        metrics.update(loss=loss)  # In-place updates.
        optimizer.update(model, grads)  # In-place updates.

        return loss, aux

    @nnx.jit
    def eval_step(model: Blackbox, metrics: nnx.MultiMetric, data):
        model.eval()
        loss, aux = loss_fn(model, data)
        metrics.update(loss=loss)  # In-place updates.

        return loss, aux

    return train_step, eval_step

reconstruct_model

reconstruct_model(
    model_params, config, Model: type[T]
) -> T

Reconstruct the model from the model parameters, config, and model initializer.

Examples:

>>> _, state = nnx.split(blackbox)
>>> model_params = nnx.to_pure_dict(state)
>>> config = {
...    "shared_layers": [8],
...    "pauli_layers": [8]
... }
>>> model_data = sq.model.ModelData(params=model_params, config=config)
# save and load to and from disk!
>>> blackbox = sq.models.nnx.reconstruct_model(model_data.params, model_data.config, sq.models.nnx.WoModel)

Parameters:

Name Type Description Default
model_params Any

The pytree containing model parameters.

required
config Any

The model configuration for model initialization.

required
Model type[T]

The model initializer.

required

Returns:

Name Type Description
T T

description

Source code in src/inspeqtor/v1/models/nnx.py
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
def reconstruct_model(model_params, config, Model: type[T]) -> T:
    """Reconstruct the model from the model parameters, config, and model initializer.

    Examples:
        >>> _, state = nnx.split(blackbox)
        >>> model_params = nnx.to_pure_dict(state)
        >>> config = {
        ...    "shared_layers": [8],
        ...    "pauli_layers": [8]
        ... }
        >>> model_data = sq.model.ModelData(params=model_params, config=config)
        # save and load to and from disk!
        >>> blackbox = sq.models.nnx.reconstruct_model(model_data.params, model_data.config, sq.models.nnx.WoModel)

    Args:
        model_params (typing.Any): The pytree containing model parameters.
        config (typing.Any): The model configuration for model initialization.
        Model (type[T]): The model initializer.

    Returns:
        T: _description_
    """
    abstract_model = nnx.eval_shape(lambda: Model(**config, rngs=nnx.Rngs(0)))
    graphdef, abstract_state = nnx.split(abstract_model)
    nnx.replace_by_pure_dict(abstract_state, model_params)

    return nnx.merge(graphdef, abstract_state)

train_model

train_model(
    key: ndarray,
    train_data: DataBundled,
    val_data: DataBundled,
    test_data: DataBundled,
    model: Blackbox,
    optimizer: GradientTransformation,
    loss_fn: Callable,
    callbacks: list[Callable] = [],
    NUM_EPOCH: int = 1000,
    _optimizer: Optimizer | None = None,
)

Train the BlackBox model

Examples:

>>> # The number of epochs break down
... NUM_EPOCH = 150
... # Total number of iterations as 90% of data is used for training
... # 10% of the data is used for testing
... total_iterations = 9 * NUM_EPOCH
... # The step for optimizer if set to 8 * NUM_EPOCH (should be less than total_iterations)
... step_for_optimizer = 8 * NUM_EPOCH
... optimizer = get_default_optimizer(step_for_optimizer)
... # The warmup steps for the optimizer
... warmup_steps = 0.1 * step_for_optimizer
... # The cool down steps for the optimizer
... cool_down_steps = total_iterations - step_for_optimizer
... total_iterations, step_for_optimizer, warmup_steps, cool_down_steps

Parameters:

Name Type Description Default
key ndarray

Random key

required
model Module

The model to be used for training

required
optimizer GradientTransformation

The optimizer to be used for training

required
loss_fn Callable

The loss function to be used for training

required
callbacks list[Callable]

list of callback functions. Defaults to [].

[]
NUM_EPOCH int

The number of epochs. Defaults to 1_000.

1000

Returns:

Name Type Description
tuple

The model parameters, optimizer state, and the histories

Source code in src/inspeqtor/v1/models/nnx.py
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
def train_model(
    # Random key
    key: jnp.ndarray,
    # Data
    train_data: DataBundled,
    val_data: DataBundled,
    test_data: DataBundled,
    # Model to be used for training
    model: Blackbox,
    optimizer: optax.GradientTransformation,
    # Loss function to be used
    loss_fn: typing.Callable,
    # Callbacks to be used
    callbacks: list[typing.Callable] = [],
    # Number of epochs
    NUM_EPOCH: int = 1_000,
    _optimizer: nnx.Optimizer | None = None,
):
    """Train the BlackBox model

    Examples:
        >>> # The number of epochs break down
        ... NUM_EPOCH = 150
        ... # Total number of iterations as 90% of data is used for training
        ... # 10% of the data is used for testing
        ... total_iterations = 9 * NUM_EPOCH
        ... # The step for optimizer if set to 8 * NUM_EPOCH (should be less than total_iterations)
        ... step_for_optimizer = 8 * NUM_EPOCH
        ... optimizer = get_default_optimizer(step_for_optimizer)
        ... # The warmup steps for the optimizer
        ... warmup_steps = 0.1 * step_for_optimizer
        ... # The cool down steps for the optimizer
        ... cool_down_steps = total_iterations - step_for_optimizer
        ... total_iterations, step_for_optimizer, warmup_steps, cool_down_steps

    Args:
        key (jnp.ndarray): Random key
        model (nn.Module): The model to be used for training
        optimizer (optax.GradientTransformation): The optimizer to be used for training
        loss_fn (typing.Callable): The loss function to be used for training
        callbacks (list[typing.Callable], optional): list of callback functions. Defaults to [].
        NUM_EPOCH (int, optional): The number of epochs. Defaults to 1_000.

    Returns:
        tuple: The model parameters, optimizer state, and the histories
    """

    key, loader_key = jax.random.split(key)

    BATCH_SIZE = val_data.control_params.shape[0]

    histories: list[HistoryEntryV3] = []

    if _optimizer is None:
        _optimizer = nnx.Optimizer(
            model,
            optimizer,
            wrt=nnx.Param,
        )

    metrics = nnx.MultiMetric(
        loss=nnx.metrics.Average("loss"),
    )

    train_step, eval_step = create_step(loss_fn=loss_fn)

    for (step, batch_idx, is_last_batch, epoch_idx), (
        batch_p,
        batch_u,
        batch_ex,
    ) in dataloader(
        (
            train_data.control_params,
            train_data.unitaries,
            train_data.observables,
        ),
        batch_size=BATCH_SIZE,
        num_epochs=NUM_EPOCH,
        key=loader_key,
    ):
        train_step(
            model,
            _optimizer,
            metrics,
            DataBundled(
                control_params=batch_p, unitaries=batch_u, observables=batch_ex
            ),
        )

        histories.append(
            HistoryEntryV3(
                step=step, loss=metrics.compute()["loss"], loop="train", aux={}
            )
        )
        metrics.reset()  # Reset the metrics for the train set.

        if is_last_batch:
            # Validation
            eval_step(model, metrics, val_data)
            histories.append(
                HistoryEntryV3(
                    step=step, loss=metrics.compute()["loss"], loop="val", aux={}
                )
            )
            metrics.reset()  # Reset the metrics for the val set.
            # Testing
            eval_step(model, metrics, test_data)
            histories.append(
                HistoryEntryV3(
                    step=step, loss=metrics.compute()["loss"], loop="test", aux={}
                )
            )
            metrics.reset()  # Reset the metrics for the test set.

            for callback in callbacks:
                callback(model, _optimizer, histories)

    return model, _optimizer, histories

Optimize

src.inspeqtor.v1.optimize

get_default_optimizer

get_default_optimizer(
    n_iterations: int,
) -> GradientTransformation

Generate present optimizer from number of training iteration.

Parameters:

Name Type Description Default
n_iterations int

Training iteration

required

Returns:

Type Description
GradientTransformation

optax.GradientTransformation: Optax optimizer.

Source code in src/inspeqtor/v1/optimize.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def get_default_optimizer(n_iterations: int) -> optax.GradientTransformation:
    """Generate present optimizer from number of training iteration.

    Args:
        n_iterations (int): Training iteration

    Returns:
        optax.GradientTransformation: Optax optimizer.
    """
    return optax.adamw(
        learning_rate=optax.warmup_cosine_decay_schedule(
            init_value=1e-6,
            peak_value=1e-2,
            warmup_steps=int(0.1 * n_iterations),
            decay_steps=n_iterations,
            end_value=1e-6,
        )
    )

minimize

minimize(
    params: ArrayTree,
    func: Callable[[ndarray], tuple[ndarray, Any]],
    optimizer: GradientTransformation,
    lower: ArrayTree | None = None,
    upper: ArrayTree | None = None,
    maxiter: int = 1000,
    callbacks: list[Callable] = [],
) -> tuple[ArrayTree, list[Any]]

Optimize the loss function with bounded parameters.

Parameters:

Name Type Description Default
params ArrayTree

Intiial parameters to be optimized

required
lower ArrayTree

Lower bound of the parameters

None
upper ArrayTree

Upper bound of the parameters

None
func Callable[[ndarray], tuple[ndarray, Any]]

Loss function

required
optimizer GradientTransformation

Instance of optax optimizer

required
maxiter int

Number of optimization step. Defaults to 1000.

1000

Returns:

Type Description
tuple[ArrayTree, list[Any]]

tuple[chex.ArrayTree, list[typing.Any]]: Tuple of parameters and optimization history

Source code in src/inspeqtor/v1/optimize.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def minimize(
    params: chex.ArrayTree,
    func: typing.Callable[[jnp.ndarray], tuple[jnp.ndarray, typing.Any]],
    optimizer: optax.GradientTransformation,
    lower: chex.ArrayTree | None = None,
    upper: chex.ArrayTree | None = None,
    maxiter: int = 1000,
    callbacks: list[typing.Callable] = [],
) -> tuple[chex.ArrayTree, list[typing.Any]]:
    """Optimize the loss function with bounded parameters.

    Args:
        params (chex.ArrayTree): Intiial parameters to be optimized
        lower (chex.ArrayTree): Lower bound of the parameters
        upper (chex.ArrayTree): Upper bound of the parameters
        func (typing.Callable[[jnp.ndarray], tuple[jnp.ndarray, typing.Any]]): Loss function
        optimizer (optax.GradientTransformation): Instance of optax optimizer
        maxiter (int, optional): Number of optimization step. Defaults to 1000.

    Returns:
        tuple[chex.ArrayTree, list[typing.Any]]: Tuple of parameters and optimization history
    """
    opt_state = optimizer.init(params)
    history = []

    for step_idx in range(maxiter):
        grads, aux = jax.grad(func, has_aux=True)(params)
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)

        if lower is not None and upper is not None:
            # Apply projection
            params = optax.projections.projection_box(params, lower, upper)

        # Log the history
        aux["params"] = params
        history.append(aux)

        for callback in callbacks:
            callback(step_idx, aux)

    return params, history

stochastic_minimize

stochastic_minimize(
    key: ndarray,
    params: ArrayTree,
    func: Callable[[ndarray, ndarray], tuple[ndarray, Any]],
    optimizer: GradientTransformation,
    lower: ArrayTree | None = None,
    upper: ArrayTree | None = None,
    maxiter: int = 1000,
    callbacks: list[Callable] = [],
) -> tuple[ArrayTree, list[Any]]

Optimize the loss function with bounded parameters.

Parameters:

Name Type Description Default
params ArrayTree

Intiial parameters to be optimized

required
lower ArrayTree

Lower bound of the parameters

None
upper ArrayTree

Upper bound of the parameters

None
func Callable[[ndarray], tuple[ndarray, Any]]

Loss function

required
optimizer GradientTransformation

Instance of optax optimizer

required
maxiter int

Number of optimization step. Defaults to 1000.

1000

Returns:

Type Description
tuple[ArrayTree, list[Any]]

tuple[chex.ArrayTree, list[typing.Any]]: Tuple of parameters and optimization history

Source code in src/inspeqtor/v1/optimize.py
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def stochastic_minimize(
    key: jnp.ndarray,
    params: chex.ArrayTree,
    func: typing.Callable[[jnp.ndarray, jnp.ndarray], tuple[jnp.ndarray, typing.Any]],
    optimizer: optax.GradientTransformation,
    lower: chex.ArrayTree | None = None,
    upper: chex.ArrayTree | None = None,
    maxiter: int = 1000,
    callbacks: list[typing.Callable] = [],
) -> tuple[chex.ArrayTree, list[typing.Any]]:
    """Optimize the loss function with bounded parameters.

    Args:
        params (chex.ArrayTree): Intiial parameters to be optimized
        lower (chex.ArrayTree): Lower bound of the parameters
        upper (chex.ArrayTree): Upper bound of the parameters
        func (typing.Callable[[jnp.ndarray], tuple[jnp.ndarray, typing.Any]]): Loss function
        optimizer (optax.GradientTransformation): Instance of optax optimizer
        maxiter (int, optional): Number of optimization step. Defaults to 1000.

    Returns:
        tuple[chex.ArrayTree, list[typing.Any]]: Tuple of parameters and optimization history
    """
    opt_state = optimizer.init(params)
    history = []

    for step_idx in range(maxiter):
        key, _ = jax.random.split(key)
        grads, aux = jax.grad(func, has_aux=True)(params, key)
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)

        if lower is not None and upper is not None:
            # Apply projection
            params = optax.projections.projection_box(params, lower, upper)

        # Log the history
        aux["params"] = params
        history.append(aux)

        for callback in callbacks:
            callback(step_idx, aux)

    return params, history

Physics

src.inspeqtor.v1.physics

normalizer

normalizer(matrix: ndarray) -> ndarray

Normalize the given matrix with QR decomposition and return matrix Q which is unitary

Parameters:

Name Type Description Default
matrix ndarray

The matrix to normalize to unitary matrix

required

Returns:

Type Description
ndarray

jnp.ndarray: The unitary matrix

Source code in src/inspeqtor/v1/physics.py
81
82
83
84
85
86
87
88
89
90
91
def normalizer(matrix: jnp.ndarray) -> jnp.ndarray:
    """Normalize the given matrix with QR decomposition and return matrix Q
       which is unitary

    Args:
        matrix (jnp.ndarray): The matrix to normalize to unitary matrix

    Returns:
        jnp.ndarray: The unitary matrix
    """
    return jnp.linalg.qr(matrix).Q  # type: ignore

solver

solver(
    args: HamiltonianArgs,
    t_eval: ndarray,
    hamiltonian: Callable[
        [HamiltonianArgs, ndarray], ndarray
    ],
    y0: ndarray,
    t0: float,
    t1: float,
    rtol: float = 1e-07,
    atol: float = 1e-07,
    max_steps: int = int(2**16),
) -> ndarray

Solve the Schrodinger equation using the given Hamiltonian

Parameters:

Name Type Description Default
args HamiltonianArgs

The arguments for the Hamiltonian

required
t_eval ndarray

The time points to evaluate the solution

required
hamiltonian Callable[[HamiltonianArgs, ndarray], ndarray]

The Hamiltonian function

required
y0 ndarray

The initial state, set to jnp.eye(2, dtype=jnp.complex128) for unitary matrix

required
t0 float

The initial time

required
t1 float

The final time

required
rtol float

description. Defaults to 1e-7.

1e-07
atol float

description. Defaults to 1e-7.

1e-07
max_steps int

The maxmimum step of evalution of solver. Defaults to int(2**16).

int(2 ** 16)

Returns:

Type Description
ndarray

jnp.ndarray: The solution of the Schrodinger equation at the given time points

Source code in src/inspeqtor/v1/physics.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
def solver(
    args: HamiltonianArgs,
    t_eval: jnp.ndarray,
    hamiltonian: typing.Callable[[HamiltonianArgs, jnp.ndarray], jnp.ndarray],
    y0: jnp.ndarray,
    t0: float,
    t1: float,
    rtol: float = 1e-7,
    atol: float = 1e-7,
    max_steps: int = int(2**16),
) -> jnp.ndarray:
    """Solve the Schrodinger equation using the given Hamiltonian

    Args:
        args (HamiltonianArgs): The arguments for the Hamiltonian
        t_eval (jnp.ndarray): The time points to evaluate the solution
        hamiltonian (typing.Callable[[HamiltonianArgs, jnp.ndarray], jnp.ndarray]): The Hamiltonian function
        y0 (jnp.ndarray): The initial state, set to jnp.eye(2, dtype=jnp.complex128) for unitary matrix
        t0 (float): The initial time
        t1 (float): The final time
        rtol (float, optional): _description_. Defaults to 1e-7.
        atol (float, optional): _description_. Defaults to 1e-7.
        max_steps (int, optional): The maxmimum step of evalution of solver. Defaults to int(2**16).

    Returns:
        jnp.ndarray: The solution of the Schrodinger equation at the given time points
    """

    # * Increase time_step to increase accuracy of solver,
    # *     then you have to increase the max_steps too.
    # * Using just a basic solver
    def rhs(t: jnp.ndarray, y: jnp.ndarray, args: HamiltonianArgs):
        return -1j * hamiltonian(args, t) @ y

    term = diffrax.ODETerm(rhs)  # type: ignore
    solver = diffrax.Tsit5()

    solution = diffrax.diffeqsolve(
        term,
        solver,
        t0=t0,
        t1=t1,
        dt0=None,
        stepsize_controller=diffrax.PIDController(
            rtol=rtol,
            atol=atol,
        ),
        y0=y0,
        args=args,
        saveat=diffrax.SaveAt(ts=t_eval),
        max_steps=max_steps,
    )

    # Normailized the solution
    ys = solution.ys
    # assert isinstance(ys, jnp.ndarray)

    return jax.vmap(normalizer)(ys)  # type: ignore

auto_rotating_frame_hamiltonian

auto_rotating_frame_hamiltonian(
    hamiltonian: Callable[
        [HamiltonianArgs, ndarray], ndarray
    ],
    frame: ndarray,
    explicit_deriv: bool = False,
)

Implement the Hamiltonian in the rotating frame with H_I = U(t) @ H @ U^dagger(t) + i * U(t) @ dU^dagger(t)/dt

Parameters:

Name Type Description Default
hamiltonian Callable

The hamiltonian function

required
frame ndarray

The frame matrix

required
Source code in src/inspeqtor/v1/physics.py
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
def auto_rotating_frame_hamiltonian(
    hamiltonian: typing.Callable[[HamiltonianArgs, jnp.ndarray], jnp.ndarray],
    frame: jnp.ndarray,
    explicit_deriv: bool = False,
):
    """Implement the Hamiltonian in the rotating frame with
    H_I = U(t) @ H @ U^dagger(t) + i * U(t) @ dU^dagger(t)/dt

    Args:
        hamiltonian (Callable): The hamiltonian function
        frame (jnp.ndarray): The frame matrix
    """

    is_diagonal = False
    # Check if the frame is diagonal matrix
    if jnp.count_nonzero(frame - jnp.diag(jnp.diagonal(frame))) == 0:
        is_diagonal = True

    # Check if the jax_enable_x64 is True
    if not jax.config.read("jax_enable_x64") or is_diagonal:

        def frame_unitary(t: jnp.ndarray) -> jnp.ndarray:
            # NOTE: This is the same as the below, as we sure that the frame is diagonal
            return jnp.diag(jnp.exp(1j * jnp.diagonal(frame) * t))

    else:

        def frame_unitary(t: jnp.ndarray) -> jnp.ndarray:
            return jax.scipy.linalg.expm(1j * frame * t)

    def derivative_frame_unitary(t: jnp.ndarray) -> jnp.ndarray:
        # NOTE: Assume that the frame is time independent.
        return 1j * frame @ frame_unitary(t)

    def rotating_frame_hamiltonian_v0(args: HamiltonianArgs, t: jnp.ndarray):
        return frame_unitary(t) @ hamiltonian(args, t) @ jnp.transpose(
            jnp.conjugate(frame_unitary(t))
        ) + 1j * (
            derivative_frame_unitary(t) @ jnp.transpose(jnp.conjugate(frame_unitary(t)))
        )

    def rotating_frame_hamiltonian(args: HamiltonianArgs, t: jnp.ndarray):
        # NOTE: Assume that the product of derivative and conjugate of frame unitary is identity
        return (
            frame_unitary(t)
            @ hamiltonian(args, t)
            @ jnp.transpose(jnp.conjugate(frame_unitary(t)))
            - frame
        )

    return (
        rotating_frame_hamiltonian_v0 if explicit_deriv else rotating_frame_hamiltonian
    )

explicit_auto_rotating_frame_hamiltonian

explicit_auto_rotating_frame_hamiltonian(
    hamiltonian: Callable[
        [HamiltonianArgs, ndarray], ndarray
    ],
    frame: ndarray,
)

Implement the Hamiltonian in the rotating frame with H_I = U(t) @ H @ U^dagger(t) + i * U(t) @ dU^dagger(t)/dt

Note

This is the implementation of auto_rotating_frame_hamiltonian that perform explicit derivative.

Parameters:

Name Type Description Default
hamiltonian Callable

The hamiltonian function

required
frame ndarray

The frame matrix

required
Source code in src/inspeqtor/v1/physics.py
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
def explicit_auto_rotating_frame_hamiltonian(
    hamiltonian: typing.Callable[[HamiltonianArgs, jnp.ndarray], jnp.ndarray],
    frame: jnp.ndarray,
):
    """Implement the Hamiltonian in the rotating frame with
    H_I = U(t) @ H @ U^dagger(t) + i * U(t) @ dU^dagger(t)/dt

    Note:
        This is the implementation of `auto_rotating_frame_hamiltonian`
        that perform explicit derivative.

    Args:
        hamiltonian (Callable): The hamiltonian function
        frame (jnp.ndarray): The frame matrix
    """

    def frame_unitary(t: jnp.ndarray) -> jnp.ndarray:
        return jax.scipy.linalg.expm(1j * frame * t)

    def derivative_frame_unitary(t: jnp.ndarray) -> jnp.ndarray:
        # NOTE: Assume that the frame is time independent.
        return 1j * frame @ frame_unitary(t)

    def rotating_frame_hamiltonian_v0(args: HamiltonianArgs, t: jnp.ndarray):
        return frame_unitary(t) @ hamiltonian(args, t) @ jnp.transpose(
            jnp.conjugate(frame_unitary(t))
        ) + 1j * (
            derivative_frame_unitary(t) @ jnp.transpose(jnp.conjugate(frame_unitary(t)))
        )

    return rotating_frame_hamiltonian_v0

a

a(dims: int) -> ndarray

Annihilation operator of given dims

Parameters:

Name Type Description Default
dims int

Number of states

required

Returns:

Type Description
ndarray

jnp.ndarray: Annihilation operator

Source code in src/inspeqtor/v1/physics.py
242
243
244
245
246
247
248
249
250
251
def a(dims: int) -> jnp.ndarray:
    """Annihilation operator of given dims

    Args:
        dims (int): Number of states

    Returns:
        jnp.ndarray: Annihilation operator
    """
    return jnp.diag(jnp.sqrt(jnp.arange(1, dims)), 1)

a_dag

a_dag(dims: int) -> ndarray

Creation operator of given dims

Parameters:

Name Type Description Default
dims int

Number of states

required

Returns:

Type Description
ndarray

jnp.ndarray: Creation operator

Source code in src/inspeqtor/v1/physics.py
254
255
256
257
258
259
260
261
262
263
def a_dag(dims: int) -> jnp.ndarray:
    """Creation operator of given dims

    Args:
        dims (int): Number of states

    Returns:
        jnp.ndarray: Creation operator
    """
    return jnp.diag(jnp.sqrt(jnp.arange(1, dims)), -1)

N

N(dims: int) -> ndarray

Number operator of given dims

Parameters:

Name Type Description Default
dims int

Number of states

required

Returns:

Type Description
ndarray

jnp.ndarray: Number operator

Source code in src/inspeqtor/v1/physics.py
266
267
268
269
270
271
272
273
274
275
def N(dims: int) -> jnp.ndarray:
    """Number operator of given dims

    Args:
        dims (int): Number of states

    Returns:
        jnp.ndarray: Number operator
    """
    return jnp.diag(jnp.arange(dims))

gen_hamiltonian_from

gen_hamiltonian_from(
    qubit_informations: list[QubitInformation],
    coupling_constants: list[CouplingInformation],
    dims: int = 2,
) -> dict[str, HamiltonianTerm]

Generate dict of Hamiltonian from given qubits and coupling information.

Parameters:

Name Type Description Default
qubit_informations list[QubitInformation]

Qubit information

required
coupling_constants list[CouplingInformation]

Coupling information

required
dims int

The level of the quantum system. Defaults to 2, i.e. qubit system.

2

Returns:

Type Description
dict[str, HamiltonianTerm]

dict[str, HamiltonianTerm]: description

Source code in src/inspeqtor/v1/physics.py
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
def gen_hamiltonian_from(
    qubit_informations: list[QubitInformation],
    coupling_constants: list[CouplingInformation],
    dims: int = 2,
) -> dict[str, HamiltonianTerm]:
    """Generate dict of Hamiltonian from given qubits and coupling information.

    Args:
        qubit_informations (list[QubitInformation]): Qubit information
        coupling_constants (list[CouplingInformation]): Coupling information
        dims (int, optional): The level of the quantum system. Defaults to 2, i.e. qubit system.

    Returns:
        dict[str, HamiltonianTerm]: _description_
    """
    num_qubits = len(qubit_informations)

    operators: dict[str, HamiltonianTerm] = {}

    for idx, qubit in enumerate(qubit_informations):
        # The static Hamiltonian terms
        static_i = 2 * jnp.pi * qubit.frequency * (jnp.eye(dims) - 2 * N(dims)) / 2

        operators[ChannelID(qubit_idx=qubit.qubit_idx, type=TermType.STATIC).hash()] = (
            HamiltonianTerm(
                qubit_idx=qubit.qubit_idx,
                type=TermType.STATIC,
                controlable=False,
                operator=tensor(static_i, idx, num_qubits),
            )
        )

        # The anharmonicity term
        anhar_i = 2 * jnp.pi * qubit.anharmonicity * (N(dims) @ N(dims) - N(dims)) / 2

        operators[
            ChannelID(qubit_idx=qubit.qubit_idx, type=TermType.ANHAMONIC).hash()
        ] = HamiltonianTerm(
            qubit_idx=qubit.qubit_idx,
            type=TermType.ANHAMONIC,
            controlable=False,
            operator=tensor(anhar_i, idx, num_qubits),
        )

        # The drive terms
        drive_i = 2 * jnp.pi * qubit.drive_strength * (a(dims) + a_dag(dims))

        operators[ChannelID(qubit_idx=qubit.qubit_idx, type=TermType.DRIVE).hash()] = (
            HamiltonianTerm(
                qubit_idx=qubit.qubit_idx,
                type=TermType.DRIVE,
                controlable=True,
                operator=tensor(drive_i, idx, num_qubits),
            )
        )

        # The control terms that drive with another qubit frequency
        control_i = 2 * jnp.pi * qubit.drive_strength * (a(dims) + a_dag(dims))

        operators[
            ChannelID(qubit_idx=qubit.qubit_idx, type=TermType.CONTROL).hash()
        ] = HamiltonianTerm(
            qubit_idx=qubit.qubit_idx,
            type=TermType.CONTROL,
            controlable=True,
            operator=tensor(control_i, idx, num_qubits),
        )

    for coupling in coupling_constants:
        # Add the coupling constant to the Hamiltonian
        c_1 = tensor(a(dims), coupling.qubit_indices[0], num_qubits) @ tensor(
            a_dag(dims), coupling.qubit_indices[1], num_qubits
        )
        c_2 = tensor(a_dag(dims), coupling.qubit_indices[0], num_qubits) @ tensor(
            a(dims), coupling.qubit_indices[1], num_qubits
        )
        coupling_ij = 2 * jnp.pi * coupling.coupling_strength * (c_1 + c_2)

        operators[
            ChannelID(
                qubit_idx=(coupling.qubit_indices[0], coupling.qubit_indices[1]),
                type=TermType.COUPLING,
            ).hash()
        ] = HamiltonianTerm(
            qubit_idx=(coupling.qubit_indices[0], coupling.qubit_indices[1]),
            type=TermType.COUPLING,
            controlable=False,
            operator=coupling_ij,
        )

    return operators

hamiltonian_fn

hamiltonian_fn(
    args: dict[str, SignalParameters],
    t: ndarray,
    signals: dict[
        str, Callable[[SignalParameters, ndarray], ndarray]
    ],
    hamiltonian_terms: dict[str, HamiltonianTerm],
    static_terms: list[str],
) -> ndarray

Hamiltonian function to be used whitebox. Expect to be used in partial form, i.e. making signals, hamiltonian_terms, and static_terms static arguments.

Parameters:

Name Type Description Default
args dict[str, SignalParameters]

Control parameter

required
t ndarray

Time to evaluate

required
signals dict[str, Callable[[SignalParameters, ndarray], ndarray]]

Signal function of the control

required
hamiltonian_terms dict[str, HamiltonianTerm]

Dict of Hamiltonian terms, where key is channel

required
static_terms list[str]

list of channel id specifing the static term.

required

Returns:

Type Description
ndarray

jnp.ndarray: description

Source code in src/inspeqtor/v1/physics.py
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
def hamiltonian_fn(
    args: dict[str, SignalParameters],
    t: jnp.ndarray,
    signals: dict[str, typing.Callable[[SignalParameters, jnp.ndarray], jnp.ndarray]],
    hamiltonian_terms: dict[str, HamiltonianTerm],
    static_terms: list[str],
) -> jnp.ndarray:
    """Hamiltonian function to be used whitebox.
    Expect to be used in partial form, i.e. making `signals`, `hamiltonian_terms`, and `static_terms` static arguments.

    Args:
        args (dict[str, SignalParameters]): Control parameter
        t (jnp.ndarray): Time to evaluate
        signals (dict[str, typing.Callable[[SignalParameters, jnp.ndarray], jnp.ndarray]]): Signal function of the control
        hamiltonian_terms (dict[str, HamiltonianTerm]): Dict of Hamiltonian terms, where key is channel
        static_terms (list[str]): list of channel id specifing the static term.

    Returns:
        jnp.ndarray: _description_
    """
    # Match the args with signal
    drives = jnp.array(
        [
            signal(args[channel_id], t) * hamiltonian_terms[channel_id].operator
            for channel_id, signal in signals.items()
        ]
    )

    statics = jnp.array(
        [hamiltonian_terms[static_term].operator for static_term in static_terms]
    )

    return jnp.sum(drives, axis=0) + jnp.sum(statics, axis=0)

signal_func_v3

signal_func_v3(
    get_envelope: Callable,
    drive_frequency: float,
    dt: float,
)

Make the envelope function into signal with drive frequency

Parameters:

Name Type Description Default
get_envelope Callable

The envelope function in unit of dt

required
drive_frequency float

drive freuqency in unit of GHz

required
dt float

The dt provived will be used to convert envelope unit to ns, set to 1 if the envelope function is already in unit of ns

required
Source code in src/inspeqtor/v1/physics.py
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
def signal_func_v3(get_envelope: typing.Callable, drive_frequency: float, dt: float):
    """Make the envelope function into signal with drive frequency

    Args:
        get_envelope (Callable): The envelope function in unit of dt
        drive_frequency (float): drive freuqency in unit of GHz
        dt (float): The dt provived will be used to convert envelope unit to ns,
                    set to 1 if the envelope function is already in unit of ns
    """

    def signal(pulse_parameters: SignalParameters, t: jnp.ndarray):
        return jnp.real(
            get_envelope(pulse_parameters.pulse_params)(t / dt)
            * jnp.exp(
                1j * ((2 * jnp.pi * drive_frequency * t) + pulse_parameters.phase)
            )
        )

    return signal

signal_func_v4

signal_func_v4(
    get_envelope: Callable,
    drive_frequency: float,
    dt: float,
)

Make the envelope function into signal with drive frequency

Parameters:

Name Type Description Default
get_envelope Callable

The envelope function in unit of dt

required
drive_frequency float

drive freuqency in unit of GHz

required
dt float

The dt provived will be used to convert envelope unit to ns, set to 1 if the envelope function is already in unit of ns

required
Source code in src/inspeqtor/v1/physics.py
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
def signal_func_v4(get_envelope: typing.Callable, drive_frequency: float, dt: float):
    """Make the envelope function into signal with drive frequency

    Args:
        get_envelope (Callable): The envelope function in unit of dt
        drive_frequency (float): drive freuqency in unit of GHz
        dt (float): The dt provived will be used to convert envelope unit to ns,
                    set to 1 if the envelope function is already in unit of ns
    """

    def signal(control_param: SignalParametersV2, t: jnp.ndarray):
        return jnp.real(
            get_envelope(control_param.pulse_params)(t / dt)
            * jnp.exp(1j * ((2 * jnp.pi * drive_frequency * t) + control_param.phase))
        )

    return signal

make_signal_fn

make_signal_fn(
    get_envelope: Callable[
        [ControlParam], Callable[[ndarray], ndarray]
    ],
    drive_frequency: float,
    dt: float,
)

Make the envelope function into signal with drive frequency

Parameters:

Name Type Description Default
get_envelope Callable

The envelope function in unit of dt

required
drive_frequency float

drive freuqency in unit of GHz

required
dt float

The dt provived will be used to convert envelope unit to ns, set to 1 if the envelope function is already in unit of ns

required
Source code in src/inspeqtor/v1/physics.py
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
def make_signal_fn(
    get_envelope: typing.Callable[
        [ControlParam], typing.Callable[[jnp.ndarray], jnp.ndarray]
    ],
    drive_frequency: float,
    dt: float,
):
    """Make the envelope function into signal with drive frequency

    Args:
        get_envelope (Callable): The envelope function in unit of dt
        drive_frequency (float): drive freuqency in unit of GHz
        dt (float): The dt provived will be used to convert envelope unit to ns,
                    set to 1 if the envelope function is already in unit of ns
    """

    def signal(control_parameters: ControlParam, t: jnp.ndarray):
        return jnp.real(
            get_envelope(control_parameters)(t / dt)
            * jnp.exp(1j * (2 * jnp.pi * drive_frequency * t))
        )

    return signal

gate_fidelity

gate_fidelity(U: ndarray, V: ndarray) -> ndarray

Calculate the gate fidelity between U and V

Parameters:

Name Type Description Default
U ndarray

Unitary operator to be targetted

required
V ndarray

Unitary operator to be compared

required

Returns:

Type Description
ndarray

jnp.ndarray: Gate fidelity

Source code in src/inspeqtor/v1/physics.py
493
494
495
496
497
498
499
500
501
502
503
504
505
506
def gate_fidelity(U: jnp.ndarray, V: jnp.ndarray) -> jnp.ndarray:
    """Calculate the gate fidelity between U and V

    Args:
        U (jnp.ndarray): Unitary operator to be targetted
        V (jnp.ndarray): Unitary operator to be compared

    Returns:
        jnp.ndarray: Gate fidelity
    """
    up = jnp.trace(U.conj().T @ V)
    down = jnp.sqrt(jnp.trace(U.conj().T @ U) * jnp.trace(V.conj().T @ V))

    return jnp.abs(up / down) ** 2

check_valid_density_matrix

check_valid_density_matrix(rho: ndarray)

Check if the provided matrix is valid density matrix

Parameters:

Name Type Description Default
rho ndarray

description

required
Source code in src/inspeqtor/v1/physics.py
532
533
534
535
536
537
538
539
540
def check_valid_density_matrix(rho: jnp.ndarray):
    """Check if the provided matrix is valid density matrix

    Args:
        rho (jnp.ndarray): _description_
    """
    # Check if the density matrix is valid
    assert jnp.allclose(jnp.trace(rho), 1.0), "Density matrix is not trace 1"
    assert jnp.allclose(rho, rho.conj().T), "Density matrix is not Hermitian"

check_hermitian

check_hermitian(op: ndarray)

Check if the provided matrix is Hermitian

Parameters:

Name Type Description Default
op ndarray

Matrix to be assert

required
Source code in src/inspeqtor/v1/physics.py
543
544
545
546
547
548
549
def check_hermitian(op: jnp.ndarray):
    """Check if the provided matrix is Hermitian

    Args:
        op (jnp.ndarray): Matrix to be assert
    """
    assert jnp.allclose(op, op.conj().T), "Matrix is not Hermitian"

direct_AFG_estimation

direct_AFG_estimation(
    coefficients: ndarray, expectation_values: ndarray
) -> ndarray

Calculate single qubit average gate fidelity from expectation value This function should be used with direct_AFG_estimation_coefficients

Examples:

>>> coefficients = direct_AFG_estimation_coefficients(unitary)
... agf = direct_AFG_estimation(coefficients, expectation_value)

Parameters:

Name Type Description Default
coefficients ndarray

The coefficients return from direct_AFG_estimation_coefficients

required
expectation_values ndarray

The expectation values assume to be shape of (..., 18) with order of sq.constant.default_expectation_values_order

required

Returns:

Type Description
ndarray

jnp.ndarray: Average Gate Fidelity

Source code in src/inspeqtor/v1/physics.py
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
def direct_AFG_estimation(
    coefficients: jnp.ndarray,
    expectation_values: jnp.ndarray,
) -> jnp.ndarray:
    """Calculate single qubit average gate fidelity from expectation value
    This function should be used with `direct_AFG_estimation_coefficients`

    Examples:
        >>> coefficients = direct_AFG_estimation_coefficients(unitary)
        ... agf = direct_AFG_estimation(coefficients, expectation_value)

    Args:
        coefficients (jnp.ndarray): The coefficients return from `direct_AFG_estimation_coefficients`
        expectation_values (jnp.ndarray): The expectation values assume to be shape of (..., 18) with order of `sq.constant.default_expectation_values_order`

    Returns:
        jnp.ndarray: Average Gate Fidelity
    """
    return (1 / 2) + ((1 / 12) * jnp.dot(coefficients, expectation_values))

direct_AFG_estimation_coefficients

direct_AFG_estimation_coefficients(
    target_unitary: ndarray,
) -> ndarray

Compute the expected coefficients to be used for AGF calculation using direct_AFG_estimation. The order of coefficients is the same as sq.constant.default_expectation_values_order

Parameters:

Name Type Description Default
target_unitary ndarray

Target unitary to be computed for coefficient

required

Returns:

Type Description
ndarray

jnp.ndarray: Coefficients for AGF calculation.

Source code in src/inspeqtor/v1/physics.py
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
def direct_AFG_estimation_coefficients(target_unitary: jnp.ndarray) -> jnp.ndarray:
    """Compute the expected coefficients to be used for AGF calculation using `direct_AFG_estimation`.
    The order of coefficients is the same as `sq.constant.default_expectation_values_order`

    Args:
        target_unitary (jnp.ndarray): Target unitary to be computed for coefficient

    Returns:
        jnp.ndarray: Coefficients for AGF calculation.
    """
    coefficients = []
    for pauli_i in [X, Y, Z]:
        for pauli_j in [X, Y, Z]:
            pauli_coeff = (1 / 2) * jnp.trace(
                pauli_i @ target_unitary @ pauli_j @ target_unitary.conj().T
            )
            for state_coeff in [1, -1]:
                coeff = state_coeff * pauli_coeff
                coefficients.append(coeff)

    return jnp.real(jnp.array(coefficients))

calculate_exp

calculate_exp(
    unitary: ndarray,
    operator: ndarray,
    density_matrix: ndarray,
) -> ndarray

Calculate the expectation value for given unitary, observable (operator), initial state (density_matrix). Shape of all arguments must be boardcastable.

Parameters:

Name Type Description Default
unitary ndarray

Unitary operator

required
operator ndarray

Quantum Observable

required
density_matrix ndarray

Intial state in form of density matrix.

required

Returns:

Type Description
ndarray

jnp.ndarray: Expectation value of quantum observable.

Source code in src/inspeqtor/v1/physics.py
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
def calculate_exp(
    unitary: jnp.ndarray, operator: jnp.ndarray, density_matrix: jnp.ndarray
) -> jnp.ndarray:
    """Calculate the expectation value for given unitary, observable (operator), initial state (density_matrix).
    Shape of all arguments must be boardcastable.

    Args:
        unitary (jnp.ndarray): Unitary operator
        operator (jnp.ndarray): Quantum Observable
        density_matrix (jnp.ndarray): Intial state in form of density matrix.

    Returns:
        jnp.ndarray: Expectation value of quantum observable.
    """
    rho = jnp.matmul(
        unitary, jnp.matmul(density_matrix, unitary.conj().swapaxes(-2, -1))
    )
    temp = jnp.matmul(rho, operator)
    return jnp.real(jnp.sum(jnp.diagonal(temp, axis1=-2, axis2=-1), axis=-1))

unitaries_prod

unitaries_prod(
    prev_unitary: ndarray, curr_unitary: ndarray
) -> tuple[ndarray, ndarray]

Function to be used for trotterization Whitebox

Parameters:

Name Type Description Default
prev_unitary ndarray

Product of cummulate Unitary operator.

required
curr_unitary ndarray

The next Unitary operator to be multiply.

required

Returns:

Type Description
tuple[ndarray, ndarray]

tuple[jnp.ndarray, jnp.ndarray]: Product of previous unitart and current unitary.

Source code in src/inspeqtor/v1/physics.py
626
627
628
629
630
631
632
633
634
635
636
637
638
639
def unitaries_prod(
    prev_unitary: jnp.ndarray, curr_unitary: jnp.ndarray
) -> tuple[jnp.ndarray, jnp.ndarray]:
    """Function to be used for trotterization Whitebox

    Args:
        prev_unitary (jnp.ndarray): Product of cummulate Unitary operator.
        curr_unitary (jnp.ndarray): The next Unitary operator to be multiply.

    Returns:
        tuple[jnp.ndarray, jnp.ndarray]: Product of previous unitart and current unitary.
    """
    prod_unitary = prev_unitary @ curr_unitary
    return prod_unitary, prod_unitary

make_trotterization_solver

make_trotterization_solver(
    hamiltonian: Callable[..., ndarray],
    total_dt: int,
    dt: float,
    trotter_steps: int,
    y0: ndarray,
)

Retutn whitebox function compute using Trotterization strategy.

Parameters:

Name Type Description Default
hamiltonian Callable[..., ndarray]

The Hamiltonian function of the system

required
total_dt int

The total duration of control sequence

required
dt float

The duration of time step in nanosecond.

required
trotter_steps int

The number of trotterization step.

required
y0 ndarray

The initial unitary state. Defaults to jnp.eye(2, dtype=jnp.complex128)

required

Returns:

Type Description

typing.Callable[..., jnp.ndarray]: Trotterization Whitebox function

Source code in src/inspeqtor/v1/physics.py
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
def make_trotterization_solver(
    hamiltonian: typing.Callable[..., jnp.ndarray],
    total_dt: int,
    dt: float,
    trotter_steps: int,
    y0: jnp.ndarray,
):
    """Retutn whitebox function compute using Trotterization strategy.

    Args:
        hamiltonian (typing.Callable[..., jnp.ndarray]): The Hamiltonian function of the system
        total_dt (int): The total duration of control sequence
        dt (float, optional): The duration of time step in nanosecond.
        trotter_steps (int, optional): The number of trotterization step.
        y0 (jnp.ndarray): The initial unitary state. Defaults to jnp.eye(2, dtype=jnp.complex128)

    Returns:
        typing.Callable[..., jnp.ndarray]: Trotterization Whitebox function
    """
    hamiltonian = jax.jit(hamiltonian)
    time_step = jnp.linspace(0, total_dt * dt, trotter_steps)

    def whitebox(control_parameters: jnp.ndarray):
        hamiltonians = jax.vmap(hamiltonian, in_axes=(None, 0))(
            control_parameters, time_step
        )
        unitaries = jax.scipy.linalg.expm(
            -1j * (time_step[1] - time_step[0]) * hamiltonians
        )
        # * Nice explanation of scan
        # * https://www.nelsontang.com/blog/a-friendly-introduction-to-scan-with-jax
        _, unitaries = jax.lax.scan(unitaries_prod, y0, unitaries)
        return unitaries

    return whitebox

lindblad_solver

lindblad_solver(
    args: HamiltonianArgs,
    t_eval: ndarray,
    hamiltonian: Callable[
        [HamiltonianArgs, ndarray], ndarray
    ],
    lindblad_ops: list[ndarray],
    rho0: ndarray,
    t0: float,
    t1: float,
    rtol: float = 1e-07,
    atol: float = 1e-07,
    max_steps: int = int(2**16),
) -> ndarray

Solve the Lindblad Master equation without flattening matrices

Source code in src/inspeqtor/v1/physics.py
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
def lindblad_solver(
    args: HamiltonianArgs,
    t_eval: jnp.ndarray,
    hamiltonian: typing.Callable[[HamiltonianArgs, jnp.ndarray], jnp.ndarray],
    lindblad_ops: list[jnp.ndarray],
    rho0: jnp.ndarray,
    t0: float,
    t1: float,
    rtol: float = 1e-7,
    atol: float = 1e-7,
    max_steps: int = int(2**16),
) -> jnp.ndarray:
    """Solve the Lindblad Master equation without flattening matrices"""

    def commutator(A, B):
        return A @ B - B @ A

    def anti_commutator(A, B):
        return A @ B + B @ A

    # Direct matrix RHS function - no reshaping needed
    def lindblad_rhs(t, rho, args: HamiltonianArgs):
        H = hamiltonian(args, t)

        # Coherent evolution term: -i[H, ρ]
        coherent_term = -1j * commutator(H, rho)

        # Dissipative terms from all Lindblad operators
        dissipative_term = jnp.zeros_like(rho)

        for L in lindblad_ops:
            L_dag = jnp.conjugate(L.T)
            term1 = L @ rho @ L_dag
            term2 = anti_commutator(L_dag @ L, rho)
            dissipative_term = dissipative_term + (term1 - 0.5 * term2)

        return coherent_term + dissipative_term

    term = diffrax.ODETerm(lindblad_rhs)
    solver = diffrax.Tsit5()
    # solver = diffrax.Dopri5()

    solution = diffrax.diffeqsolve(
        term,
        solver,
        t0=t0,
        t1=t1,
        dt0=None,
        stepsize_controller=diffrax.PIDController(
            rtol=rtol,
            atol=atol,
        ),
        y0=rho0,
        args=args,
        saveat=diffrax.SaveAt(ts=t_eval),
        max_steps=max_steps,
    )

    # Process the matrices to ensure hermiticity and unit trace
    def process_density_matrix(rho):
        # Ensure Hermiticity
        rho = 0.5 * (rho + jnp.conjugate(rho.T))
        # Normalize to trace = 1
        return rho / jnp.trace(rho)

    # Process all matrices
    return jax.vmap(process_density_matrix)(solution.ys)

Predefined

src.inspeqtor.v1.predefined

HamiltonianSpec dataclass

Source code in src/inspeqtor/v1/predefined.py
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
@dataclass
class HamiltonianSpec:
    method: WhiteboxStrategy
    hamiltonian_enum: HamiltonianEnum = HamiltonianEnum.rotating_transmon_hamiltonian
    # For Trotterization
    trotter_steps: int = 1000
    # For ODE sovler
    max_steps = int(2**16)

    def get_hamiltonian_fn(self):
        if self.hamiltonian_enum == HamiltonianEnum.rotating_transmon_hamiltonian:
            return rotating_transmon_hamiltonian
        elif self.hamiltonian_enum == HamiltonianEnum.transmon_hamiltonian:
            return transmon_hamiltonian
        else:
            raise ValueError(f"Unsupport Hamiltonian: {self.hamiltonian_enum}")

    def get_solver(
        self,
        control_sequence: ControlSequence,
        qubit_info: QubitInformation,
        dt: float,
        get_envelope_transformer=get_envelope_transformer,
    ):
        """Return Unitary solver from the given specification of the Hamiltonian and solver

        Args:
            control_sequence (ControlSequence): The control sequence object
            qubit_info (QubitInformation): The qubit information object
            dt (float): The time step size of the device

        Raises:
            ValueError: Unsupport Solver method

        Returns:
            typing.Any: The unitary solver
        """
        if self.method == WhiteboxStrategy.TROTTER:
            hamiltonian = partial(
                self.get_hamiltonian_fn(),
                qubit_info=qubit_info,
                signal=make_signal_fn(
                    get_envelope=get_envelope_transformer(
                        control_sequence=control_sequence
                    ),
                    drive_frequency=qubit_info.frequency,
                    dt=dt,
                ),
            )

            whitebox = make_trotterization_solver(
                hamiltonian=hamiltonian,
                total_dt=control_sequence.total_dt,
                dt=dt,
                trotter_steps=self.trotter_steps,
                y0=jnp.eye(2, dtype=jnp.complex128),
            )
            return whitebox
        elif self.method == WhiteboxStrategy.ODE:
            return get_single_qubit_whitebox(
                self.get_hamiltonian_fn(),
                control_sequence,
                qubit_info,
                dt,
                self.max_steps,
            )
        else:
            raise ValueError("Unsupport method")

get_solver

get_solver(
    control_sequence: ControlSequence,
    qubit_info: QubitInformation,
    dt: float,
    get_envelope_transformer=get_envelope_transformer,
)

Return Unitary solver from the given specification of the Hamiltonian and solver

Parameters:

Name Type Description Default
control_sequence ControlSequence

The control sequence object

required
qubit_info QubitInformation

The qubit information object

required
dt float

The time step size of the device

required

Raises:

Type Description
ValueError

Unsupport Solver method

Returns:

Type Description

typing.Any: The unitary solver

Source code in src/inspeqtor/v1/predefined.py
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
def get_solver(
    self,
    control_sequence: ControlSequence,
    qubit_info: QubitInformation,
    dt: float,
    get_envelope_transformer=get_envelope_transformer,
):
    """Return Unitary solver from the given specification of the Hamiltonian and solver

    Args:
        control_sequence (ControlSequence): The control sequence object
        qubit_info (QubitInformation): The qubit information object
        dt (float): The time step size of the device

    Raises:
        ValueError: Unsupport Solver method

    Returns:
        typing.Any: The unitary solver
    """
    if self.method == WhiteboxStrategy.TROTTER:
        hamiltonian = partial(
            self.get_hamiltonian_fn(),
            qubit_info=qubit_info,
            signal=make_signal_fn(
                get_envelope=get_envelope_transformer(
                    control_sequence=control_sequence
                ),
                drive_frequency=qubit_info.frequency,
                dt=dt,
            ),
        )

        whitebox = make_trotterization_solver(
            hamiltonian=hamiltonian,
            total_dt=control_sequence.total_dt,
            dt=dt,
            trotter_steps=self.trotter_steps,
            y0=jnp.eye(2, dtype=jnp.complex128),
        )
        return whitebox
    elif self.method == WhiteboxStrategy.ODE:
        return get_single_qubit_whitebox(
            self.get_hamiltonian_fn(),
            control_sequence,
            qubit_info,
            dt,
            self.max_steps,
        )
    else:
        raise ValueError("Unsupport method")

rotating_transmon_hamiltonian

rotating_transmon_hamiltonian(
    params: HamiltonianArgs,
    t: ndarray,
    qubit_info: QubitInformation,
    signal: Callable[[HamiltonianArgs, ndarray], ndarray],
) -> ndarray

Rotating frame hamiltonian of the transmon model

Parameters:

Name Type Description Default
params HamiltonianParameters

The parameter of the pulse for hamiltonian

required
t ndarray

The time to evaluate the Hamiltonian

required
qubit_info QubitInformation

The information of qubit

required
signal Callable[..., ndarray]

The pulse signal

required

Returns:

Type Description
ndarray

jnp.ndarray: The Hamiltonian

Source code in src/inspeqtor/v1/predefined.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def rotating_transmon_hamiltonian(
    params: HamiltonianArgs,
    t: jnp.ndarray,
    qubit_info: QubitInformation,
    signal: typing.Callable[[HamiltonianArgs, jnp.ndarray], jnp.ndarray],
) -> jnp.ndarray:
    """Rotating frame hamiltonian of the transmon model

    Args:
        params (HamiltonianParameters): The parameter of the pulse for hamiltonian
        t (jnp.ndarray): The time to evaluate the Hamiltonian
        qubit_info (QubitInformation): The information of qubit
        signal (Callable[..., jnp.ndarray]): The pulse signal

    Returns:
        jnp.ndarray: The Hamiltonian
    """
    a0 = 2 * jnp.pi * qubit_info.frequency
    a1 = 2 * jnp.pi * qubit_info.drive_strength

    def f3(params, t):
        return a1 * signal(params, t)

    def f_sigma_x(params, t):
        return f3(params, t) * jnp.cos(a0 * t)

    def f_sigma_y(params, t):
        return f3(params, t) * jnp.sin(a0 * t)

    return f_sigma_x(params, t) * X - f_sigma_y(params, t) * Y

transmon_hamiltonian

transmon_hamiltonian(
    params: HamiltonianArgs,
    t: ndarray,
    qubit_info: QubitInformation,
    signal: Callable[[HamiltonianArgs, ndarray], ndarray],
) -> ndarray

Lab frame hamiltonian of the transmon model

Parameters:

Name Type Description Default
params HamiltonianParameters

The parameter of the pulse for hamiltonian

required
t ndarray

The time to evaluate the Hamiltonian

required
qubit_info QubitInformation

The information of qubit

required
signal Callable[..., ndarray]

The pulse signal

required

Returns:

Type Description
ndarray

jnp.ndarray: The Hamiltonian

Source code in src/inspeqtor/v1/predefined.py
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def transmon_hamiltonian(
    params: HamiltonianArgs,
    t: jnp.ndarray,
    qubit_info: QubitInformation,
    signal: typing.Callable[[HamiltonianArgs, jnp.ndarray], jnp.ndarray],
) -> jnp.ndarray:
    """Lab frame hamiltonian of the transmon model

    Args:
        params (HamiltonianParameters): The parameter of the pulse for hamiltonian
        t (jnp.ndarray): The time to evaluate the Hamiltonian
        qubit_info (QubitInformation): The information of qubit
        signal (Callable[..., jnp.ndarray]): The pulse signal

    Returns:
        jnp.ndarray: The Hamiltonian
    """

    a0 = 2 * jnp.pi * qubit_info.frequency
    a1 = 2 * jnp.pi * qubit_info.drive_strength

    return ((a0 / 2) * Z) + (a1 * signal(params, t) * X)

get_gaussian_control_sequence

get_gaussian_control_sequence(
    qubit_info: QubitInformation, max_amp: float = 0.5
)

Get predefined Gaussian control sequence with single Gaussian pulse.

Parameters:

Name Type Description Default
qubit_info QubitInformation

Qubit information

required
max_amp float

The maximum amplitude. Defaults to 0.5.

0.5

Returns:

Name Type Description
ControlSequence

Control sequence instance

Source code in src/inspeqtor/v1/predefined.py
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
def get_gaussian_control_sequence(
    qubit_info: QubitInformation,
    max_amp: float = 0.5,  # NOTE: Choice of maximum amplitude is arbitrary
):
    """Get predefined Gaussian control sequence with single Gaussian pulse.

    Args:
        qubit_info (QubitInformation): Qubit information
        max_amp (float, optional): The maximum amplitude. Defaults to 0.5.

    Returns:
        ControlSequence: Control sequence instance
    """
    total_length = 320
    dt = 2 / 9

    control_sequence = ControlSequence(
        controls=[
            GaussianPulse(
                duration=total_length,
                qubit_drive_strength=qubit_info.drive_strength,
                dt=dt,
                max_amp=max_amp,
                min_theta=0.0,
                max_theta=2 * jnp.pi,
            ),
        ],
        total_dt=total_length,
    )

    return control_sequence

get_two_axis_gaussian_control_sequence

get_two_axis_gaussian_control_sequence(
    qubit_info: QubitInformation, max_amp: float = 0.5
)

Get predefined two-axis Gaussian control sequence.

Parameters:

Name Type Description Default
qubit_info QubitInformation

Qubit information

required
max_amp float

The maximum amplitude. Defaults to 0.5.

0.5

Returns:

Name Type Description
ControlSequence

Control sequence instance

Source code in src/inspeqtor/v1/predefined.py
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
def get_two_axis_gaussian_control_sequence(
    qubit_info: QubitInformation,
    max_amp: float = 0.5,
):
    """Get predefined two-axis Gaussian control sequence.

    Args:
        qubit_info (QubitInformation): Qubit information
        max_amp (float, optional): The maximum amplitude. Defaults to 0.5.

    Returns:
        ControlSequence: Control sequence instance
    """
    total_length = 320
    dt = 2 / 9

    control_sequence = ControlSequence(
        controls=[
            TwoAxisGaussianPulse(
                duration=total_length,
                qubit_drive_strength=qubit_info.drive_strength,
                dt=dt,
                max_amp=max_amp,
                min_theta_x=-2 * jnp.pi,
                max_theta_x=2 * jnp.pi,
                min_theta_y=-2 * jnp.pi,
                max_theta_y=2 * jnp.pi,
            ),
        ],
        total_dt=total_length,
    )

    return control_sequence

get_drag_pulse_v2_sequence

get_drag_pulse_v2_sequence(
    qubit_info_drive_strength: float,
    max_amp: float = 0.5,
    min_theta=0.0,
    max_theta=2 * pi,
    min_beta=-2.0,
    max_beta=2.0,
    dt=2 / 9,
)

Get predefined DRAG control sequence with single DRAG pulse.

Parameters:

Name Type Description Default
qubit_info QubitInformation

Qubit information

required
max_amp float

The maximum amplitude. Defaults to 0.5.

0.5

Returns:

Name Type Description
ControlSequence

Control sequence instance

Source code in src/inspeqtor/v1/predefined.py
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
def get_drag_pulse_v2_sequence(
    qubit_info_drive_strength: float,
    max_amp: float = 0.5,  # NOTE: Choice of maximum amplitude is arbitrary
    min_theta=0.0,
    max_theta=2 * jnp.pi,
    min_beta=-2.0,
    max_beta=2.0,
    dt=2 / 9,
):
    """Get predefined DRAG control sequence with single DRAG pulse.

    Args:
        qubit_info (QubitInformation): Qubit information
        max_amp (float, optional): The maximum amplitude. Defaults to 0.5.

    Returns:
        ControlSequence: Control sequence instance
    """
    total_length = 320
    control_sequence = ControlSequence(
        controls=[
            DragPulseV2(
                duration=total_length,
                qubit_drive_strength=qubit_info_drive_strength,
                dt=dt,
                max_amp=max_amp,
                min_theta=min_theta,
                max_theta=max_theta,
                min_beta=min_beta,
                max_beta=max_beta,
            ),
        ],
        total_dt=total_length,
    )

    return control_sequence

generate_experimental_data

generate_experimental_data(
    key: ndarray,
    hamiltonian: Callable[..., ndarray],
    sample_size: int = 10,
    shots: int = 1000,
    strategy: SimulationStrategy = RANDOM,
    get_qubit_information_fn: Callable[
        [], QubitInformation
    ] = get_mock_qubit_information,
    get_control_sequence_fn: Callable[
        [], ControlSequence
    ] = get_multi_drag_control_sequence_v3,
    max_steps: int = int(2**16),
    method: WhiteboxStrategy = ODE,
    trotter_steps: int = 1000,
    expectation_value_receipt: list[
        ExpectationValue
    ] = default_expectation_values_order,
) -> tuple[
    ExperimentData,
    ControlSequence,
    ndarray,
    Callable[[ndarray], ndarray],
]

Generate simulated dataset

Parameters:

Name Type Description Default
key ndarray

Random key

required
hamiltonian Callable[..., ndarray]

Total Hamiltonian of the device

required
sample_size int

Sample size of the control parameters. Defaults to 10.

10
shots int

Number of shots used to estimate expectation value, will be used if SimulationStrategy is SHOT, otherwise ignored. Defaults to 1000.

1000
strategy SimulationStrategy

Simulation strategy. Defaults to SimulationStrategy.RANDOM.

RANDOM
get_qubit_information_fn Callable[[], QubitInformation]

Function that return qubit information. Defaults to get_mock_qubit_information.

get_mock_qubit_information
get_control_sequence_fn Callable[[], ControlSequence]

Function that return control sequence. Defaults to get_multi_drag_control_sequence_v3.

get_multi_drag_control_sequence_v3
max_steps int

Maximum step of solver. Defaults to int(2**16).

int(2 ** 16)
method WhiteboxStrategy

Unitary solver method. Defaults to WhiteboxStrategy.ODE.

ODE

Raises:

Type Description
NotImplementedError

Not support strategy

Returns:

Type Description
tuple[ExperimentData, ControlSequence, ndarray, Callable[[ndarray], ndarray]]

tuple[ExperimentData, ControlSequence, jnp.ndarray, typing.Callable[[jnp.ndarray], jnp.ndarray]]: tuple of (1) Experiment data, (2) Pulse sequence, (3) Noisy unitary, (4) Noisy solver

Source code in src/inspeqtor/v1/predefined.py
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
def generate_experimental_data(
    key: jnp.ndarray,
    hamiltonian: typing.Callable[..., jnp.ndarray],
    sample_size: int = 10,
    shots: int = 1000,
    strategy: SimulationStrategy = SimulationStrategy.RANDOM,
    get_qubit_information_fn: typing.Callable[
        [], QubitInformation
    ] = get_mock_qubit_information,
    get_control_sequence_fn: typing.Callable[
        [], ControlSequence
    ] = get_multi_drag_control_sequence_v3,
    max_steps: int = int(2**16),
    method: WhiteboxStrategy = WhiteboxStrategy.ODE,
    trotter_steps: int = 1000,
    expectation_value_receipt: list[
        ExpectationValue
    ] = default_expectation_values_order,
) -> tuple[
    ExperimentData,
    ControlSequence,
    jnp.ndarray,
    typing.Callable[[jnp.ndarray], jnp.ndarray],
]:
    """Generate simulated dataset

    Args:
        key (jnp.ndarray): Random key
        hamiltonian (typing.Callable[..., jnp.ndarray]): Total Hamiltonian of the device
        sample_size (int, optional): Sample size of the control parameters. Defaults to 10.
        shots (int, optional): Number of shots used to estimate expectation value, will be used if `SimulationStrategy` is `SHOT`, otherwise ignored. Defaults to 1000.
        strategy (SimulationStrategy, optional): Simulation strategy. Defaults to SimulationStrategy.RANDOM.
        get_qubit_information_fn (typing.Callable[ [], QubitInformation ], optional): Function that return qubit information. Defaults to get_mock_qubit_information.
        get_control_sequence_fn (typing.Callable[ [], ControlSequence ], optional): Function that return control sequence. Defaults to get_multi_drag_control_sequence_v3.
        max_steps (int, optional): Maximum step of solver. Defaults to int(2**16).
        method (WhiteboxStrategy, optional): Unitary solver method. Defaults to WhiteboxStrategy.ODE.

    Raises:
        NotImplementedError: Not support strategy

    Returns:
        tuple[ExperimentData, ControlSequence, jnp.ndarray, typing.Callable[[jnp.ndarray], jnp.ndarray]]: tuple of (1) Experiment data, (2) Pulse sequence, (3) Noisy unitary, (4) Noisy solver
    """
    qubit_info, control_sequence, config = get_mock_prefined_exp_v1(
        sample_size=sample_size,
        shots=shots,
        get_control_sequence_fn=get_control_sequence_fn,
        get_qubit_information_fn=get_qubit_information_fn,
    )

    # Generate mock expectation value
    key, exp_key = jax.random.split(key)

    dt = config.device_cycle_time_ns

    if method == WhiteboxStrategy.TROTTER:
        noisy_simulator = jax.jit(
            make_trotterization_solver(
                hamiltonian=hamiltonian,
                total_dt=control_sequence.total_dt,
                dt=dt,
                trotter_steps=trotter_steps,
                y0=jnp.eye(2, dtype=jnp.complex128),
            )
        )
    else:
        t_eval = jnp.linspace(
            0, control_sequence.total_dt * dt, control_sequence.total_dt
        )
        noisy_simulator = jax.jit(
            partial(
                solver,
                t_eval=t_eval,
                hamiltonian=hamiltonian,
                y0=jnp.eye(2, dtype=jnp.complex64),
                t0=0,
                t1=control_sequence.total_dt * dt,
                max_steps=max_steps,
            )
        )

    control_params_list = []
    parameter_structure = control_sequence.get_parameter_names()
    num_parameters = len(list(itertools.chain.from_iterable(parameter_structure)))
    # control_params: list[jnp.ndarray] = []
    control_params = jnp.empty(shape=(sample_size, num_parameters))
    for control_idx in range(config.sample_size):
        key, subkey = jax.random.split(key)
        pulse_params = control_sequence.sample_params(subkey)
        control_params_list.append(pulse_params)

        control_params = control_params.at[control_idx].set(
            list_of_params_to_array(pulse_params, parameter_structure)
        )

    unitaries = jax.vmap(noisy_simulator)(control_params)
    SHOTS = config.shots

    # Calculate the expectation values depending on the strategy
    unitaries_f = jnp.asarray(unitaries)[:, -1, :, :]

    assert unitaries_f.shape == (
        sample_size,
        2,
        2,
    ), f"Final unitaries shape is {unitaries_f.shape}"

    if strategy == SimulationStrategy.RANDOM:
        # Just random expectation values with key
        expectation_values = 2 * (
            jax.random.uniform(exp_key, shape=(config.sample_size, 18)) - (1 / 2)
        )
    elif strategy == SimulationStrategy.IDEAL:
        expectation_values = calculate_expectation_values(unitaries_f)

    elif strategy == SimulationStrategy.SHOT:
        key, sample_key = jax.random.split(key)
        # The `shot_quantum_device` function will re-calculate the unitary
        expectation_values = shot_quantum_device(
            sample_key,
            control_params,
            noisy_simulator,
            SHOTS,
            expectation_value_receipt,
        )
    else:
        raise NotImplementedError

    assert expectation_values.shape == (
        sample_size,
        18,
    ), f"Expectation values shape is {expectation_values.shape}"

    rows = []
    for sample_idx in range(config.sample_size):
        for exp_idx, exp in enumerate(expectation_value_receipt):
            row = make_row(
                expectation_value=float(expectation_values[sample_idx, exp_idx]),
                initial_state=exp.initial_state,
                observable=exp.observable,
                parameters_list=control_params_list[sample_idx],
                parameters_id=sample_idx,
            )

            rows.append(row)

    df = pd.DataFrame(rows)

    exp_data = ExperimentData(experiment_config=config, preprocess_data=df)

    return (
        exp_data,
        control_sequence,
        jnp.array(unitaries),
        noisy_simulator,
    )

get_single_qubit_whitebox

get_single_qubit_whitebox(
    hamiltonian: Callable[..., ndarray],
    control_sequence: ControlSequence,
    qubit_info: QubitInformation,
    dt: float,
    max_steps: int = int(2**16),
    get_envelope_transformer=get_envelope_transformer,
) -> Callable[[ndarray], ndarray]

Generate single qubit whitebox

Parameters:

Name Type Description Default
hamiltonian Callable[..., ndarray]

Hamiltonian

required
control_sequence ControlSequence

Control sequence instance

required
qubit_info QubitInformation

Qubit information

required
dt float

Duration of 1 timestep in nanosecond

required
max_steps int

Maximum steps of solver. Defaults to int(2**16).

int(2 ** 16)

Returns:

Type Description
Callable[[ndarray], ndarray]

typing.Callable[[jnp.ndarray], jnp.ndarray]: Whitebox with ODE solver

Source code in src/inspeqtor/v1/predefined.py
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
def get_single_qubit_whitebox(
    hamiltonian: typing.Callable[..., jnp.ndarray],
    control_sequence: ControlSequence,
    qubit_info: QubitInformation,
    dt: float,
    max_steps: int = int(2**16),
    get_envelope_transformer=get_envelope_transformer,
) -> typing.Callable[[jnp.ndarray], jnp.ndarray]:
    """Generate single qubit whitebox

    Args:
        hamiltonian (typing.Callable[..., jnp.ndarray]): Hamiltonian
        control_sequence (ControlSequence): Control sequence instance
        qubit_info (QubitInformation): Qubit information
        dt (float): Duration of 1 timestep in nanosecond
        max_steps (int, optional): Maximum steps of solver. Defaults to int(2**16).

    Returns:
        typing.Callable[[jnp.ndarray], jnp.ndarray]: Whitebox with ODE solver
    """
    t_eval = jnp.linspace(0, control_sequence.total_dt * dt, control_sequence.total_dt)

    hamiltonian = partial(
        hamiltonian,
        qubit_info=qubit_info,
        signal=make_signal_fn(
            get_envelope_transformer(control_sequence),
            qubit_info.frequency,
            dt,
        ),
    )

    whitebox = partial(
        solver,
        t_eval=t_eval,
        hamiltonian=hamiltonian,
        y0=jnp.eye(2, dtype=jnp.complex_),
        t0=0,
        t1=control_sequence.total_dt * dt,
        max_steps=max_steps,
    )

    return whitebox

load_data_from_path

load_data_from_path(
    path: str | Path,
    hamiltonian_spec: HamiltonianSpec,
    pulse_reader=default_pulse_reader,
) -> LoadedData

Load and prepare the experimental data from given path and hamiltonian spec.

Parameters:

Name Type Description Default
path str | Path

The path to the folder that contain experimental data.

required
hamiltonian_spec HamiltonianSpec

The specification of the Hamiltonian

required
pulse_reader Any

description. Defaults to default_pulse_reader.

default_pulse_reader

Returns:

Name Type Description
LoadedData LoadedData

The object contatin necessary information for device characterization.

Source code in src/inspeqtor/v1/predefined.py
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
def load_data_from_path(
    path: str | pathlib.Path,
    hamiltonian_spec: HamiltonianSpec,
    pulse_reader=default_pulse_reader,
) -> LoadedData:
    """Load and prepare the experimental data from given path and hamiltonian spec.

    Args:
        path (str | pathlib.Path): The path to the folder that contain experimental data.
        hamiltonian_spec (HamiltonianSpec): The specification of the Hamiltonian
        pulse_reader (typing.Any, optional): _description_. Defaults to default_pulse_reader.

    Returns:
        LoadedData: The object contatin necessary information for device characterization.
    """
    exp_data = ExperimentData.from_folder(path)
    control_sequence = pulse_reader(path)

    qubit_info = exp_data.experiment_config.qubits[0]
    dt = exp_data.experiment_config.device_cycle_time_ns

    whitebox = hamiltonian_spec.get_solver(control_sequence, qubit_info, dt)

    return prepare_data(exp_data, control_sequence, whitebox)

save_data_to_path

save_data_to_path(
    path: str | Path,
    experiment_data: ExperimentData,
    control_sequence: ControlSequence,
)

Save the experimental data to the path

Parameters:

Name Type Description Default
path str | Path

The path to folder to save the experimental data

required
experiment_data ExperimentData

The experimental data object

required
control_sequence ControlSequence

The control sequence that used to create the experimental data.

required

Returns:

Name Type Description
None
Source code in src/inspeqtor/v1/predefined.py
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
def save_data_to_path(
    path: str | pathlib.Path,
    experiment_data: ExperimentData,
    control_sequence: ControlSequence,
):
    """Save the experimental data to the path

    Args:
        path (str | pathlib.Path): The path to folder to save the experimental data
        experiment_data (ExperimentData): The experimental data object
        control_sequence (ControlSequence): The control sequence that used to create the experimental data.

    Returns:
        None:
    """
    path = pathlib.Path(path)
    path.mkdir(parents=True, exist_ok=True)
    experiment_data.save_to_folder(path)
    control_sequence.to_file(path)

    return None

Probabilistic

src.inspeqtor.v1.probabilistic

LearningModel

The learning model.

Source code in src/inspeqtor/v1/probabilistic.py
363
364
365
366
367
class LearningModel(StrEnum):
    """The learning model."""

    TruncatedNormal = auto()
    BernoulliProbs = auto()

make_probabilistic_model

make_probabilistic_model(
    predictive_model: Callable[..., ndarray],
    shots: int = 1,
    block_graybox: bool = False,
    separate_observables: bool = False,
    log_expectation_values: bool = False,
)

Make probabilistic model from the Statistical model with priors

Parameters:

Name Type Description Default
base_model Module

The statistical based model, currently only support flax.linen module

required
model_prediction_to_expvals_fn Callable[..., ndarray]

Function to convert output from model to expectation values array

required
bnn_prior dict[str, Distribution] | Distribution

The priors of BNN. Defaults to dist.Normal(0.0, 1.0).

required
shots int

The number of shots forcing PGM to sample. Defaults to 1.

1
block_graybox bool

If true, the latent variables in Graybox model will be hidden, i.e. not traced by numpyro. Defaults to False.

False
enable_bnn bool

If true, the statistical model will be convert to probabilistic model. Defaults to True.

required
separate_observables bool

If true, the observable will be separate into dict form. Defaults to False.

False

Returns:

Type Description

typing.Callable: Probabilistic Graybox Model

Source code in src/inspeqtor/v1/probabilistic.py
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
def make_probabilistic_model(
    predictive_model: typing.Callable[..., jnp.ndarray],
    shots: int = 1,
    block_graybox: bool = False,
    separate_observables: bool = False,
    log_expectation_values: bool = False,
):
    """Make probabilistic model from the Statistical model with priors

    Args:
        base_model (nn.Module): The statistical based model, currently only support flax.linen module
        model_prediction_to_expvals_fn (typing.Callable[..., jnp.ndarray]): Function to convert output from model to expectation values array
        bnn_prior (dict[str, dist.Distribution] | dist.Distribution, optional): The priors of BNN. Defaults to dist.Normal(0.0, 1.0).
        shots (int, optional): The number of shots forcing PGM to sample. Defaults to 1.
        block_graybox (bool, optional): If true, the latent variables in Graybox model will be hidden, i.e. not traced by `numpyro`. Defaults to False.
        enable_bnn (bool, optional): If true, the statistical model will be convert to probabilistic model. Defaults to True.
        separate_observables (bool, optional): If true, the observable will be separate into dict form. Defaults to False.

    Returns:
        typing.Callable: Probabilistic Graybox Model
    """

    def block_graybox_fn(
        control_parameters: jnp.ndarray,
        unitaries: jnp.ndarray,
    ):
        key = numpyro.prng_key()
        with handlers.block(), handlers.seed(rng_seed=key):
            expvals = predictive_model(control_parameters, unitaries)

        return expvals

    graybox_fn = block_graybox_fn if block_graybox else predictive_model

    def bernoulli_model(
        control_parameters: jnp.ndarray,
        unitaries: jnp.ndarray,
        observables: jnp.ndarray | None = None,
    ):
        expvals = graybox_fn(control_parameters, unitaries)

        if log_expectation_values:
            numpyro.deterministic("expectation_values", expvals)

        if observables is None:
            sizes = control_parameters.shape[:-1] + (18,)
            if shots > 1:
                sizes = (shots,) + sizes
        else:
            sizes = observables.shape

        # The plate is for the shots prediction to work properly
        with numpyro.util.optional(
            shots > 1, numpyro.plate_stack("plate", sizes=list(sizes)[:-1])
        ):
            if separate_observables:
                expvals_samples = {}

                for idx, exp in enumerate(default_expectation_values_order):
                    s = numpyro.sample(
                        f"obs/{exp.initial_state}/{exp.observable}",
                        dist.BernoulliProbs(
                            probs=expectation_value_to_prob_minus(
                                jnp.expand_dims(expvals[..., idx], axis=-1)
                            )
                        ).to_event(1),  # type: ignore
                        obs=(
                            observables[..., idx] if observables is not None else None
                        ),
                    )

                    expvals_samples[f"obs/{exp.initial_state}/{exp.observable}"] = s

            else:
                expvals_samples = numpyro.sample(
                    "obs",
                    dist.BernoulliProbs(
                        probs=expectation_value_to_prob_minus(expvals)
                    ).to_event(1),  # type: ignore
                    obs=observables,
                    infer={"enumerate": "parallel"},
                )

        return expvals_samples

    return bernoulli_model

get_args_of_distribution

get_args_of_distribution(x)

Get the arguments used to construct Distribution, if the provided parameter is not Distribution, return it. So that the function can be used with jax.tree.map.

Parameters:

Name Type Description Default
x Any

Maybe Distribution

required

Returns:

Type Description

typing.Any: Argument of Distribution if Distribution is provided.

Source code in src/inspeqtor/v1/probabilistic.py
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
@deprecated
def get_args_of_distribution(x):
    """Get the arguments used to construct Distribution, if the provided parameter is not Distribution, return it.
    So that the function can be used with `jax.tree.map`.

    Args:
        x (typing.Any): Maybe Distribution

    Returns:
        typing.Any: Argument of Distribution if Distribution is provided.
    """
    if isinstance(x, dist.Distribution):
        return x.get_args()
    else:
        return x

construct_normal_priors

construct_normal_priors(posterior)

Construct a dict of Normal Distributions with posterior

Parameters:

Name Type Description Default
posterior Any

Dict of Normal distribution arguments

required

Returns:

Type Description

typing.Any: dict of Normal distributions

Source code in src/inspeqtor/v1/probabilistic.py
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
@deprecated
def construct_normal_priors(posterior):
    """Construct a dict of Normal Distributions with posterior

    Args:
        posterior (typing.Any): Dict of Normal distribution arguments

    Returns:
        typing.Any: dict of Normal distributions
    """
    posterior_distributions = {}
    assert isinstance(posterior, dict)
    for name, value in posterior.items():
        assert isinstance(name, str)
        assert isinstance(value, dict)
        posterior_distributions[name] = dist.Normal(value["loc"], value["scale"])  # type: ignore
    return posterior_distributions

construct_normal_prior_from_samples

construct_normal_prior_from_samples(
    posterior_samples: dict[str, ndarray],
) -> dict[str, Distribution]

Construct a dict of Normal Distributions with posterior sample

Parameters:

Name Type Description Default
posterior_samples dict[str, ndarray]

Posterior sample

required

Returns:

Type Description
dict[str, Distribution]

dict[str, dist.Distribution]: dict of Normal Distributions

Source code in src/inspeqtor/v1/probabilistic.py
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
def construct_normal_prior_from_samples(
    posterior_samples: dict[str, jnp.ndarray],
) -> dict[str, dist.Distribution]:
    """Construct a dict of Normal Distributions with posterior sample

    Args:
        posterior_samples (dict[str, jnp.ndarray]): Posterior sample

    Returns:
        dict[str, dist.Distribution]: dict of Normal Distributions
    """

    posterior_mean = jax.tree.map(lambda x: jnp.mean(x, axis=0), posterior_samples)
    posterior_std = jax.tree.map(lambda x: jnp.std(x, axis=0), posterior_samples)

    prior = {}
    for name, mean in posterior_mean.items():
        prior[name] = dist.Normal(mean, posterior_std[name])

    return prior

make_normal_posterior_dist_fn_from_svi_result

make_normal_posterior_dist_fn_from_svi_result(
    key: ndarray,
    guide: Callable,
    params: dict[str, ndarray],
    num_samples: int,
    prefix: str,
) -> Callable[[str, tuple[int, ...]], Distribution]

This function create a get posterior function to be used with numpyro.contrib.module.

Parameters:

Name Type Description Default
key ndarray

The random key

required
guide Callable

The guide (variational distribution)

required
params dict[str, ndarray]

The variational parameters

required
num_samples int

The number of sample for approxiatation the posterior distributions

required

Examples:

prefix = "graybox"
prior_fn = make_normal_posterior_dist_fn_from_svi_result(
    jax.random.key(0), guide, result.params, 10_000, prefix
)
graybox_model = sq.probabilistic.make_flax_probabilistic_graybox_model(
    name=prefix,
    base_model=base_model,
    adapter_fn=sq.probabilistic.observable_to_expvals,
    prior=prior_fn,
)
posterior_model = sq.probabilistic.make_probabilistic_model(
    predictive_model=graybox_model, log_expectation_values=True
)

Returns:

Type Description
Callable[[str, tuple[int, ...]], Distribution]

typing.Callable[[str, tuple[int, ...]], dist.Distribution]: The function that return posterior distribution

Source code in src/inspeqtor/v1/probabilistic.py
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
def make_normal_posterior_dist_fn_from_svi_result(
    key: jnp.ndarray,
    guide: typing.Callable,
    params: dict[str, jnp.ndarray],
    num_samples: int,
    prefix: str,
) -> typing.Callable[[str, tuple[int, ...]], dist.Distribution]:
    """This function create a get posterior function to be used with `numpyro.contrib.module`.

    Args:
        key (jnp.ndarray): The random key
        guide (typing.Callable): The guide (variational distribution)
        params (dict[str, jnp.ndarray]): The variational parameters
        num_samples (int): The number of sample for approxiatation the posterior distributions

    Examples:
        ```python
        prefix = "graybox"
        prior_fn = make_normal_posterior_dist_fn_from_svi_result(
            jax.random.key(0), guide, result.params, 10_000, prefix
        )
        graybox_model = sq.probabilistic.make_flax_probabilistic_graybox_model(
            name=prefix,
            base_model=base_model,
            adapter_fn=sq.probabilistic.observable_to_expvals,
            prior=prior_fn,
        )
        posterior_model = sq.probabilistic.make_probabilistic_model(
            predictive_model=graybox_model, log_expectation_values=True
        )
        ```

    Returns:
        typing.Callable[[str, tuple[int, ...]], dist.Distribution]: The function that return posterior distribution
    """
    posterior_samples = Predictive(model=guide, params=params, num_samples=num_samples)(
        key
    )

    def posterior_dist_fn(name: str, shape: tuple[int, ...]) -> dist.Distribution:
        site_name = prefix + "/" + name
        return construct_normal_prior_from_samples(posterior_samples)[site_name]

    return posterior_dist_fn

make_predictive_fn

make_predictive_fn(
    posterior_model, learning_model: LearningModel
)

Construct predictive model from the probabilsitic model. This function does not relied on guide and the variational parameters

Examples:

characterized_result = sq.probabilistic.SVIResult.from_file(
    PGM_model_path / "model.json"
)

base_model = sq.models.library.linen.WoModel(
    shared_layers=characterized_result.config["model_config"]["hidden_sizes"][0],
    pauli_layers=characterized_result.config["model_config"]["hidden_sizes"][1],
)
graybox_model = sq.probabilistic.make_flax_probabilistic_graybox_model(
    name="graybox",
    base_model=base_model,
    adapter_fn=sq.probabilistic.observable_to_expvals,
    prior=dist.Normal(0, 1),
)
model = sq.probabilistic.make_probabilistic_model(
    graybox_probabilistic_model=graybox_model,
)
# initialize guide
guide = sq.probabilistic.auto_diagonal_normal_guide(
    model,
    ml.custom_feature_map(loaded_data.control_parameters),
    loaded_data.unitaries,
    jnp.zeros(shape=(shots, loaded_data.control_parameters.shape[0], 18)),
)
priors = {
    k.strip("graybox/"): v
    for k, v in make_prior_from_params(guide, characterized_result.params).items()
}
graybox_model = sq.probabilistic.make_flax_probabilistic_graybox_model(
    name="graybox",
    base_model=base_model,
    adapter_fn=sq.probabilistic.observable_to_expvals,
    prior=priors,
)
posterior_model = sq.probabilistic.make_probabilistic_model(
    graybox_probabilistic_model=graybox_model,
    shots=shots,
    block_graybox=True,
    log_expectation_values=True,
)
predicetive_fn = sq.probabilistic.make_predictive_fn(
    posterior_model, sq.probabilistic.LearningModel.BernoulliProbs
)

Parameters:

Name Type Description Default
posterior_model Any

probabilsitic model

required
learning_model LearningModel

description

required
Source code in src/inspeqtor/v1/probabilistic.py
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
def make_predictive_fn(
    posterior_model,
    learning_model: LearningModel,
):
    """Construct predictive model from the probabilsitic model.
    This function does not relied on guide and the variational parameters

    Examples:
        ```python
        characterized_result = sq.probabilistic.SVIResult.from_file(
            PGM_model_path / "model.json"
        )

        base_model = sq.models.library.linen.WoModel(
            shared_layers=characterized_result.config["model_config"]["hidden_sizes"][0],
            pauli_layers=characterized_result.config["model_config"]["hidden_sizes"][1],
        )
        graybox_model = sq.probabilistic.make_flax_probabilistic_graybox_model(
            name="graybox",
            base_model=base_model,
            adapter_fn=sq.probabilistic.observable_to_expvals,
            prior=dist.Normal(0, 1),
        )
        model = sq.probabilistic.make_probabilistic_model(
            graybox_probabilistic_model=graybox_model,
        )
        # initialize guide
        guide = sq.probabilistic.auto_diagonal_normal_guide(
            model,
            ml.custom_feature_map(loaded_data.control_parameters),
            loaded_data.unitaries,
            jnp.zeros(shape=(shots, loaded_data.control_parameters.shape[0], 18)),
        )
        priors = {
            k.strip("graybox/"): v
            for k, v in make_prior_from_params(guide, characterized_result.params).items()
        }
        graybox_model = sq.probabilistic.make_flax_probabilistic_graybox_model(
            name="graybox",
            base_model=base_model,
            adapter_fn=sq.probabilistic.observable_to_expvals,
            prior=priors,
        )
        posterior_model = sq.probabilistic.make_probabilistic_model(
            graybox_probabilistic_model=graybox_model,
            shots=shots,
            block_graybox=True,
            log_expectation_values=True,
        )
        predicetive_fn = sq.probabilistic.make_predictive_fn(
            posterior_model, sq.probabilistic.LearningModel.BernoulliProbs
        )
        ```

    Args:
        posterior_model (typing.Any): probabilsitic model
        learning_model (LearningModel): _description_
    """

    def binary_predict_expectation_values(
        key: jnp.ndarray,
        control_params: jnp.ndarray,
        unitary: jnp.ndarray,
    ) -> jnp.ndarray:
        return jnp.mean(
            binary_to_eigenvalue(
                handlers.seed(posterior_model, key)(  # type: ignore
                    control_params, unitary
                )
            ),
            axis=0,
        )

    def normal_predict_expectation_values(
        key: jnp.ndarray,
        control_params: jnp.ndarray,
        unitary: jnp.ndarray,
    ) -> jnp.ndarray:
        return handlers.seed(posterior_model, key)(  # type: ignore
            control_params, unitary
        )

    return (
        binary_predict_expectation_values
        if learning_model == LearningModel.BernoulliProbs
        else normal_predict_expectation_values
    )

make_pdf

make_pdf(sample: ndarray, bins: int, srange=(-1, 1))

Make the numberical PDF from given sample using histogram method

Parameters:

Name Type Description Default
sample ndarray

Sample to make PDF.

required
bins int

The number of interval bin.

required
srange tuple

The range of the pdf. Defaults to (-1, 1).

(-1, 1)

Returns:

Type Description

typing.Any: The approximated numerical PDF

Source code in src/inspeqtor/v1/probabilistic.py
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
def make_pdf(sample: jnp.ndarray, bins: int, srange=(-1, 1)):
    """Make the numberical PDF from given sample using histogram method

    Args:
        sample (jnp.ndarray): Sample to make PDF.
        bins (int): The number of interval bin.
        srange (tuple, optional): The range of the pdf. Defaults to (-1, 1).

    Returns:
        typing.Any: The approximated numerical PDF
    """
    density, bin_edges = jnp.histogram(sample, bins=bins, range=srange, density=True)
    dx = jnp.diff(bin_edges)
    p = density * dx
    return p

safe_kl_divergence

safe_kl_divergence(p: ndarray, q: ndarray)

Calculate the KL divergence where the infinity is converted to zero.

Parameters:

Name Type Description Default
p ndarray

The left PDF

required
q ndarray

The right PDF

required

Returns:

Type Description

jnp.ndarray: The KL divergence

Source code in src/inspeqtor/v1/probabilistic.py
476
477
478
479
480
481
482
483
484
485
486
def safe_kl_divergence(p: jnp.ndarray, q: jnp.ndarray):
    """Calculate the KL divergence where the infinity is converted to zero.

    Args:
        p (jnp.ndarray): The left PDF
        q (jnp.ndarray): The right PDF

    Returns:
        jnp.ndarray: The KL divergence
    """
    return jnp.sum(jnp.nan_to_num(jax.scipy.special.rel_entr(p, q), posinf=0.0))

kl_divergence

kl_divergence(p: ndarray, q: ndarray)

Calculate the KL divergence

Parameters:

Name Type Description Default
p ndarray

The left PDF

required
q ndarray

The right PDF

required

Returns:

Type Description

jnp.ndarray: The KL divergence

Source code in src/inspeqtor/v1/probabilistic.py
489
490
491
492
493
494
495
496
497
498
499
def kl_divergence(p: jnp.ndarray, q: jnp.ndarray):
    """Calculate the KL divergence

    Args:
        p (jnp.ndarray): The left PDF
        q (jnp.ndarray): The right PDF

    Returns:
        jnp.ndarray:  The KL divergence
    """
    return jnp.sum(jax.scipy.special.rel_entr(p, q))

safe_jensenshannon_divergence

safe_jensenshannon_divergence(p: ndarray, q: ndarray)

Calculate Jensen-Shannon Divergnece using KL divergence. Implement following: https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.distance.jensenshannon.html

Parameters:

Name Type Description Default
p ndarray

The left PDF

required
q ndarray

The right PDF

required

Returns:

Type Description

typing.Any: description

Source code in src/inspeqtor/v1/probabilistic.py
502
503
504
505
506
507
508
509
510
511
512
513
514
515
def safe_jensenshannon_divergence(p: jnp.ndarray, q: jnp.ndarray):
    """Calculate Jensen-Shannon Divergnece using KL divergence.
    Implement following: https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.distance.jensenshannon.html

    Args:
        p (jnp.ndarray): The left PDF
        q (jnp.ndarray): The right PDF

    Returns:
        typing.Any: _description_
    """
    # Compute pointwise mean of p and q
    m = (p + q) / 2
    return (safe_kl_divergence(p, m) + safe_kl_divergence(q, m)) / 2

jensenshannon_divergence_from_pmf

jensenshannon_divergence_from_pmf(p: ndarray, q: ndarray)

Calculate the Jensen-Shannon Divergence from PMF

Example

key = jax.random.key(0)
key_1, key_2 = jax.random.split(key)
sample_1 = jax.random.normal(key_1, shape=(10000, ))
sample_2 = jax.random.normal(key_2, shape=(10000, ))

# Determine srange from sample
merged_sample = jnp.concat([sample_1, sample_2])
srange = jnp.min(merged_sample), jnp.max(merged_sample)

# https://stats.stackexchange.com/questions/510699/discrete-kl-divergence-with-decreasing-bin-width
# Recommend this book: https://catalog.lib.uchicago.edu/vufind/Record/6093380/TOC
bins = int(2 * (sample_2.shape[0]) ** (1/3))
# bins = 10
dis_1 = sq.probabilistic.make_pdf(sample_1, bins=bins, srange=srange)
dis_2 = sq.probabilistic.make_pdf(sample_2, bins=bins, srange=srange)

jsd = sq.probabilistic.jensenshannon_divergence_from_pdf(dis_1, dis_2)

Parameters:

Name Type Description Default
p ndarray

The 1st probability mass function

required
q ndarray

The 1st probability mass function

required

Returns:

Type Description

jnp.ndarray: The Jensen-Shannon Divergence of p and q

Source code in src/inspeqtor/v1/probabilistic.py
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
def jensenshannon_divergence_from_pmf(p: jnp.ndarray, q: jnp.ndarray):
    """Calculate the Jensen-Shannon Divergence from PMF

    Example
    ```python
    key = jax.random.key(0)
    key_1, key_2 = jax.random.split(key)
    sample_1 = jax.random.normal(key_1, shape=(10000, ))
    sample_2 = jax.random.normal(key_2, shape=(10000, ))

    # Determine srange from sample
    merged_sample = jnp.concat([sample_1, sample_2])
    srange = jnp.min(merged_sample), jnp.max(merged_sample)

    # https://stats.stackexchange.com/questions/510699/discrete-kl-divergence-with-decreasing-bin-width
    # Recommend this book: https://catalog.lib.uchicago.edu/vufind/Record/6093380/TOC
    bins = int(2 * (sample_2.shape[0]) ** (1/3))
    # bins = 10
    dis_1 = sq.probabilistic.make_pdf(sample_1, bins=bins, srange=srange)
    dis_2 = sq.probabilistic.make_pdf(sample_2, bins=bins, srange=srange)

    jsd = sq.probabilistic.jensenshannon_divergence_from_pdf(dis_1, dis_2)

    ```

    Args:
        p (jnp.ndarray): The 1st probability mass function
        q (jnp.ndarray): The 1st probability mass function

    Returns:
        jnp.ndarray: The Jensen-Shannon Divergence of p and q
    """
    # Note for JSD: https://medium.com/data-science/how-to-understand-and-use-jensen-shannon-divergence-b10e11b03fd6
    # Implement following: https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.distance.jensenshannon.html
    # Compute pointwise mean of p and q
    m = (p + q) / 2
    return (kl_divergence(p, m) + kl_divergence(q, m)) / 2

jensenshannon_divergence_from_sample

jensenshannon_divergence_from_sample(
    sample_1: ndarray, sample_2: ndarray
) -> ndarray

Calculate the Jensen-Shannon Divergence from sample

Parameters:

Name Type Description Default
sample_1 ndarray

The left PDF

required
sample_2 ndarray

The right PDF

required

Returns:

Type Description
ndarray

jnp.ndarray: The Jensen-Shannon Divergence of p and q

Source code in src/inspeqtor/v1/probabilistic.py
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
def jensenshannon_divergence_from_sample(
    sample_1: jnp.ndarray, sample_2: jnp.ndarray
) -> jnp.ndarray:
    """Calculate the Jensen-Shannon Divergence from sample

    Args:
        sample_1 (jnp.ndarray): The left PDF
        sample_2 (jnp.ndarray): The right PDF

    Returns:
        jnp.ndarray: The Jensen-Shannon Divergence of p and q
    """
    merged_sample = jnp.concat([sample_1, sample_2])
    bins = int(2 * (sample_2.shape[0]) ** (1 / 3))
    srange = jnp.min(merged_sample), jnp.max(merged_sample)

    dis_1 = make_pdf(sample_1, bins=bins, srange=srange)
    dis_2 = make_pdf(sample_2, bins=bins, srange=srange)

    return jensenshannon_divergence_from_pmf(dis_1, dis_2)

batched_matmul

batched_matmul(x, w, b)

A specialized batched matrix multiplication of weight and input x, then add the bias. This function is intended to be used in dense_layer

Parameters:

Name Type Description Default
x ndarray

The input x

required
w ndarray

The weight to multiply to x

required
b ndarray

The bias

required

Returns:

Type Description

jnp.ndarray: Output of the operations.

Source code in src/inspeqtor/v1/probabilistic.py
579
580
581
582
583
584
585
586
587
588
589
590
591
def batched_matmul(x, w, b):
    """A specialized batched matrix multiplication of weight and input x, then add the bias.
    This function is intended to be used in `dense_layer`

    Args:
        x (jnp.ndarray): The input x
        w (jnp.ndarray): The weight to multiply to x
        b (jnp.ndarray): The bias

    Returns:
        jnp.ndarray: Output of the operations.
    """
    return jnp.einsum(x, (..., 0), w, (..., 0, 1), (..., 1)) + b

get_trace

get_trace(fn, key=key(0))

Convinent function to get a trace of the probabilistic model in numpyro.

Parameters:

Name Type Description Default
fn function

The probabilistic model in numpyro.

required
key ndarray

The random key. Defaults to jax.random.key(0).

key(0)
Source code in src/inspeqtor/v1/probabilistic.py
594
595
596
597
598
599
600
601
602
603
604
605
def get_trace(fn, key=jax.random.key(0)):
    """Convinent function to get a trace of the probabilistic model in numpyro.

    Args:
        fn (function): The probabilistic model in numpyro.
        key (jnp.ndarray, optional): The random key. Defaults to jax.random.key(0).
    """

    def inner(*args, **kwargs):
        return handlers.trace(handlers.seed(fn, key)).get_trace(*args, **kwargs)

    return inner

default_priors_fn

default_priors_fn(
    name: str, shape: tuple[int, ...]
) -> Distribution

This is a default prior function for the dense_layer

Parameters:

Name Type Description Default
name str

The site name of the parameters, if end with sigma will return Log Normal distribution, otherwise, return Normal distribution

required

Returns:

Type Description
Distribution

typing.Any: description

Source code in src/inspeqtor/v1/probabilistic.py
608
609
610
611
612
613
614
615
616
617
618
619
620
621
def default_priors_fn(name: str, shape: tuple[int, ...]) -> dist.Distribution:
    """This is a default prior function for the `dense_layer`

    Args:
        name (str): The site name of the parameters, if end with `sigma` will return Log Normal distribution,
                          otherwise, return Normal distribution

    Returns:
        typing.Any: _description_
    """
    if name.endswith("bias"):
        return dist.LogNormal(0, 1).expand(shape)

    return dist.Normal(0, 1).expand(shape)

dense_layer

dense_layer(
    x: ndarray,
    name: str,
    in_features: int,
    out_features: int,
    priors_fn: Callable[
        [str, tuple[int, ...]], Distribution
    ] = default_priors_fn,
)

A custom probabilistic dense layer for neural network model. This function intended to be used with numpyro

Parameters:

Name Type Description Default
x ndarray

The input x

required
name str

Site name of the layer

required
in_features int

The size of the feature.

required
out_features int

The desired size of the output feature.

required
priors_fn Callable[[str], Distribution]

The prior function to be used for initializing prior distribution. Defaults to default_priors_fn.

default_priors_fn

Returns:

Type Description

typing.Any: Output of the layer given x.

Source code in src/inspeqtor/v1/probabilistic.py
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
def dense_layer(
    x: jnp.ndarray,
    name: str,
    in_features: int,
    out_features: int,
    priors_fn: typing.Callable[
        [str, tuple[int, ...]], dist.Distribution
    ] = default_priors_fn,
):
    """A custom probabilistic dense layer for neural network model.
    This function intended to be used with `numpyro`

    Args:
        x (jnp.ndarray): The input x
        name (str): Site name of the layer
        in_features (int): The size of the feature.
        out_features (int): The desired size of the output feature.
        priors_fn (typing.Callable[[str], dist.Distribution], optional): The prior function to be used for initializing prior distribution. Defaults to default_priors_fn.

    Returns:
        typing.Any: Output of the layer given x.
    """
    w_name = f"{name}.kernel"
    w = numpyro.sample(
        w_name,
        priors_fn(w_name, (in_features, out_features)).to_event(2),  # type: ignore
    )
    b_name = f"{name}.bias"
    b = numpyro.sample(
        b_name,
        priors_fn(b_name, (out_features,)).to_event(1),  # type: ignore
    )
    return batched_matmul(x, w, b)  # type: ignore

init_default

init_default(params_name: str)

The initialization function for deterministic dense layer

Parameters:

Name Type Description Default
params_name str

The site name

required

Raises:

Type Description
ValueError

Unsupport site name

Returns:

Type Description

typing.Any: The function to be used for parameters init given site name.

Source code in src/inspeqtor/v1/probabilistic.py
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
def init_default(params_name: str):
    """The initialization function for deterministic dense layer

    Args:
        params_name (str): The site name

    Raises:
        ValueError: Unsupport site name

    Returns:
        typing.Any: The function to be used for parameters init given site name.
    """
    if params_name.endswith("kernel"):
        return jnp.ones
    elif params_name.endswith("bias"):
        return lambda x: 0.1 * jnp.ones(x)
    else:
        raise ValueError("Unsupport param name")

dense_deterministic_layer

dense_deterministic_layer(
    x,
    name: str,
    in_features: int,
    out_features: int,
    batch_shape: tuple[int, ...] = (),
    init_fn=init_default,
)

The deterministic dense layer, to be used with SVI optimizer.

Parameters:

Name Type Description Default
x Any

The input feature

required
name str

The site name

required
in_features int

The size of the input features

required
out_features int

The desired size of the output features.

required
batch_shape tuple[int, ...]

The batch size of the x. Defaults to ().

()
init_fn Any

Initilization function of the model parameters. Defaults to init_default.

init_default

Returns:

Type Description

typing.Any: The output of the layer given x.

Source code in src/inspeqtor/v1/probabilistic.py
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
def dense_deterministic_layer(
    x,
    name: str,
    in_features: int,
    out_features: int,
    batch_shape: tuple[int, ...] = (),
    init_fn=init_default,
):
    """The deterministic dense layer, to be used with SVI optimizer.

    Args:
        x (typing.Any): The input feature
        name (str): The site name
        in_features (int): The size of the input features
        out_features (int): The desired size of the output features.
        batch_shape (tuple[int, ...], optional): The batch size of the x. Defaults to ().
        init_fn (typing.Any, optional): Initilization function of the model parameters. Defaults to init_default.

    Returns:
        typing.Any: The output of the layer given x.
    """
    # Sample weights - shape (in_features, out_features)
    weight_shape = batch_shape + (in_features, out_features)
    W_name = f"{name}.kernel"
    W = numpyro.param(
        W_name,
        init_fn(W_name)(shape=weight_shape),  # type: ignore
    )

    # Sample bias - shape (out_features,)
    bias_shape = batch_shape + (out_features,)
    b_name = f"{name}.bias"
    b = numpyro.param(b_name, init_fn(b_name)(shape=bias_shape))  # type: ignore

    return batched_matmul(x, W, b)  # type: ignore

make_posteriors_fn

make_posteriors_fn(
    key: ndarray, guide, params, num_samples=10000
)

Make the posterior distribution function that will return the posterior of parameter of the given name, from guide and parameters.

Parameters:

Name Type Description Default
guide Any

The guide function

required
params Any

The parameters of the guide

required
num_samples int

The sample size. Defaults to 10000.

10000

Returns:

Type Description

typing.Any: A function of parameter name that return the sample from the posterior distribution of the parameters.

Source code in src/inspeqtor/v1/probabilistic.py
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
@deprecated
def make_posteriors_fn(key: jnp.ndarray, guide, params, num_samples=10000):
    """Make the posterior distribution function that will
    return the posterior of parameter of the given name, from guide and parameters.

    Args:
        guide (typing.Any): The guide function
        params (typing.Any): The parameters of the guide
        num_samples (int, optional): The sample size. Defaults to 10000.

    Returns:
        typing.Any: A function of parameter name that return the sample from the posterior distribution of the parameters.
    """
    posterior_distribution = Predictive(
        model=guide, params=params, num_samples=num_samples
    )(key)

    posterior_dict = construct_normal_prior_from_samples(posterior_distribution)

    def posteriors_fn(param_name: str):
        return posterior_dict[param_name]

    return posteriors_fn

auto_diagonal_normal_guide

auto_diagonal_normal_guide(
    model,
    *args,
    block_sample: bool = False,
    init_loc_fn=zeros,
    key: ndarray = key(0),
)

Automatically generate guide from given model. Expected to be initialized with the example input of the model. The given input should also including the observed site. The blocking capability is intended to be used in the when the guide will be used with its corresponding model in anothe model. This is the avoid site name duplication, while allows for model to use newly sample from the guide.

Parameters:

Name Type Description Default
model Any

The probabilistic model.

required
block_sample bool

Flag to block the sample site. Defaults to False.

False
init_loc_fn Any

Initialization of guide parameters function. Defaults to jnp.zeros.

zeros

Returns:

Type Description

typing.Any: description

Source code in src/inspeqtor/v1/probabilistic.py
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
def auto_diagonal_normal_guide(
    model,
    *args,
    block_sample: bool = False,
    init_loc_fn=jnp.zeros,
    key: jnp.ndarray = jax.random.key(0),
):
    """Automatically generate guide from given model. Expected to be initialized with the example input of the model.
    The given input should also including the observed site.
    The blocking capability is intended to be used in the when the guide will be used with its corresponding model in anothe model.
    This is the avoid site name duplication, while allows for model to use newly sample from the guide.

    Args:
        model (typing.Any): The probabilistic model.
        block_sample (bool, optional): Flag to block the sample site. Defaults to False.
        init_loc_fn (typing.Any, optional): Initialization of guide parameters function. Defaults to jnp.zeros.

    Returns:
        typing.Any: _description_
    """
    model_trace = handlers.trace(handlers.seed(model, key)).get_trace(*args)
    # get the trace of the model
    # Then get only the sample site with observed equal to false
    sample_sites = [v for k, v in model_trace.items() if v["type"] == "sample"]
    non_observed_sites = [v for v in sample_sites if not v["is_observed"]]
    params_sites = [
        {"name": v["name"], "shape": v["value"].shape} for v in non_observed_sites
    ]

    def guide(
        *args,
        **kwargs,
    ):
        params_loc = {
            param["name"]: numpyro.param(
                f"{param['name']}_loc", init_loc_fn(param["shape"])
            )
            for param in params_sites
        }

        params_scale = {
            param["name"]: numpyro.param(
                f"{param['name']}_scale",
                0.1 * jnp.ones(param["shape"]),
                constraint=dist.constraints.softplus_positive,
            )
            for param in params_sites
        }

        samples = {}

        if block_sample:
            with handlers.block():
                # Sample from Normal distribution
                for (k_loc, v_loc), (k_scale, v_scale) in zip(
                    params_loc.items(), params_scale.items(), strict=True
                ):
                    s = numpyro.sample(
                        k_loc,
                        dist.Normal(v_loc, v_scale).to_event(),  # type: ignore
                    )
                    samples[k_loc] = s
        else:
            # Sample from Normal distribution
            for (k_loc, v_loc), (k_scale, v_scale) in zip(
                params_loc.items(), params_scale.items(), strict=True
            ):
                s = numpyro.sample(
                    k_loc,
                    dist.Normal(v_loc, v_scale).to_event(),  # type: ignore
                )
                samples[k_loc] = s

        return samples

    return guide

auto_diagonal_normal_guide_v2

auto_diagonal_normal_guide_v2(
    model,
    *args,
    init_dist_fn=init_normal_dist_fn,
    init_params_fn=init_params_fn,
    block_sample: bool = False,
    key: ndarray = key(0),
)

Automatically generate guide from given model. Expected to be initialized with the example input of the model. The given input should also including the observed site. The blocking capability is intended to be used in the when the guide will be used with its corresponding model in anothe model. This is the avoid site name duplication, while allows for model to use newly sample from the guide.

Parameters:

Name Type Description Default
model Any

The probabilistic model.

required
block_sample bool

Flag to block the sample site. Defaults to False.

False
init_loc_fn Any

Initialization of guide parameters function. Defaults to jnp.zeros.

required

Returns:

Type Description

typing.Any: description

Source code in src/inspeqtor/v1/probabilistic.py
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
def auto_diagonal_normal_guide_v2(
    model,
    *args,
    init_dist_fn=init_normal_dist_fn,
    init_params_fn=init_params_fn,
    block_sample: bool = False,
    key: jnp.ndarray = jax.random.key(0),
):
    """Automatically generate guide from given model. Expected to be initialized with the example input of the model.
    The given input should also including the observed site.
    The blocking capability is intended to be used in the when the guide will be used with its corresponding model in anothe model.
    This is the avoid site name duplication, while allows for model to use newly sample from the guide.

    Args:
        model (typing.Any): The probabilistic model.
        block_sample (bool, optional): Flag to block the sample site. Defaults to False.
        init_loc_fn (typing.Any, optional): Initialization of guide parameters function. Defaults to jnp.zeros.

    Returns:
        typing.Any: _description_
    """
    # get the trace of the model
    model_trace = handlers.trace(handlers.seed(model, key)).get_trace(*args)
    # Then get only the sample site with observed equal to false
    sample_sites = [v for k, v in model_trace.items() if v["type"] == "sample"]
    non_observed_sites = [v for v in sample_sites if not v["is_observed"]]
    params_sites = [
        {"name": v["name"], "shape": v["value"].shape} for v in non_observed_sites
    ]

    def sample_fn(
        params_loc: dict[str, typing.Any], params_scale: dict[str, typing.Any]
    ):
        samples = {}
        # Sample from Normal distribution
        for (k_loc, v_loc), (k_scale, v_scale) in zip(
            params_loc.items(), params_scale.items(), strict=True
        ):
            s = numpyro.sample(
                k_loc,
                init_dist_fn(k_loc)(v_loc, v_scale).to_event(),  # type: ignore
            )
            samples[k_loc] = s

        return samples

    def guide(
        *args,
        **kwargs,
    ):
        params_loc = {
            param["name"]: init_params_fn(f"{param['name']}_loc", param["shape"])
            for param in params_sites
        }

        params_scale = {
            param["name"]: init_params_fn(f"{param['name']}_scale", param["shape"])
            for param in params_sites
        }

        if block_sample:
            with handlers.block():
                samples = sample_fn(params_loc, params_scale)
        else:
            samples = sample_fn(params_loc, params_scale)

        return samples

    return guide

make_predictive_fn_v2

make_predictive_fn_v2(model, guide, params, shots: int)

Make a postirior predictive model function from model, guide, SVI parameters, and the number of shots. This version relied explicitly on the guide and variational parameters. It might be slow than the first version.

Parameters:

Name Type Description Default
model Any

Probabilistic model.

required
guide Any

Gudie corresponded to the model

required
params Any

SVI parameters of the guide

required
shots int

The number of shots

required

Returns:

Type Description

typing.Any: The posterior predictive model.

Source code in src/inspeqtor/v1/probabilistic.py
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
def make_predictive_fn_v2(
    model,
    guide,
    params,
    shots: int,
):
    """Make a postirior predictive model function from model, guide, SVI parameters, and the number of shots.
    This version relied explicitly on the guide and variational parameters. It might be slow than the first version.

    Args:
        model (typing.Any): Probabilistic model.
        guide (typing.Any): Gudie corresponded to the model
        params (typing.Any): SVI parameters of the guide
        shots (int): The number of shots

    Returns:
        typing.Any: The posterior predictive model.
    """
    predictive = Predictive(
        model, guide=guide, params=params, num_samples=shots, return_sites=["obs"]
    )

    def predictive_fn(*args, **kwargs):
        return predictive(*args, **kwargs)["obs"]

    return predictive_fn

make_predictive_SGM_model

make_predictive_SGM_model(
    model: Module,
    model_params,
    output_to_expectation_values_fn,
    shots: int,
)

Make a predictive model from given SGM model, the model parameters, and number of shots.

Parameters:

Name Type Description Default
model Module

Flax model

required
model_params Any

The model parameters.

required
shots int

The number of shots.

required
Source code in src/inspeqtor/v1/probabilistic.py
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
@deprecated
def make_predictive_SGM_model(
    model: nn.Module, model_params, output_to_expectation_values_fn, shots: int
):
    """Make a predictive model from given SGM model, the model parameters, and number of shots.

    Args:
        model (nn.Module): Flax model
        model_params (typing.Any): The model parameters.
        shots (int): The number of shots.
    """

    def predictive_model(
        key: jnp.ndarray, control_param: jnp.ndarray, unitaries: jnp.ndarray
    ):
        output = model.apply(model_params, control_param)
        predicted_expvals = output_to_expectation_values_fn(output, unitaries)

        return binary_to_eigenvalue(
            jax.vmap(jax.random.bernoulli, in_axes=(0, None))(
                jax.random.split(key, shots),
                expectation_value_to_prob_minus(predicted_expvals),
            ).astype(jnp.int_)
        ).mean(axis=0)

    return predictive_model

make_predictive_MCDGM_model

make_predictive_MCDGM_model(model: Module, model_params)

Make a predictive model from given Monte-Carlo Dropout Graybox model, and the model parameters.

Parameters:

Name Type Description Default
model Module

Monte-Carlo Dropout Graybox model

required
model_params Any

The model parameters

required
Source code in src/inspeqtor/v1/probabilistic.py
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
def make_predictive_MCDGM_model(model: nn.Module, model_params):
    """Make a predictive model from given Monte-Carlo Dropout Graybox model, and the model parameters.

    Args:
        model (nn.Module): Monte-Carlo Dropout Graybox model
        model_params (typing.Any): The model parameters
    """

    def predictive_model(
        key: jnp.ndarray, control_param: jnp.ndarray, unitaries: jnp.ndarray
    ):
        wo_params = model.apply(
            model_params,
            control_param,
            rngs={"dropout": key},
        )

        predicted_expvals = get_predict_expectation_value(
            wo_params,  # type: ignore
            unitaries,
            default_expectation_values_order,
        )

        return predicted_expvals

    return predictive_model

make_predictive_resampling_model

make_predictive_resampling_model(
    predictive_fn: Callable[[ndarray, ndarray], ndarray],
    shots: int,
) -> Callable[[ndarray, ndarray, ndarray], ndarray]

Make a binary predictive model from given SGM model, the model parameters, and number of shots.

Parameters:

Name Type Description Default
predictive_fn Callable[[ndarray, ndarray], ndarray]

The predictive_fn embeded with the SGM model.

required
shots int

The number of shots.

required

Returns:

Type Description
Callable[[ndarray, ndarray, ndarray], ndarray]

typing.Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray], jnp.ndarray]: Binary predictive model.

Source code in src/inspeqtor/v1/probabilistic.py
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
def make_predictive_resampling_model(
    predictive_fn: typing.Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray], shots: int
) -> typing.Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray], jnp.ndarray]:
    """Make a binary predictive model from given SGM model, the model parameters, and number of shots.

    Args:
        predictive_fn (typing.Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray]): The predictive_fn embeded with the SGM model.
        shots (int): The number of shots.

    Returns:
        typing.Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray], jnp.ndarray]: Binary predictive model.
    """

    def predictive_model(
        key: jnp.ndarray, control_parameters: jnp.ndarray, unitaries: jnp.ndarray
    ):
        predicted_expvals = predictive_fn(control_parameters, unitaries)

        return binary_to_eigenvalue(
            jax.vmap(jax.random.bernoulli, in_axes=(0, None))(
                jax.random.split(key, shots),
                expectation_value_to_prob_minus(predicted_expvals),
            ).astype(jnp.int_)
        ).mean(axis=0)

    return predictive_model

make_probabilistic_graybox_model

make_probabilistic_graybox_model(model, adapter_fn)

This function make a probabilistic graybox model using custom numpyro BNN model

Parameters:

Name Type Description Default
model _type_

description

required
adapter_fn _type_

description

required
Source code in src/inspeqtor/v1/probabilistic.py
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
def make_probabilistic_graybox_model(model, adapter_fn):
    """This function make a probabilistic graybox model using custom numpyro BNN model

    Args:
        model (_type_): _description_
        adapter_fn (_type_): _description_
    """

    def probabilistic_graybox_model(control_parameters, unitaries):
        samples_shape = control_parameters.shape[:-2]
        unitaries = jnp.broadcast_to(unitaries, samples_shape + unitaries.shape[-3:])

        # Predict from control parameters
        output = model(control_parameters)

        numpyro.deterministic("output", output)

        # With unitary and Wo, calculate expectation values
        expvals = adapter_fn(output, unitaries)

        return expvals

    return probabilistic_graybox_model

auto_diagonal_normal_guide_v3

auto_diagonal_normal_guide_v3(
    model,
    *args,
    init_dist_fn=bnn_init_dist_fn,
    init_params_fn=bnn_init_params_fn,
    dist_transform_fn=default_transform_dist_fn,
    block_sample: bool = False,
    key: ndarray = key(0),
)

Automatically generate guide from given model. Expected to be initialized with the example input of the model. The given input should also including the observed site. The blocking capability is intended to be used in the when the guide will be used with its corresponding model in anothe model. This is the avoid site name duplication, while allows for model to use newly sample from the guide.

Notes

This version enable even more flexible initialization strategy. This function intended to be able to be compatible with auto marginal guide.

Parameters:

Name Type Description Default
model Any

The probabilistic model.

required
block_sample bool

Flag to block the sample site. Defaults to False.

False
init_loc_fn Any

Initialization of guide parameters function. Defaults to jnp.zeros.

required

Returns:

Type Description

typing.Any: description

Source code in src/inspeqtor/v1/probabilistic.py
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
def auto_diagonal_normal_guide_v3(
    model,
    *args,
    init_dist_fn=bnn_init_dist_fn,
    init_params_fn=bnn_init_params_fn,
    dist_transform_fn=default_transform_dist_fn,
    block_sample: bool = False,
    key: jnp.ndarray = jax.random.key(0),
):
    """Automatically generate guide from given model. Expected to be initialized with the example input of the model.
    The given input should also including the observed site.
    The blocking capability is intended to be used in the when the guide will be used with its corresponding model in anothe model.
    This is the avoid site name duplication, while allows for model to use newly sample from the guide.

    Notes:
        This version enable even more flexible initialization strategy.
        This function intended to be able to be compatible with auto marginal guide.

    Args:
        model (typing.Any): The probabilistic model.
        block_sample (bool, optional): Flag to block the sample site. Defaults to False.
        init_loc_fn (typing.Any, optional): Initialization of guide parameters function. Defaults to jnp.zeros.

    Returns:
        typing.Any: _description_
    """
    # get the trace of the model
    model_trace = handlers.trace(handlers.seed(model, key)).get_trace(*args)
    # Then get only the sample site with observed equal to false
    sample_sites = [v for k, v in model_trace.items() if v["type"] == "sample"]
    non_observed_sites = [v for v in sample_sites if not v["is_observed"]]
    params_sites = [
        {"name": v["name"], "shape": v["value"].shape} for v in non_observed_sites
    ]

    def sample_fn(
        params_loc: dict[str, typing.Any], params_scale: dict[str, typing.Any]
    ):
        samples = {}
        # Sample from Normal distribution
        for (k_loc, v_loc), (k_scale, v_scale) in zip(
            params_loc.items(), params_scale.items(), strict=True
        ):
            s = numpyro.sample(
                k_loc,
                dist_transform_fn(k_loc, init_dist_fn(k_loc)(v_loc, v_scale)),  # type: ignore
            )
            samples[k_loc] = s

        return samples

    def guide(
        *args,
        **kwargs,
    ):
        params_loc = {
            param["name"]: init_params_fn(f"{param['name']}_loc", param["shape"])
            for param in params_sites
        }

        params_scale = {
            param["name"]: init_params_fn(f"{param['name']}_scale", param["shape"])
            for param in params_sites
        }

        with numpyro.util.optional(block_sample, handlers.block()):
            samples = sample_fn(params_loc, params_scale)

        return samples

    return guide

make_posterior_fn

make_posterior_fn(
    params, get_dist_fn: Callable[[str], Any]
)

This function create a posterior function to make a posterior predictive model

Examples:

posterior_model = sq.probabilistic.make_probabilistic_model(
    predictive_model=partial(
            probabilistic_graybox_model,
            priors_fn=make_posterior_fn(result.params, init_bnn_dist_fn),
        ),
    log_expectation_values=True,
)

Parameters:

Name Type Description Default
params _type_

The variational parameters from SVI

required
get_dist_fn Callable[[str], Distribution]

The function that return function given name

required
Source code in src/inspeqtor/v1/probabilistic.py
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
def make_posterior_fn(params, get_dist_fn: typing.Callable[[str], typing.Any]):
    """This function create a posterior function to make a posterior predictive model

    Examples:
        ```python
        posterior_model = sq.probabilistic.make_probabilistic_model(
            predictive_model=partial(
                    probabilistic_graybox_model,
                    priors_fn=make_posterior_fn(result.params, init_bnn_dist_fn),
                ),
            log_expectation_values=True,
        )
        ```

    Args:
        params (_type_): The variational parameters from SVI
        get_dist_fn (typing.Callable[[str], dist.Distribution]): The function that return function given name
    """

    def posterior_fn(name: str, shape: tuple[int, ...]):
        return get_dist_fn(name)(params[name + "_loc"], params[name + "_scale"])

    return posterior_fn

Utilities

src.inspeqtor.v1.utils

center_location

center_location(
    num_of_pulse: int, total_time_dt: int | float
) -> ndarray

Create an array of location equally that centered each pulse.

Parameters:

Name Type Description Default
num_of_pulse int

The number of the pulse in the sequence to be equally center.

required
total_time_dt int | float

The total bins of the sequence.

required

Returns:

Type Description
ndarray

jnp.ndarray: The array of location equally that centered each pulse.

Source code in src/inspeqtor/v1/utils.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def center_location(num_of_pulse: int, total_time_dt: int | float) -> jnp.ndarray:
    """Create an array of location equally that centered each pulse.

    Args:
        num_of_pulse (int): The number of the pulse in the sequence to be equally center.
        total_time_dt (int | float): The total bins of the sequence.

    Returns:
        jnp.ndarray: The array of location equally that centered each pulse.
    """
    center_locations = (
        jnp.array([(k - 0.5) / num_of_pulse for k in range(1, num_of_pulse + 1)])
        * total_time_dt
    )
    return center_locations

drag_envelope_v2

drag_envelope_v2(
    amp: float | ndarray,
    sigma: float | ndarray,
    beta: float | ndarray,
    center: float | ndarray,
    final_amp: float | ndarray = 1.0,
)

Drag pulse following: https://docs.quantum.ibm.com/api/qiskit/qiskit.pulse.library.Drag_class.rst#drag

Parameters:

Name Type Description Default
amp float | ndarray

The amplitude of the pulse

required
sigma float | ndarray

The standard deviation of the pulse

required
beta float | ndarray

DRAG coefficient.

required
center float | ndarray

Center location of the pulse

required
final_amp float | ndarray

Final amplitude of the control. Defaults to 1.0.

1.0

Returns:

Type Description

typing.Callable: DRAG envelope function

Source code in src/inspeqtor/v1/utils.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def drag_envelope_v2(
    amp: float | jnp.ndarray,
    sigma: float | jnp.ndarray,
    beta: float | jnp.ndarray,
    center: float | jnp.ndarray,
    final_amp: float | jnp.ndarray = 1.0,
):
    """Drag pulse following: https://docs.quantum.ibm.com/api/qiskit/qiskit.pulse.library.Drag_class.rst#drag

    Args:
        amp (float | jnp.ndarray): The amplitude of the pulse
        sigma (float | jnp.ndarray): The standard deviation of the pulse
        beta (float | jnp.ndarray): DRAG coefficient.
        center (float | jnp.ndarray): Center location of the pulse
        final_amp (float | jnp.ndarray, optional): Final amplitude of the control. Defaults to 1.0.

    Returns:
        typing.Callable: DRAG envelope function
    """

    def g(t):
        return jnp.exp(-((t - center) ** 2) / (2 * sigma**2))

    def g_prime(t):
        return amp * (g(t) - g(-1)) / (1 - g(-1))

    def envelop(t):
        return final_amp * g_prime(t) * (1 + 1j * beta * (t - center) / sigma**2)

    return envelop

detune_hamiltonian

detune_hamiltonian(
    hamiltonian: Callable[
        [HamiltonianArgs, ndarray], ndarray
    ],
    detune: float,
) -> Callable[[HamiltonianArgs, ndarray], ndarray]

Detune the Hamiltonian in Z-axis with detuning coefficient

Parameters:

Name Type Description Default
hamiltonian Callable[[HamiltonianArgs, ndarray], ndarray]

Hamiltonian function to be detuned

required
detune float

Detuning coefficient

required

Returns:

Type Description
Callable[[HamiltonianArgs, ndarray], ndarray]

typing.Callable: Detuned Hamiltonian.

Source code in src/inspeqtor/v1/utils.py
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
@warn_not_tested_function
def detune_hamiltonian(
    hamiltonian: typing.Callable[[HamiltonianArgs, jnp.ndarray], jnp.ndarray],
    detune: float,
) -> typing.Callable[[HamiltonianArgs, jnp.ndarray], jnp.ndarray]:
    """Detune the Hamiltonian in Z-axis with detuning coefficient

    Args:
        hamiltonian (typing.Callable[[HamiltonianArgs, jnp.ndarray], jnp.ndarray]): Hamiltonian function to be detuned
        detune (float): Detuning coefficient

    Returns:
        typing.Callable: Detuned Hamiltonian.

    """

    def detuned_hamiltonian(
        params: HamiltonianArgs,
        t: jnp.ndarray,
        *args,
        **kwargs,
    ) -> jnp.ndarray:
        return hamiltonian(params, t, *args, **kwargs) + detune * Z

    return detuned_hamiltonian

prepare_data

prepare_data(
    exp_data: ExperimentData,
    control_sequence: ControlSequence,
    whitebox: Callable,
) -> LoadedData

Prepare the data for easy accessing from experiment data, control sequence, and Whitebox.

Parameters:

Name Type Description Default
exp_data ExperimentData

ExperimentData instance

required
control_sequence ControlSequence

Control sequence of the experiment

required
whitebox Callable

Ideal unitary solver.

required

Returns:

Name Type Description
LoadedData LoadedData

LoadedData instance

Source code in src/inspeqtor/v1/utils.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
def prepare_data(
    exp_data: ExperimentData,
    control_sequence: ControlSequence,
    whitebox: typing.Callable,
) -> LoadedData:
    """Prepare the data for easy accessing from experiment data, control sequence, and Whitebox.

    Args:
        exp_data (ExperimentData): `ExperimentData` instance
        control_sequence (ControlSequence): Control sequence of the experiment
        whitebox (typing.Callable): Ideal unitary solver.

    Returns:
        LoadedData: `LoadedData` instance
    """
    logging.info(f"Loaded data from {exp_data.experiment_config.EXPERIMENT_IDENTIFIER}")

    control_parameters = jnp.array(exp_data.parameters)
    # * Attempt to reshape the control_parameters to (size, features)
    if len(control_parameters.shape) == 3:
        control_parameters = control_parameters.reshape(
            control_parameters.shape[0],
            control_parameters.shape[1] * control_parameters.shape[2],
        )

    expectation_values = jnp.array(exp_data.get_expectation_values())
    unitaries = jax.vmap(whitebox)(control_parameters)

    logging.info(
        f"Finished preparing the data for the experiment {exp_data.experiment_config.EXPERIMENT_IDENTIFIER}"
    )

    return LoadedData(
        experiment_data=exp_data,
        control_parameters=control_parameters,
        unitaries=unitaries[:, -1, :, :],
        observed_values=expectation_values,
        control_sequence=control_sequence,
        whitebox=whitebox,
    )

random_split

random_split(
    key: ndarray, test_size: int, *data_arrays: ndarray
)

The random_split function splits the data into training and testing sets.

Examples:

>>> key = jax.random.key(0)
>>> x = jnp.arange(10)
>>> y = jnp.arange(10)
>>> x_train, y_train, x_test, y_test = random_split(key, 2, x, y)
>>> assert x_train.shape[0] == 8 and y_train.shape[0] == 8
>>> assert x_test.shape[0] == 2 and y_test.shape[0] == 2

Parameters:

Name Type Description Default
key ndarray

Random key.

required
test_size int

The size of the test set. Must be less than the size of the data.

required

Returns:

Type Description

typing.Sequence[jnp.ndarray]: The training and testing sets in the same order.

Source code in src/inspeqtor/v1/utils.py
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
def random_split(key: jnp.ndarray, test_size: int, *data_arrays: jnp.ndarray):
    """The random_split function splits the data into training and testing sets.

    Examples:
        >>> key = jax.random.key(0)
        >>> x = jnp.arange(10)
        >>> y = jnp.arange(10)
        >>> x_train, y_train, x_test, y_test = random_split(key, 2, x, y)
        >>> assert x_train.shape[0] == 8 and y_train.shape[0] == 8
        >>> assert x_test.shape[0] == 2 and y_test.shape[0] == 2

    Args:
        key (jnp.ndarray): Random key.
        test_size (int): The size of the test set. Must be less than the size of the data.

    Returns:
        typing.Sequence[jnp.ndarray]: The training and testing sets in the same order.
    """
    # * General random split
    idx = jax.random.permutation(key, data_arrays[0].shape[0])
    train_data = []
    test_data = []

    for data in data_arrays:
        train_data.append(data[idx][test_size:])
        test_data.append(data[idx][:test_size])

    return (*train_data, *test_data)

dataloader

dataloader(
    arrays: Sequence[ndarray],
    batch_size: int,
    num_epochs: int,
    *,
    key: ndarray,
)

The dataloader function creates a generator that yields batches of data.

Parameters:

Name Type Description Default
arrays Sequence[ndarray]

The list or tuple of arrays to be batched.

required
batch_size int

The size of the batch.

required
num_epochs int

The number of epochs. If set to -1, the generator will run indefinitely.

required
key ndarray

The random key.

required

Returns:

Name Type Description
None

stop the generator.

Yields:

Type Description

typing.Any: (step, batch_idx, is_last_batch, epoch_idx), (array_batch, ...)

Source code in src/inspeqtor/v1/utils.py
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
def dataloader(
    arrays: typing.Sequence[jnp.ndarray],
    batch_size: int,
    num_epochs: int,
    *,
    key: jnp.ndarray,
):
    """The dataloader function creates a generator that yields batches of data.

    Args:
        arrays (typing.Sequence[jnp.ndarray]): The list or tuple of arrays to be batched.
        batch_size (int): The size of the batch.
        num_epochs (int): The number of epochs. If set to -1, the generator will run indefinitely.
        key (jnp.ndarray): The random key.

    Returns:
        None: stop the generator.

    Yields:
        typing.Any: (step, batch_idx, is_last_batch, epoch_idx), (array_batch, ...)
    """
    # * General dataloader
    # Check that all arrays have the same size in the first dimension
    dataset_size = arrays[0].shape[0]
    # assert all(array.shape[0] == dataset_size for array in arrays)
    # Generate random indices
    indices = jnp.arange(dataset_size)
    step = 0
    epoch_idx = 0
    while True:
        if epoch_idx == num_epochs:
            return None
        perm = jax.random.permutation(key, indices)
        (key,) = jax.random.split(key, 1)
        batch_idx = 0
        start = 0
        end = batch_size
        is_last_batch = False
        while not is_last_batch:
            batch_perm = perm[start:end]
            # Check if this is the last batch
            is_last_batch = end >= dataset_size
            yield (
                (step, batch_idx, is_last_batch, epoch_idx),
                tuple(array[batch_perm] for array in arrays),
            )
            start = end
            end = start + batch_size
            step += 1
            batch_idx += 1

        epoch_idx += 1

expectation_value_to_prob_plus

expectation_value_to_prob_plus(
    expectation_value: ndarray,
) -> ndarray

Calculate the probability of -1 and 1 for the given expectation value E[O] = -1 * P[O = -1] + 1 * P[O = 1], where P[O = -1] + P[O = 1] = 1 Thus, E[O] = -1 * (1 - P[O = 1]) + 1 * P[O = 1] E[O] = 2 * P[O = 1] - 1 -> P[O = 1] = (E[O] + 1) / 2 Args: expectation_value (jnp.ndarray): Expectation value of quantum observable

Returns:

Type Description
ndarray

jnp.ndarray: Probability of measuring plus eigenvector

Source code in src/inspeqtor/v1/utils.py
269
270
271
272
273
274
275
276
277
278
279
280
281
282
def expectation_value_to_prob_plus(expectation_value: jnp.ndarray) -> jnp.ndarray:
    """
    Calculate the probability of -1 and 1 for the given expectation value
    E[O] = -1 * P[O = -1] + 1 * P[O = 1], where P[O = -1] + P[O = 1] = 1
    Thus, E[O] = -1 * (1 - P[O = 1]) + 1 * P[O = 1]
    E[O] = 2 * P[O = 1] - 1 -> P[O = 1] = (E[O] + 1) / 2
    Args:
        expectation_value (jnp.ndarray): Expectation value of quantum observable

    Returns:
        jnp.ndarray: Probability of measuring plus eigenvector
    """

    return (expectation_value + 1) / 2

expectation_value_to_prob_minus

expectation_value_to_prob_minus(
    expectation_value: ndarray,
) -> ndarray

Convert quantum observable expectation value to probability of measuring -1.

For a binary quantum observable \(\hat{O}\) with eigenvalues \(b = \{-1, 1\}\), this function calculates the probability of measuring the eigenvalue -1 given its expectation value.

Derivation: $$ \langle \hat{O} \rangle = -1 \cdot \Pr(b=-1) + 1 \cdot \Pr(b = 1) $$ With the constraint \(\Pr(b = -1) + \Pr(b = 1) = 1\):

\[ \langle \hat{O} \rangle = -1 \cdot \Pr(b=-1) + 1 \cdot (1 - \Pr(b=-1)) \ \langle \hat{O} \rangle = -\Pr(b=-1) + 1 - \Pr(b=-1) \ \langle \hat{O} \rangle = 1 - 2\Pr(b=-1) \ \Pr(b=-1) = \frac{1 - \langle \hat{O} \rangle}{2} \]

Parameters:

Name Type Description Default
expectation_value ndarray

Expectation value of the quantum observable, must be in range [-1, 1].

required

Returns:

Type Description
ndarray

jnp.ndarray: Probability of measuring the -1 eigenvalue.

Source code in src/inspeqtor/v1/utils.py
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
def expectation_value_to_prob_minus(expectation_value: jnp.ndarray) -> jnp.ndarray:
    """Convert quantum observable expectation value to probability of measuring -1.

    For a binary quantum observable $\\hat{O}$ with eigenvalues $b = \\{-1, 1\\}$, this function
    calculates the probability of measuring the eigenvalue -1 given its expectation value.

    Derivation:
    $$
        \\langle \\hat{O} \\rangle = -1 \\cdot \\Pr(b=-1) + 1 \\cdot \\Pr(b = 1)
    $$
        With the constraint $\\Pr(b = -1) + \\Pr(b = 1) = 1$:

    $$
        \\langle \\hat{O} \\rangle = -1 \\cdot \\Pr(b=-1) + 1 \\cdot (1 - \\Pr(b=-1)) \\
        \\langle \\hat{O} \\rangle = -\\Pr(b=-1) + 1 - \\Pr(b=-1) \\
        \\langle \\hat{O} \\rangle = 1 - 2\\Pr(b=-1) \\
        \\Pr(b=-1) = \\frac{1 - \\langle \\hat{O} \\rangle}{2}
    $$

    Args:
        expectation_value (jnp.ndarray): Expectation value of the quantum observable,
            must be in range [-1, 1].

    Returns:
        jnp.ndarray: Probability of measuring the -1 eigenvalue.
    """
    return (1 - expectation_value) / 2

expectation_value_to_eigenvalue

expectation_value_to_eigenvalue(
    expectation_value: ndarray, SHOTS: int
) -> ndarray

Convert expectation value to eigenvalue

Parameters:

Name Type Description Default
expectation_value ndarray

Expectation value of quantum observable

required
SHOTS int

The number of shots used to produce expectation value

required

Returns:

Type Description
ndarray

jnp.ndarray: Array of eigenvalues

Source code in src/inspeqtor/v1/utils.py
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
def expectation_value_to_eigenvalue(
    expectation_value: jnp.ndarray, SHOTS: int
) -> jnp.ndarray:
    """Convert expectation value to eigenvalue

    Args:
        expectation_value (jnp.ndarray): Expectation value of quantum observable
        SHOTS (int): The number of shots used to produce expectation value

    Returns:
        jnp.ndarray: Array of eigenvalues
    """
    return jnp.where(
        jnp.broadcast_to(jnp.arange(SHOTS), expectation_value.shape + (SHOTS,))
        < jnp.around(
            expectation_value_to_prob_plus(
                jnp.reshape(expectation_value, expectation_value.shape + (1,))
            )
            * SHOTS
        ).astype(jnp.int32),
        1,
        -1,
    ).astype(jnp.int32)

eigenvalue_to_binary

eigenvalue_to_binary(eigenvalue: ndarray) -> ndarray

Convert -1 to 1, and 0 to 1 This implementation should be differentiable

Parameters:

Name Type Description Default
eigenvalue ndarray

Eigenvalue to convert to bit value

required

Returns:

Type Description
ndarray

jnp.ndarray: Binary array

Source code in src/inspeqtor/v1/utils.py
339
340
341
342
343
344
345
346
347
348
349
350
def eigenvalue_to_binary(eigenvalue: jnp.ndarray) -> jnp.ndarray:
    """Convert -1 to 1, and 0 to 1
    This implementation should be differentiable

    Args:
        eigenvalue (jnp.ndarray): Eigenvalue to convert to bit value

    Returns:
        jnp.ndarray: Binary array
    """

    return (-1 * eigenvalue + 1) / 2

binary_to_eigenvalue

binary_to_eigenvalue(binary: ndarray) -> ndarray

Convert 1 to -1, and 0 to 1 This implementation should be differentiable

Parameters:

Name Type Description Default
binary ndarray

Bit value to convert to eigenvalue

required

Returns:

Type Description
ndarray

jnp.ndarray: Eigenvalue array

Source code in src/inspeqtor/v1/utils.py
353
354
355
356
357
358
359
360
361
362
363
364
def binary_to_eigenvalue(binary: jnp.ndarray) -> jnp.ndarray:
    """Convert 1 to -1, and 0 to 1
    This implementation should be differentiable

    Args:
        binary (jnp.ndarray): Bit value to convert to eigenvalue

    Returns:
        jnp.ndarray: Eigenvalue array
    """

    return -1 * (binary * 2 - 1)

recursive_vmap

recursive_vmap(func, in_axes)

Perform recursive vmap on the given axis

Note
def func(x):
    assert x.ndim == 1
    return x ** 2
x = jnp.arange(10)
x_test = jnp.broadcast_to(x, (2, 3, 4,) + x.shape)
x_test.shape, recursive_vmap(func, (0,) * (x_test.ndim - 1))(x_test).shape
((2, 3, 4, 10), (2, 3, 4, 10))

Examples:

>>> def func(x):
...     assert x.ndim == 1
...     return x ** 2
>>> x = jnp.arange(10)
>>> x_test = jnp.broadcast_to(x, (2, 3, 4,) + x.shape)
>>> x_test.shape, recursive_vmap(func, (0,) * (x_test.ndim - 1))(x_test).shape
((2, 3, 4, 10), (2, 3, 4, 10))

Parameters:

Name Type Description Default
func Any

The function for vmap

required
in_axes Any

The axes for vmap

required

Returns:

Type Description

typing.Any: description

Source code in src/inspeqtor/v1/utils.py
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
def recursive_vmap(func, in_axes):
    """Perform recursive vmap on the given axis

    Note:
        ```python
        def func(x):
            assert x.ndim == 1
            return x ** 2
        x = jnp.arange(10)
        x_test = jnp.broadcast_to(x, (2, 3, 4,) + x.shape)
        x_test.shape, recursive_vmap(func, (0,) * (x_test.ndim - 1))(x_test).shape
        ((2, 3, 4, 10), (2, 3, 4, 10))
        ```

    Examples:
        >>> def func(x):
        ...     assert x.ndim == 1
        ...     return x ** 2
        >>> x = jnp.arange(10)
        >>> x_test = jnp.broadcast_to(x, (2, 3, 4,) + x.shape)
        >>> x_test.shape, recursive_vmap(func, (0,) * (x_test.ndim - 1))(x_test).shape
        ((2, 3, 4, 10), (2, 3, 4, 10))

    Args:
        func (typing.Any): The function for vmap
        in_axes (typing.Any): The axes for vmap

    Returns:
        typing.Any: _description_
    """
    if not in_axes:
        # Base case: no more axes to vectorize over
        return func

    # Apply vmap over the first axis specified in in_axes
    vmap_func = jax.vmap(func, in_axes=in_axes[0])

    # Recursively apply vmap over the remaining axes
    return recursive_vmap(vmap_func, in_axes[1:])

calculate_shots_expectation_value

calculate_shots_expectation_value(
    key: ndarray,
    initial_state: ndarray,
    unitary: ndarray,
    operator: ndarray,
    shots: int,
) -> ndarray

Calculate finite-shots estimate of expectation value

Parameters:

Name Type Description Default
key ndarray

Random key

required
initial_state ndarray

Inital state

required
unitary ndarray

Unitary operator

required
plus_projector ndarray

The eigenvector corresponded to +1 eigenvalue of Pauli observable.

required
shots int

Number of shot to be used in estimation of expectation value

required

Returns:

Type Description
ndarray

jnp.ndarray: Finite-shot estimate expectation value

Source code in src/inspeqtor/v1/utils.py
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
def calculate_shots_expectation_value(
    key: jnp.ndarray,
    initial_state: jnp.ndarray,
    unitary: jnp.ndarray,
    operator: jnp.ndarray,
    shots: int,
) -> jnp.ndarray:
    """Calculate finite-shots estimate of expectation value

    Args:
        key (jnp.ndarray): Random key
        initial_state (jnp.ndarray): Inital state
        unitary (jnp.ndarray): Unitary operator
        plus_projector (jnp.ndarray): The eigenvector corresponded to +1 eigenvalue of Pauli observable.
        shots (int): Number of shot to be used in estimation of expectation value

    Returns:
        jnp.ndarray: Finite-shot estimate expectation value
    """
    expval = jnp.trace(unitary @ initial_state @ unitary.conj().T @ operator).real
    prob = expectation_value_to_prob_plus(expval)

    return jax.random.choice(
        key, jnp.array([1, -1]), shape=(shots,), p=jnp.array([prob, 1 - prob])
    ).mean()

shot_quantum_device

shot_quantum_device(
    key: ndarray,
    control_parameters: ndarray,
    solver: Callable[[ndarray], ndarray],
    SHOTS: int,
    expectation_value_receipt: Sequence[
        ExpectationValue
    ] = default_expectation_values_order,
) -> ndarray

This is the shot estimate expectation value quantum device

Parameters:

Name Type Description Default
control_parameters ndarray

The control parameter to be feed to simlulator

required
key ndarray

Random key

required
solver Callable[[ndarray], ndarray]

The ODE solver for propagator

required
SHOTS int

The number of shots used to estimate expectation values

required

Returns:

Type Description
ndarray

jnp.ndarray: The expectation value of shape (control_parameters.shape[0], 18)

Source code in src/inspeqtor/v1/utils.py
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
def shot_quantum_device(
    key: jnp.ndarray,
    control_parameters: jnp.ndarray,
    solver: typing.Callable[[jnp.ndarray], jnp.ndarray],
    SHOTS: int,
    expectation_value_receipt: typing.Sequence[
        ExpectationValue
    ] = default_expectation_values_order,
) -> jnp.ndarray:
    """This is the shot estimate expectation value quantum device

    Args:
        control_parameters (jnp.ndarray): The control parameter to be feed to simlulator
        key (jnp.ndarray): Random key
        solver (typing.Callable[[jnp.ndarray], jnp.ndarray]): The ODE solver for propagator
        SHOTS (int): The number of shots used to estimate expectation values

    Returns:
        jnp.ndarray: The expectation value of shape (control_parameters.shape[0], 18)
    """

    expectation_values = jnp.zeros((control_parameters.shape[0], 18))
    unitaries = jax.vmap(solver)(control_parameters)[:, -1, :, :]

    for idx, exp in enumerate(expectation_value_receipt):
        key, sample_key = jax.random.split(key)
        sample_keys = jax.random.split(sample_key, num=unitaries.shape[0])

        expectation_value = jax.vmap(
            calculate_shots_expectation_value,
            in_axes=(0, None, 0, None, None),
        )(
            sample_keys,
            exp.initial_density_matrix,
            unitaries,
            exp.observable_matrix,
            SHOTS,
        )

        expectation_values = expectation_values.at[..., idx].set(expectation_value)

    return expectation_values

Visualization

src.inspeqtor.v1.visualization

format_expectation_values

format_expectation_values(
    expvals: ndarray,
) -> dict[str, dict[str, ndarray]]

This function formats expectation values of shape (18, N) to a dictionary with the initial state as outer key and the observable as inner key.

Parameters:

Name Type Description Default
expvals ndarray

Expectation values of shape (18, N). Assumes that order is as in default_expectation_values_order.

required

Returns:

Type Description
dict[str, dict[str, ndarray]]

dict[str, dict[str, jnp.ndarray]]: A dictionary with the initial state as outer key and the observable as inner key.

Source code in src/inspeqtor/v1/visualization.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def format_expectation_values(
    expvals: jnp.ndarray,
) -> dict[str, dict[str, jnp.ndarray]]:
    """This function formats expectation values of shape (18, N) to a dictionary
    with the initial state as outer key and the observable as inner key.

    Args:
        expvals (jnp.ndarray): Expectation values of shape (18, N). Assumes that order is as in default_expectation_values_order.

    Returns:
        dict[str, dict[str, jnp.ndarray]]: A dictionary with the initial state as outer key and the observable as inner key.
    """
    expvals_dict: dict[str, dict[str, jnp.ndarray]] = {}
    for idx, exp in enumerate(default_expectation_values_order):
        if exp.initial_state not in expvals_dict:
            expvals_dict[exp.initial_state] = {}

        expvals_dict[exp.initial_state][exp.observable] = expvals[idx]

    return expvals_dict

plot_loss_with_moving_average

plot_loss_with_moving_average(
    x: ndarray | ndarray,
    y: ndarray | ndarray,
    ax: Axes,
    window: int = 50,
    annotate_at: list[float] = [0.2, 0.4, 0.6, 0.8, 1.0],
    **kwargs,
) -> Axes

Plot the moving average of the given argument y

Parameters:

Name Type Description Default
x ndarray | ndarray

The horizontal axis

required
y ndarray | ndarray

The vertical axis

required
ax Axes

Axes object

required
window int

The moving average window. Defaults to 50.

50
annotate_at list[int]

The list of x positions to annotate the y value. Defaults to [2000, 4000, 6000, 8000, 10000].

[0.2, 0.4, 0.6, 0.8, 1.0]

Returns:

Name Type Description
Axes Axes

Axes object.

Source code in src/inspeqtor/v1/visualization.py
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
def plot_loss_with_moving_average(
    x: jnp.ndarray | np.ndarray,
    y: jnp.ndarray | np.ndarray,
    ax: Axes,
    window: int = 50,
    annotate_at: list[float] = [0.2, 0.4, 0.6, 0.8, 1.0],
    **kwargs,
) -> Axes:
    """Plot the moving average of the given argument y

    Args:
        x (jnp.ndarray | np.ndarray): The horizontal axis
        y (jnp.ndarray | np.ndarray): The vertical axis
        ax (Axes): Axes object
        window (int, optional): The moving average window. Defaults to 50.
        annotate_at (list[int], optional): The list of x positions to annotate the y value. Defaults to [2000, 4000, 6000, 8000, 10000].

    Returns:
        Axes: Axes object.
    """
    moving_average = pd.Series(np.asarray(y)).rolling(window=window).mean()

    ax.plot(
        x,
        moving_average,
        **kwargs,
    )

    for percentile in annotate_at:
        # Calculate the data index that corresponds to the percentile
        idx = int(percentile * (len(x) - 1))

        loss_value = moving_average[idx]

        # Skip annotation if the moving average value is not available (e.g., at the beginning)
        if pd.isna(loss_value):
            continue

        ax.annotate(
            f"{loss_value:.3g}",
            xy=(float(x[idx].item()), float(loss_value)),
            xytext=(-10, 10),  # Offset the text for better readability
            textcoords="offset points",
            ha="center",
            va="bottom",
        )

    return ax

assert_list_of_axes

assert_list_of_axes(axes) -> list[Axes]

Assert the provide object that they are a list of Axes

Parameters:

Name Type Description Default
axes Any

Expected to be numpy array of Axes

required

Returns:

Type Description
list[Axes]

list[Axes]: The list of Axes

Source code in src/inspeqtor/v1/visualization.py
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def assert_list_of_axes(axes) -> list[Axes]:
    """Assert the provide object that they are a list of Axes

    Args:
        axes (typing.Any): Expected to be numpy array of Axes

    Returns:
        list[Axes]: The list of Axes
    """
    assert isinstance(axes, np.ndarray)
    axes = axes.flatten()

    for ax in axes:
        assert isinstance(ax, Axes)
    return axes.tolist()

set_fontsize

set_fontsize(ax: Axes, fontsize: float | int)

Set all fontsize of the Axes object

Parameters:

Name Type Description Default
ax Axes

The Axes object which fontsize to be changed.

required
fontsize float | int

The fontsize.

required
Source code in src/inspeqtor/v1/visualization.py
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
def set_fontsize(ax: Axes, fontsize: float | int):
    """Set all fontsize of the Axes object

    Args:
        ax (Axes): The Axes object which fontsize to be changed.
        fontsize (float | int): The fontsize.
    """
    for item in (
        [ax.title, ax.xaxis.label, ax.yaxis.label]
        + ax.get_xticklabels()
        + ax.get_yticklabels()
    ):
        item.set_fontsize(fontsize)

    legend, handles = ax.get_legend_handles_labels()

    ax.legend(legend, handles, fontsize=fontsize)