qml.labs.dla.run_opt¶
- run_opt(cost, theta, n_epochs=500, optimizer=None, verbose=False, interrupt_tol=None)[source]¶
Boilerplate jax optimization
- Parameters:
cost (callable) – Cost function with scalar valued real output
theta (Iterable) – Initial values for argument of
costn_epochs (int) – Number of optimization iterations
optimizer (optax.GradientTransformation) –
optaxoptimizer. Default isoptax.adam(learning_rate=0.1).verbose (bool) – Whether progress is output during optimization
interrupt_tol (float) – If not None, interrupt the optimization if the norm of the gradient is smaller than
interrupt_tol.
Example
from pennylane.labs.dla import run_opt import jax import jax.numpy as jnp import optax jax.config.update("jax_enable_x64", True) def cost(x): return x**2 x0 = jnp.array(0.4) thetas, energy, gradients = run_opt(cost, x0)
When no
optimizeris passed, we useoptax.adam(learning_rate=0.1). We can also use other optimizers, likeoptax.lbfgs.>>> optimizer = optax.lbfgs(learning_rate=0.1, memory_size=1000) >>> thetas, energy, gradients = run_opt(cost, x0, optimizer=optimizer)
code/api/api/pennylane.labs.dla.run_opt
Download Python script
Download Notebook
View on GitHub