Skip to content

Models - Library - Linen

Flax Linen-based neural network models for quantum device characterization.

inspeqtor.models.library.linen

inspeqtor.models.library.linen.WoModel

Source code in src/inspeqtor/v1/models/linen.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
class WoModel(nn.Module):
    shared_layers: typing.Sequence[int] = (20, 10)
    pauli_layers: typing.Sequence[int] = (20, 10)
    pauli_operators: typing.Sequence[str] = ("X", "Y", "Z")

    NUM_UNITARY_PARAMS: int = 3
    NUM_DIAGONAL_PARAMS: int = 2

    unitary_activation_fn: typing.Callable[[jnp.ndarray], jnp.ndarray] = (
        lambda x: 2 * jnp.pi * nn.hard_sigmoid(x)
    )
    diagonal_activation_fn: typing.Callable[[jnp.ndarray], jnp.ndarray] = (
        lambda x: (2 * nn.hard_sigmoid(x)) - 1
    )

    @nn.compact
    def __call__(self, x: jnp.ndarray) -> dict[str, jnp.ndarray]:
        # Apply a dense layer for each hidden size
        for hidden_size in self.shared_layers:
            x = nn.Dense(features=hidden_size)(x)
            x = nn.relu(x)

        # Wos_params: dict[str, dict[str, jnp.ndarray]] = dict()
        Wos: dict[str, jnp.ndarray] = dict()
        for op in self.pauli_operators:
            # Sub hidden layer
            # Copy the input
            _x = jnp.copy(x)
            for hidden_size in self.pauli_layers:
                _x = nn.Dense(features=hidden_size)(_x)
                _x = nn.relu(_x)

            # Wos_params[op] = dict()
            # For the unitary part, we use a dense layer with 3 features
            unitary_params = nn.Dense(features=self.NUM_UNITARY_PARAMS, name=f"U_{op}")(
                _x
            )
            # Apply sigmoid to this layer
            unitary_params = self.unitary_activation_fn(unitary_params)
            # For the diagonal part, we use a dense layer with 1 feature
            diag_params = nn.Dense(features=self.NUM_DIAGONAL_PARAMS, name=f"D_{op}")(
                _x
            )
            # Apply the activation function
            diag_params = self.diagonal_activation_fn(diag_params)

            Wos[op] = hermitian(unitary_params, diag_params)

        return Wos

inspeqtor.models.library.linen.UnitaryModel

Source code in src/inspeqtor/v1/models/linen.py
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
class UnitaryModel(nn.Module):
    # feature_size: int
    hidden_sizes: list[int]

    NUM_UNITARY_PARAMS: int = 4

    @nn.compact
    def __call__(self, x: jnp.ndarray):
        # Apply a dense layer for each hidden size

        for hidden_size in self.hidden_sizes:
            x = nn.Dense(features=hidden_size)(x)
            x = nn.relu(x)

        # For the unitary part, we use a dense layer with 3 features
        x = nn.Dense(features=self.NUM_UNITARY_PARAMS)(x)
        # Apply sigmoid to this layer
        x = 2 * jnp.pi * nn.hard_sigmoid(x)

        return x

inspeqtor.models.library.linen.UnitarySPAMModel

Source code in src/inspeqtor/v1/models/linen.py
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
class UnitarySPAMModel(nn.Module):
    hidden_sizes: list[int]
    NUM_UNITARY_PARAMS: int = 4

    @nn.compact
    def __call__(self, x: jnp.ndarray):
        x = UnitaryModel(
            hidden_sizes=self.hidden_sizes, NUM_UNITARY_PARAMS=self.NUM_UNITARY_PARAMS
        )(x)

        spam_params = self.param(
            "spam_params",
            lambda rng, shape: init_fn(rng, shape),
            init_spam_params.shape,
        )

        return {"model_params": x, "spam_params": unflatten_fn(spam_params)}

inspeqtor.models.library.linen.train_model

train_model(
    key: ndarray,
    train_data: DataBundled,
    val_data: DataBundled,
    test_data: DataBundled,
    model: Module,
    optimizer: GradientTransformation,
    loss_fn: Callable,
    callbacks: list[Callable] = [],
    NUM_EPOCH: int = 1000,
    model_params: VariableDict | None = None,
    opt_state: OptState | None = None,
)

Train the BlackBox model

Examples:

>>> # The number of epochs break down
... NUM_EPOCH = 150
... # Total number of iterations as 90% of data is used for training
... # 10% of the data is used for testing
... total_iterations = 9 * NUM_EPOCH
... # The step for optimizer if set to 8 * NUM_EPOCH (should be less than total_iterations)
... step_for_optimizer = 8 * NUM_EPOCH
... optimizer = get_default_optimizer(step_for_optimizer)
... # The warmup steps for the optimizer
... warmup_steps = 0.1 * step_for_optimizer
... # The cool down steps for the optimizer
... cool_down_steps = total_iterations - step_for_optimizer
... total_iterations, step_for_optimizer, warmup_steps, cool_down_steps

Parameters:

Name Type Description Default
key ndarray

Random key

required
model Module

The model to be used for training

required
optimizer GradientTransformation

The optimizer to be used for training

required
loss_fn Callable

The loss function to be used for training

required
callbacks list[Callable]

list of callback functions. Defaults to [].

[]
NUM_EPOCH int

The number of epochs. Defaults to 1_000.

1000

Returns:

Name Type Description
tuple

The model parameters, optimizer state, and the histories

Source code in src/inspeqtor/v1/models/linen.py
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
def train_model(
    # Random key
    key: jnp.ndarray,
    # Data
    train_data: DataBundled,
    val_data: DataBundled,
    test_data: DataBundled,
    # Model to be used for training
    model: nn.Module,
    optimizer: optax.GradientTransformation,
    # Loss function to be used
    loss_fn: typing.Callable,
    # Callbacks to be used
    callbacks: list[typing.Callable] = [],
    # Number of epochs
    NUM_EPOCH: int = 1_000,
    # Optional state
    model_params: VariableDict | None = None,
    opt_state: optax.OptState | None = None,
):
    """Train the BlackBox model

    Examples:
        >>> # The number of epochs break down
        ... NUM_EPOCH = 150
        ... # Total number of iterations as 90% of data is used for training
        ... # 10% of the data is used for testing
        ... total_iterations = 9 * NUM_EPOCH
        ... # The step for optimizer if set to 8 * NUM_EPOCH (should be less than total_iterations)
        ... step_for_optimizer = 8 * NUM_EPOCH
        ... optimizer = get_default_optimizer(step_for_optimizer)
        ... # The warmup steps for the optimizer
        ... warmup_steps = 0.1 * step_for_optimizer
        ... # The cool down steps for the optimizer
        ... cool_down_steps = total_iterations - step_for_optimizer
        ... total_iterations, step_for_optimizer, warmup_steps, cool_down_steps

    Args:
        key (jnp.ndarray): Random key
        model (nn.Module): The model to be used for training
        optimizer (optax.GradientTransformation): The optimizer to be used for training
        loss_fn (typing.Callable): The loss function to be used for training
        callbacks (list[typing.Callable], optional): list of callback functions. Defaults to [].
        NUM_EPOCH (int, optional): The number of epochs. Defaults to 1_000.

    Returns:
        tuple: The model parameters, optimizer state, and the histories
    """

    key, loader_key, init_key = jax.random.split(key, 3)

    train_p, train_u, train_ex = (
        train_data.control_params,
        train_data.unitaries,
        train_data.observables,
    )
    val_p, val_u, val_ex = (
        val_data.control_params,
        val_data.unitaries,
        val_data.observables,
    )
    test_p, test_u, test_ex = (
        test_data.control_params,
        test_data.unitaries,
        test_data.observables,
    )

    BATCH_SIZE = val_p.shape[0]

    if model_params is None:
        # Initialize the model parameters if it is None
        model_params = model.init(init_key, train_p[0])

    if opt_state is None:
        # Initalize the optimizer state if it is None
        opt_state = optimizer.init(model_params)  # type: ignore

    # histories: list[dict[str, typing.Any]] = []
    histories: list[HistoryEntryV3] = []

    train_step, eval_step = create_step(
        optimizer=optimizer, loss_fn=loss_fn, has_aux=True
    )

    for (step, batch_idx, is_last_batch, epoch_idx), (
        batch_p,
        batch_u,
        batch_ex,
    ) in dataloader(
        (train_p, train_u, train_ex),
        batch_size=BATCH_SIZE,
        num_epochs=NUM_EPOCH,
        key=loader_key,
    ):
        model_params, opt_state, (loss, aux) = train_step(
            model_params, opt_state, batch_p, batch_u, batch_ex
        )

        histories.append(HistoryEntryV3(step=step, loss=loss, loop="train", aux=aux))

        if is_last_batch:
            # Validation
            (val_loss, aux) = eval_step(model_params, val_p, val_u, val_ex)

            histories.append(
                HistoryEntryV3(step=step, loss=val_loss, loop="val", aux=aux)
            )

            # Testing
            (test_loss, aux) = eval_step(model_params, test_p, test_u, test_ex)

            histories.append(
                HistoryEntryV3(step=step, loss=test_loss, loop="test", aux=aux)
            )

            for callback in callbacks:
                callback(model_params, opt_state, histories)

    return model_params, opt_state, histories

inspeqtor.models.library.linen.make_predictive_fn

make_predictive_fn(
    adapter_fn: adapter_fn_type,
    model: Module,
    model_params: Any,
)
Source code in src/inspeqtor/v1/models/linen.py
504
505
506
507
508
509
510
511
512
513
514
def make_predictive_fn(
    adapter_fn: adapter_fn_type, model: nn.Module, model_params: typing.Any
):
    def predictive_fn(
        control_parameters: jnp.ndarray, unitaries: jnp.ndarray
    ) -> jnp.ndarray:
        output = model.apply(model_params, control_parameters)

        return adapter_fn(output, unitaries)

    return predictive_fn

inspeqtor.models.library.linen.create_step

create_step(
    optimizer: GradientTransformation,
    loss_fn: Callable[..., ndarray]
    | Callable[..., Tuple[ndarray, Any]],
    has_aux: bool = False,
)

The create_step function creates a training step function and a test step function.

loss_fn should have the following signature:

def loss_fn(params: jaxtyping.PyTree, *args) -> jnp.ndarray:
    ...
    return loss_value
where params is the parameters to be optimized, and args are the inputs for the loss function.

Parameters:

Name Type Description Default
optimizer GradientTransformation

optax optimizer.

required
loss_fn Callable[[PyTree, ...], ndarray]

Loss function, which takes in the model parameters, inputs, and targets, and returns the loss value.

required
has_aux bool

Whether the loss function return aux data or not. Defaults to False.

False

Returns:

Type Description

typing.Any: train_step, test_step

Source code in src/inspeqtor/v1/models/linen.py
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
def create_step(
    optimizer: optax.GradientTransformation,
    loss_fn: (
        typing.Callable[..., jnp.ndarray]
        | typing.Callable[..., typing.Tuple[jnp.ndarray, typing.Any]]
    ),
    has_aux: bool = False,
):
    """The create_step function creates a training step function and a test step function.

    loss_fn should have the following signature:
    ```py
    def loss_fn(params: jaxtyping.PyTree, *args) -> jnp.ndarray:
        ...
        return loss_value
    ```
    where `params` is the parameters to be optimized, and `args` are the inputs for the loss function.

    Args:
        optimizer (optax.GradientTransformation): `optax` optimizer.
        loss_fn (typing.Callable[[jaxtyping.PyTree, ...], jnp.ndarray]): Loss function, which takes in the model parameters, inputs, and targets, and returns the loss value.
        has_aux (bool, optional): Whether the loss function return aux data or not. Defaults to False.

    Returns:
        _typing.Any_: train_step, test_step
    """

    # * Generalized training step
    @jax.jit
    def train_step(
        params: jaxtyping.PyTree,
        optimizer_state: optax.OptState,
        *args,
        **kwargs,
    ):
        loss_value, grads = jax.value_and_grad(loss_fn, has_aux=has_aux)(
            params, *args, **kwargs
        )
        updates, opt_state = optimizer.update(grads, optimizer_state, params)
        params = optax.apply_updates(params, updates)

        return params, opt_state, loss_value

    @jax.jit
    def test_step(
        params: jaxtyping.PyTree,
        *args,
        **kwargs,
    ):
        return loss_fn(params, *args, **kwargs)

    return train_step, test_step

inspeqtor.models.library.linen.make_loss_fn

make_loss_fn(
    adapter_fn: Callable,
    model: Module,
    evaluate_fn: Callable[
        [ndarray, ndarray, ndarray], ndarray
    ],
)

summary

Parameters:

Name Type Description Default
predictive_fn Callable

Function for calculating expectation value from the model

required
model Module

Flax linen Blackbox part of the graybox model.

required
evaluate_fn Callable[[ndarray, ndarray], ndarray, ndarray]

Take in predicted and experimental expectation values and ideal unitary and return loss value

required
Source code in src/inspeqtor/v1/models/linen.py
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
def make_loss_fn(
    adapter_fn: typing.Callable,
    model: nn.Module,
    evaluate_fn: typing.Callable[[jnp.ndarray, jnp.ndarray, jnp.ndarray], jnp.ndarray],
):
    """_summary_

    Args:
        predictive_fn (typing.Callable): Function for calculating expectation value from the model
        model (nn.Module): Flax linen Blackbox part of the graybox model.
        evaluate_fn ( typing.Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray, jnp.ndarray]): Take in predicted and experimental expectation values and ideal unitary and return loss value
    """

    def loss_fn(
        params: VariableDict,
        control_parameters: jnp.ndarray,
        unitaries: jnp.ndarray,
        expectation_values: jnp.ndarray,
        **model_kwargs,
    ) -> tuple[jnp.ndarray, dict[str, jnp.ndarray]]:
        """This function implement a unified interface for nn.Module.

        Args:
            params (VariableDict): Model parameters to be optimized
            control_parameters (jnp.ndarray): Control parameters parametrized Hamiltonian
            unitaries (jnp.ndarray): The Ideal unitary operators corresponding to the control parameters
            expectation_values (jnp.ndarray): Experimental expectation values to calculate the loss value

        Returns:
            tuple[jnp.ndarray, dict[str, jnp.ndarray]]: The loss value and other metrics.
        """
        output = model.apply(params, control_parameters, **model_kwargs)
        predicted_expectation_value = adapter_fn(output, unitaries=unitaries)

        loss = evaluate_fn(predicted_expectation_value, expectation_values, unitaries)

        return jnp.mean(loss), {}

    return loss_fn