Skip to content

Models - Library - NNX

Flax NNX-based neural network models for quantum device characterization.

inspeqtor.models.library.nnx

inspeqtor.models.library.nnx.WoModel

\(\hat{W}_{O}\) based blackbox model.

Source code in src/inspeqtor/v1/models/nnx.py
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
class WoModel(Blackbox):
    """$\\hat{W}_{O}$ based blackbox model."""

    def __init__(
        self, shared_layers: list[int], pauli_layers: list[int], *, rngs: nnx.Rngs
    ):
        """
        Args:
            shared_layers (list[int]): Each integer in the list is a size of the width of each hidden layer in the shared layers.
            pauli_layers (list[int]): Each integer in the list is a size of the width of each hidden layer in the Pauli layers.
            rngs (nnx.Rngs): Random number generator of `nnx`.
        """

        self.shared_layers = nnx.Dict(
            {
                f"shared/{idx}": nnx.Linear(
                    in_features=in_features, out_features=out_features, rngs=rngs
                )
                for idx, (in_features, out_features) in enumerate(
                    zip(shared_layers[:-1], shared_layers[1:])
                )
            }
        )

        self.num_shared_layers = len(shared_layers) - 1
        self.num_pauli_layers = len(pauli_layers) - 1

        self.pauli_layers = nnx.Dict()
        self.unitary_layers = nnx.Dict()
        self.diagonal_layers = nnx.Dict()
        for pauli in ["X", "Y", "Z"]:
            layers = nnx.Dict(
                {
                    f"pauli/{idx}": nnx.Linear(
                        in_features=in_features, out_features=out_features, rngs=rngs
                    )
                    for idx, (in_features, out_features) in enumerate(
                        zip(pauli_layers[:-1], pauli_layers[1:])
                    )
                }
            )

            self.pauli_layers[pauli] = layers

            self.unitary_layers[pauli] = nnx.Linear(
                in_features=pauli_layers[-1], out_features=3, rngs=rngs
            )
            self.diagonal_layers[pauli] = nnx.Linear(
                in_features=pauli_layers[-1], out_features=2, rngs=rngs
            )

    def __call__(self, x: jnp.ndarray):
        for idx in range(self.num_shared_layers):
            layer = self.shared_layers[f"shared/{idx}"]
            x = nnx.relu(layer(x))

        observables: dict[str, jnp.ndarray] = dict()
        for pauli, pauli_layer in self.pauli_layers.items():
            _x = jnp.copy(x)
            for idx in range(self.num_pauli_layers):
                layer = pauli_layer[f"pauli/{idx}"]
                _x = nnx.relu(layer(_x))

            unitary_param = self.unitary_layers[pauli](_x)
            diagonal_param = self.diagonal_layers[pauli](_x)

            unitary_param = 2 * jnp.pi * nnx.hard_sigmoid(unitary_param)
            diagonal_param = (2 * nnx.hard_sigmoid(diagonal_param)) - 1

            observables[pauli] = hermitian(unitary_param, diagonal_param)

        return observables

__init__

__init__(
    shared_layers: list[int],
    pauli_layers: list[int],
    *,
    rngs: Rngs,
)

Parameters:

Name Type Description Default
shared_layers list[int]

Each integer in the list is a size of the width of each hidden layer in the shared layers.

required
pauli_layers list[int]

Each integer in the list is a size of the width of each hidden layer in the Pauli layers.

required
rngs Rngs

Random number generator of nnx.

required
Source code in src/inspeqtor/v1/models/nnx.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def __init__(
    self, shared_layers: list[int], pauli_layers: list[int], *, rngs: nnx.Rngs
):
    """
    Args:
        shared_layers (list[int]): Each integer in the list is a size of the width of each hidden layer in the shared layers.
        pauli_layers (list[int]): Each integer in the list is a size of the width of each hidden layer in the Pauli layers.
        rngs (nnx.Rngs): Random number generator of `nnx`.
    """

    self.shared_layers = nnx.Dict(
        {
            f"shared/{idx}": nnx.Linear(
                in_features=in_features, out_features=out_features, rngs=rngs
            )
            for idx, (in_features, out_features) in enumerate(
                zip(shared_layers[:-1], shared_layers[1:])
            )
        }
    )

    self.num_shared_layers = len(shared_layers) - 1
    self.num_pauli_layers = len(pauli_layers) - 1

    self.pauli_layers = nnx.Dict()
    self.unitary_layers = nnx.Dict()
    self.diagonal_layers = nnx.Dict()
    for pauli in ["X", "Y", "Z"]:
        layers = nnx.Dict(
            {
                f"pauli/{idx}": nnx.Linear(
                    in_features=in_features, out_features=out_features, rngs=rngs
                )
                for idx, (in_features, out_features) in enumerate(
                    zip(pauli_layers[:-1], pauli_layers[1:])
                )
            }
        )

        self.pauli_layers[pauli] = layers

        self.unitary_layers[pauli] = nnx.Linear(
            in_features=pauli_layers[-1], out_features=3, rngs=rngs
        )
        self.diagonal_layers[pauli] = nnx.Linear(
            in_features=pauli_layers[-1], out_features=2, rngs=rngs
        )

inspeqtor.models.library.nnx.UnitaryModel

Unitary-based model, predicting parameters parametrized unitary operator in range \([0, 2\pi]\).

Source code in src/inspeqtor/v1/models/nnx.py
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
class UnitaryModel(Blackbox):
    """Unitary-based model, predicting parameters parametrized unitary operator in range $[0, 2\\pi]$."""

    def __init__(self, hidden_sizes: list[int], *, rngs: nnx.Rngs) -> None:
        """

        Args:
            hidden_sizes (list[int]): Each integer in the list is a size of the width of each hidden layer in the shared layers
            rngs (nnx.Rngs): Random number generator of `nnx`.
        """
        self.hidden_sizes = hidden_sizes
        self.NUM_UNITARY_PARAMS = 4

        self.hidden_layers = nnx.Dict(
            {
                f"hidden_layers/{idx}": nnx.Linear(
                    in_features=hidden_size, out_features=hidden_size, rngs=rngs
                )
                for idx, hidden_size in enumerate(self.hidden_sizes)
            }
        )

        # Initialize the final layer for unitary parameters
        self.final_layer = nnx.Linear(
            in_features=self.hidden_sizes[-1],
            out_features=self.NUM_UNITARY_PARAMS,
            rngs=rngs,
        )

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        # Apply the hidden layers with ReLU activation
        for _, layer in self.hidden_layers.items():
            x = nnx.relu(layer(x))

        # Apply the final layer and transform the output
        x = self.final_layer(x)
        x = 2 * jnp.pi * nnx.hard_sigmoid(x)

        return x

__init__

__init__(hidden_sizes: list[int], *, rngs: Rngs) -> None

Parameters:

Name Type Description Default
hidden_sizes list[int]

Each integer in the list is a size of the width of each hidden layer in the shared layers

required
rngs Rngs

Random number generator of nnx.

required
Source code in src/inspeqtor/v1/models/nnx.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
def __init__(self, hidden_sizes: list[int], *, rngs: nnx.Rngs) -> None:
    """

    Args:
        hidden_sizes (list[int]): Each integer in the list is a size of the width of each hidden layer in the shared layers
        rngs (nnx.Rngs): Random number generator of `nnx`.
    """
    self.hidden_sizes = hidden_sizes
    self.NUM_UNITARY_PARAMS = 4

    self.hidden_layers = nnx.Dict(
        {
            f"hidden_layers/{idx}": nnx.Linear(
                in_features=hidden_size, out_features=hidden_size, rngs=rngs
            )
            for idx, hidden_size in enumerate(self.hidden_sizes)
        }
    )

    # Initialize the final layer for unitary parameters
    self.final_layer = nnx.Linear(
        in_features=self.hidden_sizes[-1],
        out_features=self.NUM_UNITARY_PARAMS,
        rngs=rngs,
    )

inspeqtor.models.library.nnx.UnitarySPAMModel

Composite class of unitary-based model and the SPAM model.

Source code in src/inspeqtor/v1/models/nnx.py
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
class UnitarySPAMModel(Blackbox):
    """Composite class of unitary-based model and the SPAM model."""

    def __init__(
        self, unitary_model: UnitaryModel, spam_params, *, rngs: nnx.Rngs
    ) -> None:
        """

        Args:
            unitary_model (UnitaryModel): Unitary-based model that have already initialized.
            rngs (nnx.Rngs): Random number generator of `nnx`.
        """
        self.unitary_model = unitary_model
        self.spam_params = jax.tree.map(nnx.Param, spam_params)

    def __call__(self, x: jnp.ndarray) -> dict:
        return {"model_params": self.unitary_model(x), "spam_params": self.spam_params}

__init__

__init__(
    unitary_model: UnitaryModel, spam_params, *, rngs: Rngs
) -> None

Parameters:

Name Type Description Default
unitary_model UnitaryModel

Unitary-based model that have already initialized.

required
rngs Rngs

Random number generator of nnx.

required
Source code in src/inspeqtor/v1/models/nnx.py
169
170
171
172
173
174
175
176
177
178
179
def __init__(
    self, unitary_model: UnitaryModel, spam_params, *, rngs: nnx.Rngs
) -> None:
    """

    Args:
        unitary_model (UnitaryModel): Unitary-based model that have already initialized.
        rngs (nnx.Rngs): Random number generator of `nnx`.
    """
    self.unitary_model = unitary_model
    self.spam_params = jax.tree.map(nnx.Param, spam_params)

inspeqtor.models.library.nnx.train_model

train_model(
    key: ndarray,
    train_data: DataBundled,
    val_data: DataBundled,
    test_data: DataBundled,
    model: Blackbox,
    optimizer: GradientTransformation,
    loss_fn: Callable,
    callbacks: list[Callable] = [],
    NUM_EPOCH: int = 1000,
    _optimizer: Optimizer | None = None,
)

Train the BlackBox model

Examples:

>>> # The number of epochs break down
... NUM_EPOCH = 150
... # Total number of iterations as 90% of data is used for training
... # 10% of the data is used for testing
... total_iterations = 9 * NUM_EPOCH
... # The step for optimizer if set to 8 * NUM_EPOCH (should be less than total_iterations)
... step_for_optimizer = 8 * NUM_EPOCH
... optimizer = get_default_optimizer(step_for_optimizer)
... # The warmup steps for the optimizer
... warmup_steps = 0.1 * step_for_optimizer
... # The cool down steps for the optimizer
... cool_down_steps = total_iterations - step_for_optimizer
... total_iterations, step_for_optimizer, warmup_steps, cool_down_steps

Parameters:

Name Type Description Default
key ndarray

Random key

required
model Module

The model to be used for training

required
optimizer GradientTransformation

The optimizer to be used for training

required
loss_fn Callable

The loss function to be used for training

required
callbacks list[Callable]

list of callback functions. Defaults to [].

[]
NUM_EPOCH int

The number of epochs. Defaults to 1_000.

1000

Returns:

Name Type Description
tuple

The model parameters, optimizer state, and the histories

Source code in src/inspeqtor/v1/models/nnx.py
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
def train_model(
    # Random key
    key: jnp.ndarray,
    # Data
    train_data: DataBundled,
    val_data: DataBundled,
    test_data: DataBundled,
    # Model to be used for training
    model: Blackbox,
    optimizer: optax.GradientTransformation,
    # Loss function to be used
    loss_fn: typing.Callable,
    # Callbacks to be used
    callbacks: list[typing.Callable] = [],
    # Number of epochs
    NUM_EPOCH: int = 1_000,
    _optimizer: nnx.Optimizer | None = None,
):
    """Train the BlackBox model

    Examples:
        >>> # The number of epochs break down
        ... NUM_EPOCH = 150
        ... # Total number of iterations as 90% of data is used for training
        ... # 10% of the data is used for testing
        ... total_iterations = 9 * NUM_EPOCH
        ... # The step for optimizer if set to 8 * NUM_EPOCH (should be less than total_iterations)
        ... step_for_optimizer = 8 * NUM_EPOCH
        ... optimizer = get_default_optimizer(step_for_optimizer)
        ... # The warmup steps for the optimizer
        ... warmup_steps = 0.1 * step_for_optimizer
        ... # The cool down steps for the optimizer
        ... cool_down_steps = total_iterations - step_for_optimizer
        ... total_iterations, step_for_optimizer, warmup_steps, cool_down_steps

    Args:
        key (jnp.ndarray): Random key
        model (nn.Module): The model to be used for training
        optimizer (optax.GradientTransformation): The optimizer to be used for training
        loss_fn (typing.Callable): The loss function to be used for training
        callbacks (list[typing.Callable], optional): list of callback functions. Defaults to [].
        NUM_EPOCH (int, optional): The number of epochs. Defaults to 1_000.

    Returns:
        tuple: The model parameters, optimizer state, and the histories
    """

    key, loader_key = jax.random.split(key)

    BATCH_SIZE = val_data.control_params.shape[0]

    histories: list[HistoryEntryV3] = []

    if _optimizer is None:
        _optimizer = nnx.Optimizer(
            model,
            optimizer,
            wrt=nnx.Param,
        )

    metrics = nnx.MultiMetric(
        loss=nnx.metrics.Average("loss"),
    )

    train_step, eval_step = create_step(loss_fn=loss_fn)

    for (step, batch_idx, is_last_batch, epoch_idx), (
        batch_p,
        batch_u,
        batch_ex,
    ) in dataloader(
        (
            train_data.control_params,
            train_data.unitaries,
            train_data.observables,
        ),
        batch_size=BATCH_SIZE,
        num_epochs=NUM_EPOCH,
        key=loader_key,
    ):
        train_step(
            model,
            _optimizer,
            metrics,
            DataBundled(
                control_params=batch_p, unitaries=batch_u, observables=batch_ex
            ),
        )

        histories.append(
            HistoryEntryV3(
                step=step, loss=metrics.compute()["loss"], loop="train", aux={}
            )
        )
        metrics.reset()  # Reset the metrics for the train set.

        if is_last_batch:
            # Validation
            eval_step(model, metrics, val_data)
            histories.append(
                HistoryEntryV3(
                    step=step, loss=metrics.compute()["loss"], loop="val", aux={}
                )
            )
            metrics.reset()  # Reset the metrics for the val set.
            # Testing
            eval_step(model, metrics, test_data)
            histories.append(
                HistoryEntryV3(
                    step=step, loss=metrics.compute()["loss"], loop="test", aux={}
                )
            )
            metrics.reset()  # Reset the metrics for the test set.

            for callback in callbacks:
                callback(model, _optimizer, histories)

    return model, _optimizer, histories

inspeqtor.models.library.nnx.make_predictive_fn

make_predictive_fn(adapter_fn, model: Blackbox)
Source code in src/inspeqtor/v1/models/nnx.py
329
330
331
332
333
334
335
336
337
def make_predictive_fn(adapter_fn, model: Blackbox):
    def predictive_fn(
        control_parameters: jnp.ndarray, unitaries: jnp.ndarray
    ) -> jnp.ndarray:
        output = model(control_parameters)

        return adapter_fn(output, unitaries)

    return predictive_fn

inspeqtor.models.library.nnx.create_step

create_step(
    loss_fn: Callable[
        [Blackbox, DataBundled], tuple[ndarray, Any]
    ],
)

A function to create the traning and evaluating step for model. The train step will update the model parameters and optimizer parameters inplace.

Parameters:

Name Type Description Default
loss_fn Callable[[Blackbox, DataBundled], tuple[ndarray, Any]]

Loss function returned from make_loss_fn

required

Returns:

Type Description

typing.Any: The tuple of training and eval step functions.

Source code in src/inspeqtor/v1/models/nnx.py
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
def create_step(
    loss_fn: typing.Callable[[Blackbox, DataBundled], tuple[jnp.ndarray, typing.Any]],
):
    """A function to create the traning and evaluating step for model.
    The train step will update the model parameters and optimizer parameters inplace.

    Args:
        loss_fn (typing.Callable[[Blackbox, DataBundled], tuple[jnp.ndarray, typing.Any]]): Loss function returned from `make_loss_fn`

    Returns:
        typing.Any: The tuple of training and eval step functions.
    """

    @nnx.jit
    def train_step(
        model: Blackbox,
        optimizer: nnx.Optimizer,
        metrics: nnx.MultiMetric,
        data: DataBundled,
    ):
        """Train for a single step."""
        model.train()  # Switch to train mode
        grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
        (loss, aux), grads = grad_fn(model, data)
        metrics.update(loss=loss)  # In-place updates.
        optimizer.update(model, grads)  # In-place updates.

        return loss, aux

    @nnx.jit
    def eval_step(model: Blackbox, metrics: nnx.MultiMetric, data):
        model.eval()
        loss, aux = loss_fn(model, data)
        metrics.update(loss=loss)  # In-place updates.

        return loss, aux

    return train_step, eval_step

inspeqtor.models.library.nnx.make_loss_fn

make_loss_fn(adapter_fn, evaluate_fn)

A function for preparing loss function to be used for model training.

Parameters:

Name Type Description Default
predictive_fn Any

Adaptor function specifically for each model.

required
evaluate_fn Callable[[ndarray, ndarray], ndarray, ndarray]

Take in predicted and experimental expectation values and ideal unitary and return loss value

required
Source code in src/inspeqtor/v1/models/nnx.py
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
def make_loss_fn(adapter_fn, evaluate_fn):
    """A function for preparing loss function to be used for model training.

    Args:
        predictive_fn (typing.Any): Adaptor function specifically for each model.
        evaluate_fn (typing.Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray, jnp.ndarray]): Take in predicted and experimental expectation values and ideal unitary and return loss value
    """

    def loss_fn(model: Blackbox, data: DataBundled):
        output = model(data.control_params)

        expval = adapter_fn(output, data.unitaries)

        loss = evaluate_fn(expval, data.observables, data.unitaries)

        return jnp.mean(loss), {}

    return loss_fn