Source code for qrisp.alg_primitives.unbalanced_w_state

"""********************************************************************************
* Copyright (c) 2026 the Qrisp authors
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* This Source Code may also be made available under the following Secondary
* Licenses when the conditions for such availability set forth in the Eclipse
* Public License, v. 2.0 are satisfied: GNU General Public License, version 2
* with the GNU Classpath Exception which is
* available at https://www.gnu.org/software/classpath/license.html.
*
* SPDX-License-Identifier: EPL-2.0 OR GPL-2.0 WITH Classpath-exception-2.0
********************************************************************************
"""

from collections.abc import Sequence

import jax.numpy as jnp
import numpy as np

from qrisp.circuit import Qubit
from qrisp.core import QuantumVariable, p, x, xxyy
from qrisp.jasp import check_for_tracing_mode, jrange
from qrisp.typing import NDArrayLike


[docs] def unbalanced_w_state(qv: QuantumVariable | Sequence[Qubit], amplitudes: NDArrayLike) -> None: r"""Prepare a generalized W state, i.e. an unbalanced Dicke state of Hamming weight 1, on the given :ref:`QuantumVariable`. The resulting quantum state is .. math:: |\psi\rangle \;=\; \sum_{i=0}^{n-1} a_i \,|e_i\rangle where :math:`|e_i\rangle` is the computational basis state with a single ``1`` at position :math:`i`, and :math:`a_i` are the (possibly complex) amplitudes given by ``amplitudes``. The input array is automatically normalized so that :math:`\sum_i |a_i|^2 = 1`. Parameters ---------- qv : QuantumVariable A freshly allocated :ref:`QuantumVariable` in the :math:`|0\dots0\rangle` state whose size matches ``len(amplitudes)``. amplitudes : NDArrayLike A 1-D sequence of complex (or real) target amplitudes, one per qubit. Its length must be equal ``qv.size``. Raises ------ ValueError If ``len(amplitudes) != qv.size`` or if the amplitude vector is zero. Notes ----- **Algorithm.** The circuit distributes a single excitation across all qubits using a linear chain of ``XXYY`` gates: 0. Precompute all required :math:`\theta_i` angles using :math:`r_i = \sqrt{ \sum_{ j = i }^{ n - 1 }{ |a_j| ^ 2 } }` and :math:`\theta_i = 2\arccos(|a_i|\,/\,r_i)`, where :math:`r_i` is the precomputed remaining (undistributed) amplitude magnitude. 1. Apply ``X`` to qubit 0, producing :math:`|10\dots0\rangle`. 2. For each qubit :math:`i = 0, \dots, n{-}2`: a. Apply ``XXYY(θ, π/2)`` on qubits :math:`(i,\, i{+}1)`, using the precomputed angles :math:`\theta = \theta_i`. In the single-excitation subspace this acts as a parametrized partial swap, leaving magnitude :math:`|a_i|` on qubit :math:`i` and passing the rest to qubit :math:`i{+}1`. b. Apply a phase gate :math:`P(\arg a_i)` on qubit :math:`i` to imprint the correct complex phase. 3. Apply :math:`P(\arg a_{n-1})` on the last qubit. **Resources.** The circuit uses :math:`n{-}1` ``XXYY`` gates (each decomposable into 2 CNOTs + single-qubit rotations) and :math:`n` phase gates, yielding :math:`\mathcal{O}(n)` depth and gate count. Examples -------- >>> import numpy as np >>> from qrisp import QuantumVariable, unbalanced_w_state >>> a = np.array([1j, 2, 3, 4]) >>> qv = QuantumVariable(4) >>> unbalanced_w_state(qv, a) >>> print(qv.qs.statevector()) """ a = jnp.asarray(amplitudes, dtype=complex) n = a.shape[0] # Use the static shape of amplitudes if not check_for_tracing_mode(): if len(qv) != n: raise ValueError(f"Length of amplitudes ({n}) must match qv.size ({len(qv)}).") if np.linalg.norm(np.asarray(amplitudes, dtype=complex)) < 1e-15: raise ValueError("Amplitude vector must be non-zero.") # Normalize so that <a|a> = 1 norm = jnp.sqrt(jnp.vdot(a, a).real) a = a / norm abs_a = jnp.abs(a) phases = jnp.angle(a) # Explicitly handle single-qubit case if n == 1: x(qv[0]) p(phases[0], qv[0]) return # --- Step 0: Precomputing angles # Precompute remaining values following: # r_i = \sqrt{ \sum_{ j = i }^{ n - 1 }{ |a_j| ^ 2 } } # abs_a_squared : [a0^2, a1^2, a2^2, a3^2] # flip : [a3^2, a2^2, a1^2, a0^2] # cumsum : [a3^2, a3^2 + a2^2, a3^2 + a2^2 + a1^2 , a3^2 + a2^2 + a1^2 + a0^2] # flip : [a3^2 + a2^2 + a1^2 + a0^2, a3^2 + a2^2 + a1^2 , a3^2 + a2^2 , a3^2] abs_a_squared = abs_a**2 remaining_arr = jnp.sqrt(jnp.flip(jnp.cumsum(jnp.flip(abs_a_squared)))) # Calculate rations for arccos. Replace 0/0 division by 1 # for `arccos(1) = 0` to do nothing. numerators = abs_a[:-1] # Strip one last fraction, as num_{n-1} / rem_{n-1} is not needed. denominators = remaining_arr[:-1] denominators_no_zeroes = jnp.where(denominators > 1e-15, denominators, 1.0) # remaining_arr = 0 only when abs_a = 0, so it is safe. ratio_arr = jnp.where(denominators > 1e-15, numerators / denominators_no_zeroes, 1.0) # Get precomputed angles. Choose θ so that cos(θ/2) = |a_i| / remaining # i.e. qubit i retains exactly magnitude |a_i| theta_arr = 2 * jnp.arccos(jnp.clip(ratio_arr, -1.0, 1.0)) # --- Step 1: place the single excitation on qubit 0 ---> |10...0> x(qv[0]) # --- Step 2: redistribute amplitude along the qubit chain for i in jrange(n - 1): # XXYY(θ, π/2) performs a parametrized partial swap in the # single-excitation subspace {|01>, |10>}: # |10> -> cos(θ/2)|10> - sin(θ/2)|01> xxyy(theta_arr[i], jnp.pi / 2, qv[i], qv[i + 1]) # Imprint the complex phase of a_i onto qubit i p(phases[i], qv[i]) # --- Step 3: imprint the phase on the last qubit p(phases[n - 1], qv[n - 1])