Skip to content

Control

inspeqtor.control

inspeqtor.control.BaseControl dataclass

Source code in src/inspeqtor/v2/control.py
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
@dataclass
class BaseControl(ABC):
    # duration: int

    def __post_init__(self):
        # self.t_eval = jnp.arange(0, self.duration, 1)
        self.validate()

    def validate(self):
        # Validate that all attributes are json serializable
        try:
            json.dumps(self.to_dict())
        except TypeError as e:
            raise TypeError(
                f"Cannot serialize {self.__class__.__name__} to json"
            ) from e

        lower, upper = self.get_bounds()
        # Validate that the sampling function is working
        key = jax.random.key(0)
        params = sample_params(key, lower, upper)
        # waveform = self.get_waveform(params)

        assert all([isinstance(k, str) for k in params.keys()]), (
            "All key of params dict must be string"
        )
        assert all([isinstance(v, float) for v in params.values()]) or all(
            [isinstance(v, jnp.ndarray) for v in params.values()]
        ), "All value of params dict must be float"
        # assert isinstance(waveform, jnp.ndarray), "Waveform must be jnp.ndarray"

    @abstractmethod
    def get_bounds(
        self, *arg, **kwarg
    ) -> tuple[ParametersDictType, ParametersDictType]: ...

    @abstractmethod
    def get_envelope(self, params: ParametersDictType) -> typing.Callable:
        raise NotImplementedError("get_envelopes method is not implemented")

    def to_dict(self) -> dict[str, typing.Union[int, float, str]]:
        """Convert the control configuration to dictionary

        Returns:
            dict[str, typing.Union[int, float, str]]: Configuration of the control
        """
        return asdict(self)

    def to_dict_new(self) -> dict[str, typing.Union[int, float, str]]:
        """Convert the control configuration to dictionary

        Returns:
            dict[str, typing.Union[int, float, str]]: Configuration of the control
        """
        return {**asdict(self), "classname": self.__class__.__name__}

    @classmethod
    def from_dict(cls, data):
        """Construct the control instace from the dictionary.

        Args:
            data (dict): Dictionary for construction of the control instance.

        Returns:
            The instance of the control.
        """
        return cls(**data)

to_dict

to_dict() -> dict[str, Union[int, float, str]]

Convert the control configuration to dictionary

Returns:

Type Description
dict[str, Union[int, float, str]]

dict[str, typing.Union[int, float, str]]: Configuration of the control

Source code in src/inspeqtor/v2/control.py
91
92
93
94
95
96
97
def to_dict(self) -> dict[str, typing.Union[int, float, str]]:
    """Convert the control configuration to dictionary

    Returns:
        dict[str, typing.Union[int, float, str]]: Configuration of the control
    """
    return asdict(self)

to_dict_new

to_dict_new() -> dict[str, Union[int, float, str]]

Convert the control configuration to dictionary

Returns:

Type Description
dict[str, Union[int, float, str]]

dict[str, typing.Union[int, float, str]]: Configuration of the control

Source code in src/inspeqtor/v2/control.py
 99
100
101
102
103
104
105
def to_dict_new(self) -> dict[str, typing.Union[int, float, str]]:
    """Convert the control configuration to dictionary

    Returns:
        dict[str, typing.Union[int, float, str]]: Configuration of the control
    """
    return {**asdict(self), "classname": self.__class__.__name__}

from_dict classmethod

from_dict(data)

Construct the control instace from the dictionary.

Parameters:

Name Type Description Default
data dict

Dictionary for construction of the control instance.

required

Returns:

Type Description

The instance of the control.

Source code in src/inspeqtor/v2/control.py
107
108
109
110
111
112
113
114
115
116
117
@classmethod
def from_dict(cls, data):
    """Construct the control instace from the dictionary.

    Args:
        data (dict): Dictionary for construction of the control instance.

    Returns:
        The instance of the control.
    """
    return cls(**data)

inspeqtor.control.ControlSequence dataclass

Control sequence, expect to be sum of atomic control.

Source code in src/inspeqtor/v2/control.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
@dataclass
class ControlSequence:
    """Control sequence, expect to be sum of atomic control."""

    controls: dict[str, BaseControl]
    total_dt: int
    structure: typing.Sequence[typing.Sequence[str]] | None = field(default=None)

    def __post_init__(self):
        # Cache the bounds
        self.lower, self.upper = self.get_bounds()
        # Create the order

        if self.structure is None:
            self.auto_order = []
            for ctrl_key in self.controls.keys():
                sub_control = []
                for ctrl_param_key in self.lower[ctrl_key].keys():
                    sub_control.append((ctrl_key, ctrl_param_key))

                self.auto_order += sub_control
            self.structure = self.auto_order
        else:
            self.auto_order = self.structure

    def get_structure(self) -> typing.Sequence[typing.Sequence[str]]:
        return self.auto_order

    def sample_params_v1(self, key: jnp.ndarray) -> dict[str, ParametersDictType]:
        """Sample control parameter

        Args:
            key (jnp.ndarray): Random key

        Returns:
            dict[str, ParametersDictType]: control parameters
        """
        params_dict: dict[str, ParametersDictType] = {}
        for idx, ctrl_key in enumerate(self.controls.keys()):
            subkey = jax.random.fold_in(key, idx)
            params = sample_params(subkey, self.lower[ctrl_key], self.upper[ctrl_key])
            params_dict[ctrl_key] = params

        return params_dict

    def sample_params(self, key: jnp.ndarray) -> dict[str, ParametersDictType]:
        return self.sample_params_v2(key)

    def sample_params_v2(self, key: jnp.ndarray) -> dict[str, ParametersDictType]:
        """Sample control parameter

        Args:
            key (jnp.ndarray): Random key

        Returns:
            dict[str, ParametersDictType]: control parameters
        """
        return nested_sample(key, merge_lower_upper(self.lower, self.upper))

    def get_bounds(
        self,
    ) -> tuple[dict[str, ParametersDictType], dict[str, ParametersDictType]]:
        """Get the bounds of the controls

        Returns:
            tuple[list[ParametersDictType], list[ParametersDictType]]: tuple of list of lower and upper bounds.
        """

        lower_bounds = jax.tree.map(
            lambda x: x.get_bounds()[0],
            self.controls,
            is_leaf=lambda x: isinstance(x, BaseControl),
        )
        upper_bounds = jax.tree.map(
            lambda x: x.get_bounds()[1],
            self.controls,
            is_leaf=lambda x: isinstance(x, BaseControl),
        )

        return lower_bounds, upper_bounds

    def get_envelope(
        self, params_dict: dict[str, ParametersDictType]
    ) -> typing.Callable:
        """Create envelope function with given control parameters

        Args:
            params_list (dict[str, ParametersDictType]): control parameter to be used

        Returns:
            typing.Callable: Envelope function
        """
        callables = []
        for params_key, params_val in params_dict.items():
            callables.append(self.controls[params_key].get_envelope(params_val))

        # Create a function that returns the sum of the envelopes
        def envelope(t):
            return sum([c(t) for c in callables])

        return envelope

    def to_dict_new(self) -> dict[str, str | dict[str, str | float]]:
        """Convert self to dict

        Returns:
            dict[str, str | dict[str, str | float]]: dict contain argument necessary for re-initialization.
        """
        return {
            **asdict(self),
            "classname": {k: v.__class__.__name__ for k, v in self.controls.items()},
            "controls": jax.tree.map(
                lambda x: x.to_dict(),
                self.controls,
                is_leaf=lambda x: isinstance(x, BaseControl),
            ),
        }

    def to_dict(self) -> dict[str, str | dict[str, str | float]]:
        """Convert self to dict

        Returns:
            dict[str, str | dict[str, str | float]]: dict contain argument necessary for re-initialization.
        """
        return {
            **asdict(self),
            "classname": {k: v.__class__.__name__ for k, v in self.controls.items()},
        }

    @classmethod
    def from_dict(
        cls,
        data: dict[str, str | dict[str, str | float]],
        controls: dict[str, type[BaseControl]],
    ) -> "ControlSequence":
        """Construct self with the provided dictionary

        Args:
            data (dict[str, str  |  dict[str, str  |  float]]): The dictionary contain initialization arguments
            controls (dict[str, type[BaseControl]]): The map of control name and class of the control

        Returns:
            ControlSequence: the instance of control sequence
        """
        controls_data = data["controls"]
        assert isinstance(controls_data, dict)

        instantiated_controls = {}

        for (ctrl_key, ctrl_data), (ctrl_key_match, ctrl_cls) in zip(
            controls_data.items(), controls.items()
        ):
            assert ctrl_key == ctrl_key_match
            assert isinstance(ctrl_data, dict), f"Expected dict, got {type(ctrl_data)}"
            instantiated_controls[ctrl_key] = ctrl_cls.from_dict(ctrl_data)

        total_dt = data["total_dt"]
        assert isinstance(total_dt, int)
        structure = data["structure"]
        assert isinstance(structure, list)
        # Explicitly convert each item in the structure to be tuple.
        structure = [tuple(item) for item in structure]

        return cls(
            controls=instantiated_controls, total_dt=total_dt, structure=structure
        )

    @classmethod
    def from_dict_new(
        cls,
        data: dict[str, str | dict[str, str | float]],
        controls: dict[str, type[BaseControl]],
    ) -> "ControlSequence":
        controls_data = data["controls"]
        assert isinstance(controls_data, dict)

        def check_if_control_dict(leave) -> bool:
            if isinstance(leave, dict):
                if "classname" in leave:
                    return True

            return False

        def initialize_control(control_data: dict) -> BaseControl:
            cls_name = control_data["classname"]
            clean_data = {k: v for k, v in control_data.items() if k != "classname"}
            return controls[cls_name].from_dict(clean_data)

        # Initialize contols
        initialized_controls = jax.tree.map(
            initialize_control, controls_data, is_leaf=check_if_control_dict
        )

        total_dt = data["total_dt"]
        assert isinstance(total_dt, int)
        structure = data["structure"]
        assert isinstance(structure, list)
        # Explicitly convert each item in the structure to be tuple.
        structure = [tuple(item) for item in structure]

        return cls(
            controls=initialized_controls, total_dt=total_dt, structure=structure
        )

    def to_file(self, path: typing.Union[str, pathlib.Path]):
        """Save configuration of the pulse to file given folder path.

        Args:
            path (typing.Union[str, pathlib.Path]): Path to the folder to save sequence, will be created if not existed.
        """
        if isinstance(path, str):
            path = pathlib.Path(path)

        save_pytree_to_json(self.to_dict(), path / "control_sequence.json")

    @classmethod
    def from_file(
        cls,
        path: typing.Union[str, pathlib.Path],
        controls: dict[str, type[BaseControl]],
    ):
        """Initialize itself from a file.

        Args:
            path (typing.Union[str, pathlib.Path]): Path to file.
            controls (dict[str, type[BaseControl]]): The map of control name and class of the control

        Returns:
            ControlSequence: the instance of control sequence
        """
        if isinstance(path, str):
            path = pathlib.Path(path)

        ctrl_loaded_dict = load_pytree_from_json(
            path / "control_sequence.json", lambda k, v: (True, v)
        )

        return cls.from_dict(ctrl_loaded_dict, controls=controls)

sample_params_v1

sample_params_v1(
    key: ndarray,
) -> dict[str, ParametersDictType]

Sample control parameter

Parameters:

Name Type Description Default
key ndarray

Random key

required

Returns:

Type Description
dict[str, ParametersDictType]

dict[str, ParametersDictType]: control parameters

Source code in src/inspeqtor/v2/control.py
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
def sample_params_v1(self, key: jnp.ndarray) -> dict[str, ParametersDictType]:
    """Sample control parameter

    Args:
        key (jnp.ndarray): Random key

    Returns:
        dict[str, ParametersDictType]: control parameters
    """
    params_dict: dict[str, ParametersDictType] = {}
    for idx, ctrl_key in enumerate(self.controls.keys()):
        subkey = jax.random.fold_in(key, idx)
        params = sample_params(subkey, self.lower[ctrl_key], self.upper[ctrl_key])
        params_dict[ctrl_key] = params

    return params_dict

sample_params_v2

sample_params_v2(
    key: ndarray,
) -> dict[str, ParametersDictType]

Sample control parameter

Parameters:

Name Type Description Default
key ndarray

Random key

required

Returns:

Type Description
dict[str, ParametersDictType]

dict[str, ParametersDictType]: control parameters

Source code in src/inspeqtor/v2/control.py
168
169
170
171
172
173
174
175
176
177
def sample_params_v2(self, key: jnp.ndarray) -> dict[str, ParametersDictType]:
    """Sample control parameter

    Args:
        key (jnp.ndarray): Random key

    Returns:
        dict[str, ParametersDictType]: control parameters
    """
    return nested_sample(key, merge_lower_upper(self.lower, self.upper))

get_bounds

get_bounds() -> tuple[
    dict[str, ParametersDictType],
    dict[str, ParametersDictType],
]

Get the bounds of the controls

Returns:

Type Description
tuple[dict[str, ParametersDictType], dict[str, ParametersDictType]]

tuple[list[ParametersDictType], list[ParametersDictType]]: tuple of list of lower and upper bounds.

Source code in src/inspeqtor/v2/control.py
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
def get_bounds(
    self,
) -> tuple[dict[str, ParametersDictType], dict[str, ParametersDictType]]:
    """Get the bounds of the controls

    Returns:
        tuple[list[ParametersDictType], list[ParametersDictType]]: tuple of list of lower and upper bounds.
    """

    lower_bounds = jax.tree.map(
        lambda x: x.get_bounds()[0],
        self.controls,
        is_leaf=lambda x: isinstance(x, BaseControl),
    )
    upper_bounds = jax.tree.map(
        lambda x: x.get_bounds()[1],
        self.controls,
        is_leaf=lambda x: isinstance(x, BaseControl),
    )

    return lower_bounds, upper_bounds

get_envelope

get_envelope(
    params_dict: dict[str, ParametersDictType],
) -> Callable

Create envelope function with given control parameters

Parameters:

Name Type Description Default
params_list dict[str, ParametersDictType]

control parameter to be used

required

Returns:

Type Description
Callable

typing.Callable: Envelope function

Source code in src/inspeqtor/v2/control.py
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
def get_envelope(
    self, params_dict: dict[str, ParametersDictType]
) -> typing.Callable:
    """Create envelope function with given control parameters

    Args:
        params_list (dict[str, ParametersDictType]): control parameter to be used

    Returns:
        typing.Callable: Envelope function
    """
    callables = []
    for params_key, params_val in params_dict.items():
        callables.append(self.controls[params_key].get_envelope(params_val))

    # Create a function that returns the sum of the envelopes
    def envelope(t):
        return sum([c(t) for c in callables])

    return envelope

to_dict_new

to_dict_new() -> dict[str, str | dict[str, str | float]]

Convert self to dict

Returns:

Type Description
dict[str, str | dict[str, str | float]]

dict[str, str | dict[str, str | float]]: dict contain argument necessary for re-initialization.

Source code in src/inspeqtor/v2/control.py
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
def to_dict_new(self) -> dict[str, str | dict[str, str | float]]:
    """Convert self to dict

    Returns:
        dict[str, str | dict[str, str | float]]: dict contain argument necessary for re-initialization.
    """
    return {
        **asdict(self),
        "classname": {k: v.__class__.__name__ for k, v in self.controls.items()},
        "controls": jax.tree.map(
            lambda x: x.to_dict(),
            self.controls,
            is_leaf=lambda x: isinstance(x, BaseControl),
        ),
    }

to_dict

to_dict() -> dict[str, str | dict[str, str | float]]

Convert self to dict

Returns:

Type Description
dict[str, str | dict[str, str | float]]

dict[str, str | dict[str, str | float]]: dict contain argument necessary for re-initialization.

Source code in src/inspeqtor/v2/control.py
238
239
240
241
242
243
244
245
246
247
def to_dict(self) -> dict[str, str | dict[str, str | float]]:
    """Convert self to dict

    Returns:
        dict[str, str | dict[str, str | float]]: dict contain argument necessary for re-initialization.
    """
    return {
        **asdict(self),
        "classname": {k: v.__class__.__name__ for k, v in self.controls.items()},
    }

from_dict classmethod

from_dict(
    data: dict[str, str | dict[str, str | float]],
    controls: dict[str, type[BaseControl]],
) -> ControlSequence

Construct self with the provided dictionary

Parameters:

Name Type Description Default
data dict[str, str | dict[str, str | float]]

The dictionary contain initialization arguments

required
controls dict[str, type[BaseControl]]

The map of control name and class of the control

required

Returns:

Name Type Description
ControlSequence ControlSequence

the instance of control sequence

Source code in src/inspeqtor/v2/control.py
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
@classmethod
def from_dict(
    cls,
    data: dict[str, str | dict[str, str | float]],
    controls: dict[str, type[BaseControl]],
) -> "ControlSequence":
    """Construct self with the provided dictionary

    Args:
        data (dict[str, str  |  dict[str, str  |  float]]): The dictionary contain initialization arguments
        controls (dict[str, type[BaseControl]]): The map of control name and class of the control

    Returns:
        ControlSequence: the instance of control sequence
    """
    controls_data = data["controls"]
    assert isinstance(controls_data, dict)

    instantiated_controls = {}

    for (ctrl_key, ctrl_data), (ctrl_key_match, ctrl_cls) in zip(
        controls_data.items(), controls.items()
    ):
        assert ctrl_key == ctrl_key_match
        assert isinstance(ctrl_data, dict), f"Expected dict, got {type(ctrl_data)}"
        instantiated_controls[ctrl_key] = ctrl_cls.from_dict(ctrl_data)

    total_dt = data["total_dt"]
    assert isinstance(total_dt, int)
    structure = data["structure"]
    assert isinstance(structure, list)
    # Explicitly convert each item in the structure to be tuple.
    structure = [tuple(item) for item in structure]

    return cls(
        controls=instantiated_controls, total_dt=total_dt, structure=structure
    )

to_file

to_file(path: Union[str, Path])

Save configuration of the pulse to file given folder path.

Parameters:

Name Type Description Default
path Union[str, Path]

Path to the folder to save sequence, will be created if not existed.

required
Source code in src/inspeqtor/v2/control.py
324
325
326
327
328
329
330
331
332
333
def to_file(self, path: typing.Union[str, pathlib.Path]):
    """Save configuration of the pulse to file given folder path.

    Args:
        path (typing.Union[str, pathlib.Path]): Path to the folder to save sequence, will be created if not existed.
    """
    if isinstance(path, str):
        path = pathlib.Path(path)

    save_pytree_to_json(self.to_dict(), path / "control_sequence.json")

from_file classmethod

from_file(
    path: Union[str, Path],
    controls: dict[str, type[BaseControl]],
)

Initialize itself from a file.

Parameters:

Name Type Description Default
path Union[str, Path]

Path to file.

required
controls dict[str, type[BaseControl]]

The map of control name and class of the control

required

Returns:

Name Type Description
ControlSequence

the instance of control sequence

Source code in src/inspeqtor/v2/control.py
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
@classmethod
def from_file(
    cls,
    path: typing.Union[str, pathlib.Path],
    controls: dict[str, type[BaseControl]],
):
    """Initialize itself from a file.

    Args:
        path (typing.Union[str, pathlib.Path]): Path to file.
        controls (dict[str, type[BaseControl]]): The map of control name and class of the control

    Returns:
        ControlSequence: the instance of control sequence
    """
    if isinstance(path, str):
        path = pathlib.Path(path)

    ctrl_loaded_dict = load_pytree_from_json(
        path / "control_sequence.json", lambda k, v: (True, v)
    )

    return cls.from_dict(ctrl_loaded_dict, controls=controls)

inspeqtor.control.control_waveform

control_waveform(
    param: ParametersDictType,
    t_eval: ndarray,
    control: BaseControl,
) -> ndarray
Source code in src/inspeqtor/v2/control.py
360
361
362
363
364
365
def control_waveform(
    param: ParametersDictType,
    t_eval: jnp.ndarray,
    control: BaseControl,
) -> jnp.ndarray:
    return jax.vmap(control.get_envelope(param))(t_eval)

inspeqtor.control.sequence_waveform

sequence_waveform(
    params: dict[str, ParametersDictType],
    t_eval: ndarray,
    control_seqeunce: ControlSequence,
) -> ndarray

Samples the pulse sequence by generating random parameters for each pulse and computing the total waveform.

Parameters:

Name Type Description Default
key Key

The random key used for generating the parameters.

required

Returns:

Type Description
ndarray

tuple[list[ParametersDictType], Complex[Array, "time"]]: A tuple containing a list of parameter dictionaries for each pulse and the total waveform.

Example

key = jax.random.PRNGKey(0) params, total_waveform = sample(key)

Source code in src/inspeqtor/v2/control.py
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
def sequence_waveform(
    params: dict[str, ParametersDictType],
    t_eval: jnp.ndarray,
    control_seqeunce: ControlSequence,
) -> jnp.ndarray:
    """
    Samples the pulse sequence by generating random parameters for each pulse and computing the total waveform.

    Parameters:
        key (Key): The random key used for generating the parameters.

    Returns:
        tuple[list[ParametersDictType], Complex[Array, "time"]]: A tuple containing a list of parameter dictionaries for each pulse and the total waveform.

    Example:
        key = jax.random.PRNGKey(0)
        params, total_waveform = sample(key)
    """
    # Create base waveform
    total_waveform = jnp.zeros_like(t_eval, dtype=jnp.complex64)

    for (param_key, param_val), (ctrl_key, control) in zip(
        params.items(), control_seqeunce.controls.items()
    ):
        waveform = control_waveform(param_val, t_eval, control)
        total_waveform += waveform

    return total_waveform

inspeqtor.control.ravel_unravel_fn

ravel_unravel_fn(structure: Iterable[Iterable[str]])

This function return the ravel and unravel functions for the provided control sequence

Parameters:

Name Type Description Default
structure Iterable[Iterable[str]]

The structure of the pytree

required

Returns:

Type Description

tuple[typing.Callable, typing.Callable]: The first element is the function that convert structured parameter to array, the second is a function that reverse the action of the first.

Source code in src/inspeqtor/v2/control.py
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
def ravel_unravel_fn(structure: typing.Iterable[typing.Iterable[str]]):
    """This function return the ravel and unravel functions for the provided control sequence

    Args:
        structure (typing.Iterable[typing.Iterable[str]]): The structure of the pytree

    Returns:
        tuple[typing.Callable, typing.Callable]: The first element is the function that convert structured parameter to array, the second is a function that reverse the action of the first.
    """

    def ravel_fn(param: ParametersDictType):
        return jnp.array(
            [get_value_by_keys(param, dict_keys) for dict_keys in structure]
        )

    def unravel_fn(param: jnp.ndarray):
        return unflatten_dict(
            {dict_keys: param[idx] for idx, dict_keys in enumerate(structure)}
        )

    return ravel_fn, unravel_fn

inspeqtor.control.sample_params

sample_params(
    key: ndarray,
    lower: ParametersDictType,
    upper: ParametersDictType,
) -> ParametersDictType

Sample parameters with the same shape with given lower and upper bounds

Parameters:

Name Type Description Default
key ndarray

Random key

required
lower ParametersDictType

Lower bound

required
upper ParametersDictType

Upper bound

required

Returns:

Name Type Description
ParametersDictType ParametersDictType

Dict of the sampled parameters

Source code in src/inspeqtor/v2/control.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def sample_params(
    key: jnp.ndarray, lower: ParametersDictType, upper: ParametersDictType
) -> ParametersDictType:
    """Sample parameters with the same shape with given lower and upper bounds

    Args:
        key (jnp.ndarray): Random key
        lower (ParametersDictType): Lower bound
        upper (ParametersDictType): Upper bound

    Returns:
        ParametersDictType: Dict of the sampled parameters
    """
    # This function is general because it is depend only on lower and upper structure
    param: ParametersDictType = {}
    param_names = lower.keys()
    for name in param_names:
        sample_key, key = jax.random.split(key)
        param[name] = jax.random.uniform(
            sample_key, shape=(), minval=lower[name], maxval=upper[name]
        )

    # return jax.tree.map(float, param)
    return param

inspeqtor.control.construct_control_sequence_reader

construct_control_sequence_reader(
    controls: list[type[BaseControl]] = [],
) -> Callable[[Union[str, Path]], ControlSequence]

Construct the control sequence reader

Parameters:

Name Type Description Default
controls list[type[BasePulse]]

List of control constructor. Defaults to [].

[]

Returns:

Type Description
Callable[[Union[str, Path]], ControlSequence]

typing.Callable[[typing.Union[str, pathlib.Path]], controlsequence]: Control sequence reader that will automatically contruct control sequence from path.

Source code in src/inspeqtor/v2/control.py
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
def construct_control_sequence_reader(
    controls: list[type[BaseControl]] = [],
) -> typing.Callable[[typing.Union[str, pathlib.Path]], ControlSequence]:
    """Construct the control sequence reader

    Args:
        controls (list[type[BasePulse]], optional): List of control constructor. Defaults to [].

    Returns:
        typing.Callable[[typing.Union[str, pathlib.Path]], controlsequence]: Control sequence reader that will automatically contruct control sequence from path.
    """
    default_controls: list[type[BaseControl]] = []

    # Merge the default controls with the provided controls
    controls_list = default_controls + controls
    control_dict = {ctrl.__name__: ctrl for ctrl in controls_list}

    def control_sequence_reader(
        path: typing.Union[str, pathlib.Path],
    ) -> ControlSequence:
        """Construct control sequence from path

        Args:
            path (typing.Union[str, pathlib.Path]): Path of the saved control sequence configuration.

        Returns:
            ControlSeqence: Control sequence instance.
        """
        if isinstance(path, str):
            path = pathlib.Path(path)

        control_sequence_dict = load_pytree_from_json(
            path / "control_sequence.json", lambda k, v: (True, v)
        )

        parsed_controls = {}
        assert isinstance(control_sequence_dict["classname"], dict)

        for ctrl_key, ctrl_classname in control_sequence_dict["classname"].items():
            parsed_controls[ctrl_key] = control_dict[ctrl_classname]

        return ControlSequence.from_dict(
            control_sequence_dict, controls=parsed_controls
        )

    return control_sequence_reader

inspeqtor.control.ravel_transform

ravel_transform(
    fn: Callable,
    control_sequence: ControlSequence,
    at: int = 0,
) -> Callable

Transform the argument at index at of the function fn with unravel_fn of the control sequence

Note
signal_fn = sq.control.ravel_transform(
    sq.physics.signal_func_v5(control_sequence.get_envelope, qubit_info.frequency, dt),
    control_sequence,
)

Parameters:

Name Type Description Default
fn Callable

The function to be transformed

required
control_sequence ControlSequence

The control sequence that will use to produce unravel_fn.

required

Returns:

Type Description
Callable

typing.Callable: A function that its first argument is transformed by unravel_fn

Source code in src/inspeqtor/v2/control.py
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
def ravel_transform(
    fn: typing.Callable, control_sequence: ControlSequence, at: int = 0
) -> typing.Callable:
    """Transform the argument at index `at` of the function `fn` with `unravel_fn` of the control sequence

    Note:
        ```python
        signal_fn = sq.control.ravel_transform(
            sq.physics.signal_func_v5(control_sequence.get_envelope, qubit_info.frequency, dt),
            control_sequence,
        )
        ```

    Args:
        fn (typing.Callable): The function to be transformed
        control_sequence (ControlSequence): The control sequence that will use to produce `unravel_fn`.

    Returns:
        typing.Callable: A function that its first argument is transformed by `unravel_fn`
    """
    _, unravel_fn = ravel_unravel_fn(control_sequence.get_structure())

    def wrapper(*args, **kwargs):
        list_args = list(args)
        list_args[at] = unravel_fn(list_args[at])

        return fn(*tuple(list_args), **kwargs)

    return wrapper

inspeqtor.control.ParametersDictType module-attribute

ParametersDictType = PyTree[Union[Float, ndarray]]

inspeqtor.control.nested_sample

nested_sample(
    key: ndarray, bounds, sample_fn=uniform_sample
)

Sample from nested bounds with custom sampling function sample_fn

Parameters:

Name Type Description Default
key ndarray

Random key

required
bounds _type_

Bound of the control parameter

required
sample_fn _type_

Custom sampling function. Defaults to uniform_sample.

uniform_sample

Returns:

Name Type Description
_type_

Control parameter sample from bound

Source code in src/inspeqtor/v2/control.py
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
def nested_sample(key: jnp.ndarray, bounds, sample_fn=uniform_sample):
    """Sample from nested bounds with custom sampling function `sample_fn`

    Args:
        key (jnp.ndarray): Random key
        bounds (_type_): Bound of the control parameter
        sample_fn (_type_, optional): Custom sampling function. Defaults to uniform_sample.

    Returns:
        _type_: Control parameter sample from bound
    """
    return unflatten_dict(
        {
            k: sample_fn(jax.random.fold_in(key, idx), bound)
            for idx, (k, bound) in enumerate(flatten_dict(bounds).items())
        }
    )

inspeqtor.control.check_bounds

check_bounds(
    param: ParametersDictType, bounds: Bounds
) -> bool

Check if the given control parameter violate the bound or not.

Parameters:

Name Type Description Default
param _type_

Control parameter

required
bounds _type_

Bound of control parameter

required

Returns:

Name Type Description
bool bool

True if parameter do not violate the bound, otherwise False

Source code in src/inspeqtor/v2/control.py
448
449
450
451
452
453
454
455
456
457
458
459
460
461
def check_bounds(param: ParametersDictType, bounds: Bounds) -> bool:
    """Check if the given control parameter violate the bound or not.

    Args:
        param (_type_): Control parameter
        bounds (_type_): Bound of control parameter

    Returns:
        bool: `True` if parameter do not violate the bound, otherwise `False`
    """
    valid_container = jax.tree.map(
        lambda x, bound: (bound[0] < x) & (x < bound[1]), param, bounds
    )
    return jax.tree.reduce(lambda init, x: init & x, valid_container, initializer=True)

inspeqtor.control.merge_lower_upper

merge_lower_upper(lower: LowerBound, upper: UpperBound)

Merge lower and upper bound into bounds

Parameters:

Name Type Description Default
lower _type_

The lower bound

required
upper _type_

The upper bound

required

Returns:

Name Type Description
_type_

Bound from the lower and upper.

Source code in src/inspeqtor/v2/control.py
398
399
400
401
402
403
404
405
406
407
408
def merge_lower_upper(lower: LowerBound, upper: UpperBound):
    """Merge lower and upper bound into bounds

    Args:
        lower (_type_): The lower bound
        upper (_type_): The upper bound

    Returns:
        _type_: Bound from the lower and upper.
    """
    return jax.tree.map(lambda x, y: (x, y), lower, upper)

inspeqtor.control.split_bounds

split_bounds(
    bounds: Bounds,
) -> tuple[LowerBound, UpperBound]

Create lower and upper bound from bounds

Parameters:

Name Type Description Default
bounds _type_

The bounds to extract the lower and upper bound

required

Returns:

Name Type Description
_type_ tuple[LowerBound, UpperBound]

The lower and upper bound

Source code in src/inspeqtor/v2/control.py
411
412
413
414
415
416
417
418
419
420
421
422
def split_bounds(bounds: Bounds) -> tuple[LowerBound, UpperBound]:
    """Create lower and upper bound from bounds

    Args:
        bounds (_type_): The bounds to extract the lower and upper bound

    Returns:
        _type_: The lower and upper bound
    """
    return jax.tree.map(
        lambda x: x[0], bounds, is_leaf=lambda x: isinstance(x, tuple)
    ), jax.tree.map(lambda x: x[1], bounds, is_leaf=lambda x: isinstance(x, tuple))

inspeqtor.control.get_envelope

get_envelope(
    param: ParametersDictType, seq: ControlSequence
)

Return an envelope function create from envelope of all controls in seq with control parameter param

Parameters:

Name Type Description Default
param _type_

Control parameter

required
seq ControlSequence

Control Sequence

required

Returns:

Name Type Description
_type_

A function of time which is a sum of all envelope of control in seq with parameter param

Source code in src/inspeqtor/v2/control.py
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
def get_envelope(param: ParametersDictType, seq: ControlSequence):
    """Return an envelope function create from envelope of all controls in `seq` with control parameter `param`

    Args:
        param (_type_): Control parameter
        seq (ControlSequence): Control Sequence

    Returns:
        _type_: A function of time which is a sum of all envelope of control in `seq` with parameter `param`
    """
    tree = jax.tree.map(
        lambda ctrl, x: ctrl.get_envelope(x),
        seq.controls,
        param,
        is_leaf=lambda x: isinstance(x, BaseControl),
    )

    def envelope(t):
        return jax.tree.reduce(lambda value, fn: fn(t) + value, tree, initializer=0.0)

    return envelope

inspeqtor.control.envelope_fn

envelope_fn(
    param: ParametersDictType,
    t: ndarray,
    seq: ControlSequence,
)

Return an envelope of all of the control in control sequence seq given paramter param at time t

Parameters:

Name Type Description Default
param _type_

Control parameter

required
t ndarray

Time to evaluate the envelope

required
seq ControlSequence

The control sequence to get the envelope

required

Returns:

Name Type Description
_type_

Envelope of all control in seq evaluate at time t with parameter param

Source code in src/inspeqtor/v2/control.py
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
def envelope_fn(param: ParametersDictType, t: jnp.ndarray, seq: ControlSequence):
    """Return an envelope of all of the control in control sequence `seq` given paramter `param` at time `t`

    Args:
        param (_type_): Control parameter
        t (jnp.ndarray): Time to evaluate the envelope
        seq (ControlSequence): The control sequence to get the envelope

    Returns:
        _type_: Envelope of all control in `seq` evaluate at time `t` with parameter `param`
    """
    tree = jax.tree.map(
        lambda ctrl, x: ctrl.get_envelope(x)(t),
        seq.controls,
        param,
        is_leaf=lambda x: isinstance(x, BaseControl),
    )

    return jax.tree.reduce(jnp.add, tree)

Library

inspeqtor.control.library

inspeqtor.control.library.DragPulse dataclass

Source code in src/inspeqtor/v1/predefined.py
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
@dataclass
class DragPulse(BaseControl):
    duration: int
    beta: float
    qubit_drive_strength: float
    dt: float
    amp: float = 0.25

    min_theta: float = 0.0
    max_theta: float = 2 * jnp.pi
    final_amp: float = 1.0

    def __post_init__(self):
        self.t_eval = jnp.arange(self.duration)

    def get_bounds(self) -> tuple[ParametersDictType, ParametersDictType]:
        lower: ParametersDictType = {}
        upper: ParametersDictType = {}

        lower["theta"] = self.min_theta
        upper["theta"] = self.max_theta

        return lower, upper

    def get_waveform(self, params: ParametersDictType) -> jnp.ndarray:
        return self.get_envelope(params)(self.t_eval)

    def get_envelope(
        self, params: ParametersDictType
    ) -> typing.Callable[..., typing.Any]:
        area = (
            params["theta"] / (2 * jnp.pi * self.qubit_drive_strength)
        ) / self.dt  # NOTE: Choice of area is arbitrary e.g. pi pulse
        sigma = (1 * area) / (self.amp * jnp.sqrt(2 * jnp.pi))

        return drag_envelope_v2(
            amp=self.amp,
            sigma=sigma.astype(float),
            beta=self.beta,
            center=self.duration // 2,
            final_amp=self.final_amp,
        )

inspeqtor.control.library.DragPulseV2 dataclass

Source code in src/inspeqtor/v2/predefined.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
@dataclass
class DragPulseV2(BaseControl):
    duration: int
    qubit_drive_strength: float
    dt: float
    max_amp: float = 0.25

    min_theta: float = 0.0
    max_theta: float = 2 * jnp.pi

    min_beta: float = 0.0
    max_beta: float = 2.0

    def __post_init__(self):
        self.gaussian_pulse = GaussianPulse(
            duration=self.duration,
            qubit_drive_strength=self.qubit_drive_strength,
            dt=self.dt,
            max_amp=self.max_amp,
            min_theta=self.min_theta,
            max_theta=self.max_theta,
        )
        self.t_eval = self.gaussian_pulse.t_eval

    def get_bounds(
        self,
    ) -> tuple[ParametersDictType, ParametersDictType]:
        lower, upper = self.gaussian_pulse.get_bounds()

        lower["beta"] = self.min_beta
        upper["beta"] = self.max_beta

        return lower, upper

    def get_envelope(
        self, params: ParametersDictType
    ) -> typing.Callable[..., typing.Any]:
        # The area of Gaussian to be rotate to,
        area = (
            params["theta"] / self.gaussian_pulse.correction
        )  # NOTE: Choice of area is arbitrary e.g. pi pulse

        def real_component(t):
            return gaussian_envelope(
                amp=area,
                center=self.gaussian_pulse.center_position,
                sigma=self.gaussian_pulse.sigma,
            )(t)

        def envelope_fn(t):
            return real_component(t) + 1j * params["beta"] * jax.grad(real_component)(t)

        return envelope_fn

inspeqtor.control.library.MultiDragPulseV3 dataclass

Source code in src/inspeqtor/v1/predefined.py
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
@dataclass
class MultiDragPulseV3(BaseControl):
    duration: int
    order: int = 1
    amp_bound: list[list[float]] = field(default_factory=list)  # [[0.0, 1.0],]
    sigma_bound: list[list[float]] = field(default_factory=list)  # [[0.1, 5.0],]
    global_beta_bound: list[float] = field(default_factory=list)  # [-2.0, 2.0]

    def __post_init__(self):
        self.t_eval = jnp.arange(self.duration, dtype=jnp.float64)

    def get_bounds(self) -> tuple[ParametersDictType, ParametersDictType]:
        lower: ParametersDictType = {}
        upper: ParametersDictType = {}

        idx = 0
        for i in range(self.order):
            for j in range(i + 1):
                lower[f"{i}/{j}/amp"] = self.amp_bound[idx][0]
                lower[f"{i}/{j}/sigma"] = self.sigma_bound[idx][0]

                upper[f"{i}/{j}/amp"] = self.amp_bound[idx][1]
                upper[f"{i}/{j}/sigma"] = self.sigma_bound[idx][1]

                idx += 1

        lower["beta"] = self.global_beta_bound[0]
        upper["beta"] = self.global_beta_bound[1]

        return lower, upper

    def get_envelope(
        self, params: ParametersDictType
    ) -> typing.Callable[..., typing.Any]:
        return get_envelope(params, self.order, self.duration)

inspeqtor.control.library.GaussianPulse dataclass

Source code in src/inspeqtor/v2/predefined.py
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
@dataclass
class GaussianPulse(BaseControl):
    duration: int
    # beta: float
    qubit_drive_strength: float
    dt: float
    max_amp: float = 0.25

    min_theta: float = 0.0
    max_theta: float = 2 * jnp.pi

    def __post_init__(self):
        self.t_eval = jnp.arange(self.duration, dtype=jnp.float_)

        # This is the correction factor that will cancel the factor in the front of hamiltonian
        self.correction = 2 * jnp.pi * self.qubit_drive_strength * self.dt

        # The standard derivation of Gaussian pulse is keep fixed for the given max_amp
        self.sigma = jnp.sqrt(2 * jnp.pi) / (self.max_amp * self.correction)

        # The center position is set at the center of the duration
        self.center_position = self.duration // 2

    def get_bounds(
        self,
    ) -> tuple[ParametersDictType, ParametersDictType]:
        lower: ParametersDictType = {}
        upper: ParametersDictType = {}

        lower["theta"] = self.min_theta
        upper["theta"] = self.max_theta

        return lower, upper

    def get_envelope(
        self, params: ParametersDictType
    ) -> typing.Callable[..., typing.Any]:
        # The area of Gaussian to be rotate to,
        area = (
            params["theta"] / self.correction
        )  # NOTE: Choice of area is arbitrary e.g. pi pulse

        return gaussian_envelope(
            amp=area, center=self.center_position, sigma=self.sigma
        )

inspeqtor.control.library.TwoAxisGaussianPulse dataclass

Source code in src/inspeqtor/v1/predefined.py
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
@dataclass
class TwoAxisGaussianPulse(BaseControl):
    duration: int
    qubit_drive_strength: float
    dt: float
    max_amp: float = 0.25

    # Rotation angles for both axes
    min_theta_x: float = 0.0
    max_theta_x: float = 2 * jnp.pi
    min_theta_y: float = 0.0
    max_theta_y: float = 2 * jnp.pi

    def __post_init__(self):
        self.t_eval = jnp.arange(self.duration, dtype=jnp.float_)

        # Correction factor that will cancel the factor in the front of hamiltonian
        self.correction = 2 * jnp.pi * self.qubit_drive_strength * self.dt

        # The standard deviation of Gaussian pulse is kept fixed for the given max_amp
        self.sigma = jnp.sqrt(2 * jnp.pi) / (self.max_amp * self.correction)

        # The center position is set at the center of the duration
        self.center_position = self.duration // 2

    def get_bounds(
        self,
    ) -> tuple[ParametersDictType, ParametersDictType]:
        lower: ParametersDictType = {}
        upper: ParametersDictType = {}

        # Bounds for X-axis rotation
        lower["theta_x"] = self.min_theta_x
        upper["theta_x"] = self.max_theta_x

        # Bounds for Y-axis rotation
        lower["theta_y"] = self.min_theta_y
        upper["theta_y"] = self.max_theta_y

        return lower, upper

    def get_envelope(
        self, params: ParametersDictType
    ) -> typing.Callable[..., typing.Any]:
        # Calculate areas for both axes
        area_x = params["theta_x"] / self.correction
        area_y = params["theta_y"] / self.correction

        def envelope_fn(t):
            # Gaussian envelope for both axes
            x_axis = gaussian_envelope(
                amp=area_x, center=self.center_position, sigma=self.sigma
            )(t)

            y_axis = gaussian_envelope(
                amp=area_y, center=self.center_position, sigma=self.sigma
            )(t)

            # Return complex envelope with x and y components
            return x_axis + 1j * y_axis

        return envelope_fn

inspeqtor.control.library.gaussian_envelope

gaussian_envelope(amp, center, sigma)
Source code in src/inspeqtor/v1/predefined.py
100
101
102
103
104
105
106
def gaussian_envelope(amp, center, sigma):
    def g_fn(t):
        return (amp / (jnp.sqrt(2 * jnp.pi) * sigma)) * jnp.exp(
            -((t - center) ** 2) / (2 * sigma**2)
        )

    return g_fn

inspeqtor.control.library.polynomial_feature_map

polynomial_feature_map(x: ndarray, degree: int)
Source code in src/inspeqtor/v1/predefined.py
812
813
def polynomial_feature_map(x: jnp.ndarray, degree: int):
    return jnp.concatenate([x**i for i in range(1, degree + 1)], axis=-1)

inspeqtor.control.library.predefined_controls module-attribute

inspeqtor.control.library.default_control_reader module-attribute

default_control_reader = construct_control_sequence_reader(
    controls=predefined_controls
)

inspeqtor.control.library.get_drag_pulse_v2_sequence

get_drag_pulse_v2_sequence(
    qubit_info_drive_strength: float,
    max_amp: float = 0.5,
    min_theta=0.0,
    max_theta=2 * pi,
    min_beta=-2.0,
    max_beta=2.0,
    dt=2 / 9,
)

Get predefined DRAG control sequence with single DRAG pulse.

Parameters:

Name Type Description Default
qubit_info QubitInformation

Qubit information

required
max_amp float

The maximum amplitude. Defaults to 0.5.

0.5

Returns:

Name Type Description
ControlSequence

Control sequence instance

Source code in src/inspeqtor/v2/predefined.py
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
def get_drag_pulse_v2_sequence(
    qubit_info_drive_strength: float,
    max_amp: float = 0.5,  # NOTE: Choice of maximum amplitude is arbitrary
    min_theta=0.0,
    max_theta=2 * jnp.pi,
    min_beta=-2.0,
    max_beta=2.0,
    dt=2 / 9,
):
    """Get predefined DRAG control sequence with single DRAG pulse.

    Args:
        qubit_info (QubitInformation): Qubit information
        max_amp (float, optional): The maximum amplitude. Defaults to 0.5.

    Returns:
        ControlSequence: Control sequence instance
    """
    total_length = 320
    control_sequence = ControlSequence(
        controls={
            "0": DragPulseV2(
                duration=total_length,
                qubit_drive_strength=qubit_info_drive_strength,
                dt=dt,
                max_amp=max_amp,
                min_theta=min_theta,
                max_theta=max_theta,
                min_beta=min_beta,
                max_beta=max_beta,
            ),
        },
        total_dt=total_length,
    )

    return control_sequence

inspeqtor.control.library.drag_feature_map

drag_feature_map(
    x: ndarray,
    degree: int = 4,
    correction: tuple[float, ...] = (2 * pi, 10),
) -> ndarray
Source code in src/inspeqtor/v2/predefined.py
643
644
645
646
647
648
649
650
651
def drag_feature_map(
    x: jnp.ndarray, degree: int = 4, correction: tuple[float, ...] = (2 * jnp.pi, 10)
) -> jnp.ndarray:
    # For angle, we normalize by 2 pi
    x = x.at[..., 0].set(x[..., 0] / (correction[0]))
    # For beta, we have to shift and normalize later
    x = x.at[..., 1].set((x[..., 1] + 5) / correction[1])

    return polynomial_feature_map(x, degree=degree)