Batch IPM (vmap) vs Sequential Ipopt#
discopt’s pure-JAX IPM backend [Nocedal and Wright, 2006] supports jax.vmap for solving multiple NLP
relaxations simultaneously — one per B&B node in a batch. This notebook
compares:
Mode |
Description |
|---|---|
Batch IPM |
|
Sequential Ipopt |
|
The batch IPM solves all nodes in a batch as a single vectorized operation, which is especially advantageous on GPU.
import os
os.environ["JAX_PLATFORMS"] = "cpu"
os.environ["JAX_ENABLE_X64"] = "1"
import time
import discopt.modeling as dm
import jax
import numpy as np
print(f"JAX devices: {jax.devices()}")
print("discopt loaded successfully")
JAX devices: [CpuDevice(id=0)]
discopt loaded successfully
1. Direct vmap Comparison#
First, let’s compare solving N independent NLP relaxations:
vmap’d IPM: solve all N simultaneously with
solve_nlp_batchSequential Ipopt: solve each one individually via cyipopt
import jax.numpy as jnp
from discopt._jax.ipm import IPMOptions, solve_nlp_batch
from discopt._jax.nlp_evaluator import NLPEvaluator
from discopt.solvers.nlp_ipopt import solve_nlp as solve_nlp_ipopt
def make_rosenbrock_model():
m = dm.Model("rosenbrock")
x = m.continuous("x", lb=-5, ub=5)
y = m.continuous("y", lb=-5, ub=5)
m.minimize((1 - x) ** 2 + 100 * (y - x**2) ** 2)
return m
def make_constrained_model():
m = dm.Model("constrained")
x = m.continuous("x", lb=-5, ub=5)
y = m.continuous("y", lb=-5, ub=5)
m.minimize(x**2 + y**2)
m.subject_to(x + y >= 1)
return m
def make_exp_model():
m = dm.Model("exp_nlp")
x = m.continuous("x", lb=-2, ub=2)
y = m.continuous("y", lb=-2, ub=2)
m.minimize(dm.exp(x) + y**2)
m.subject_to(x + y >= 1)
return m
print("Problem builders defined")
Problem builders defined
def benchmark_batch_vs_sequential(make_fn, batch_sizes, n_repeats=3):
"""Compare vmap batch IPM vs sequential Ipopt for varying batch sizes."""
m = make_fn()
ev = NLPEvaluator(m)
n = ev.n_variables
lb, ub = ev.variable_bounds
obj_fn = ev._obj_fn
cons_fn = ev._cons_fn
m_cons = ev.n_constraints
# Constraint bounds for ipopt
if m_cons > 0:
from discopt.solvers.nlp_ipopt import _infer_constraint_bounds
cl, cu = _infer_constraint_bounds(m)
g_l = jnp.array(cl, dtype=jnp.float64)
g_u = jnp.array(cu, dtype=jnp.float64)
constraint_bounds = list(zip(cl.tolist(), cu.tolist()))
else:
g_l = None
g_u = None
constraint_bounds = None
results = []
for batch_size in batch_sizes:
# Generate random starting points within bounds
rng = np.random.default_rng(42)
lb_clip = np.clip(lb, -5, 5)
ub_clip = np.clip(ub, -5, 5)
x0_batch = rng.uniform(lb_clip, ub_clip, size=(batch_size, n))
# Also generate random tightened bounds (simulating B&B nodes)
xl_batch = np.maximum(lb, lb_clip + rng.uniform(0, 0.5, size=(batch_size, n)))
xu_batch = np.minimum(ub, ub_clip - rng.uniform(0, 0.5, size=(batch_size, n)))
xu_batch = np.maximum(xu_batch, xl_batch + 0.1) # ensure feasible
ipm_opts = IPMOptions(max_iter=200)
# --- Batch IPM (vmap) ---
x0_jax = jnp.array(x0_batch, dtype=jnp.float64)
xl_jax = jnp.array(xl_batch, dtype=jnp.float64)
xu_jax = jnp.array(xu_batch, dtype=jnp.float64)
# Warm-up JIT
_ = solve_nlp_batch(obj_fn, cons_fn, x0_jax, xl_jax, xu_jax, g_l, g_u, ipm_opts)
batch_times = []
for _ in range(n_repeats):
t0 = time.perf_counter()
state = solve_nlp_batch(obj_fn, cons_fn, x0_jax, xl_jax, xu_jax, g_l, g_u, ipm_opts)
jax.block_until_ready(state.x)
batch_times.append(time.perf_counter() - t0)
batch_time = np.median(batch_times)
batch_objs = np.asarray(state.obj)
# --- Sequential Ipopt ---
seq_times = []
for _ in range(n_repeats):
t0 = time.perf_counter()
seq_objs = []
for i in range(batch_size):
res = solve_nlp_ipopt(
ev,
x0_batch[i],
constraint_bounds=constraint_bounds,
options={"print_level": 0, "max_iter": 200},
)
seq_objs.append(res.objective if res.objective is not None else np.nan)
seq_times.append(time.perf_counter() - t0)
seq_time = np.median(seq_times)
speedup = seq_time / batch_time if batch_time > 0 else 0
results.append(
{
"batch_size": batch_size,
"batch_time": batch_time,
"seq_time": seq_time,
"speedup": speedup,
"batch_obj_mean": float(np.nanmean(batch_objs)),
"seq_obj_mean": float(np.nanmean(seq_objs)),
}
)
return results
print("Benchmark function defined")
Benchmark function defined
batch_sizes = [1, 2, 4, 8, 16, 32]
test_cases = [
("Rosenbrock (unconstrained)", make_rosenbrock_model),
("Constrained quadratic", make_constrained_model),
("Exponential NLP", make_exp_model),
]
for name, make_fn in test_cases:
print(f"\n{'=' * 60}")
print(f" {name}")
print(f"{'=' * 60}")
results = benchmark_batch_vs_sequential(make_fn, batch_sizes)
header = f"{'Batch':>6s} {'vmap IPM':>10s} {'Seq Ipopt':>10s} {'Speedup':>8s}"
print(header)
print("-" * len(header))
for r in results:
print(
f"{r['batch_size']:>6d} "
f"{r['batch_time']:.4f}s "
f"{r['seq_time']:.4f}s "
f"{r['speedup']:.2f}x"
)
============================================================
Rosenbrock (unconstrained)
============================================================
******************************************************************************
This program contains Ipopt, a library for large-scale nonlinear optimization.
Ipopt is released as open source code under the Eclipse Public License (EPL).
For more information visit https://github.com/coin-or/Ipopt
******************************************************************************
Batch vmap IPM Seq Ipopt Speedup
-------------------------------------
1 0.3181s 0.0030s 0.01x
2 0.3022s 0.0061s 0.02x
4 0.3120s 0.0156s 0.05x
8 0.3199s 0.0292s 0.09x
16 0.3280s 0.0541s 0.16x
32 0.3282s 0.1224s 0.37x
============================================================
Constrained quadratic
============================================================
Batch vmap IPM Seq Ipopt Speedup
-------------------------------------
1 0.4695s 0.0015s 0.00x
2 0.5151s 0.0026s 0.01x
4 0.5153s 0.0050s 0.01x
8 0.5164s 0.0099s 0.02x
16 0.5489s 0.0207s 0.04x
32 0.5329s 0.0406s 0.08x
============================================================
Exponential NLP
============================================================
Batch vmap IPM Seq Ipopt Speedup
-------------------------------------
1 0.4820s 0.0015s 0.00x
2 0.5345s 0.0027s 0.00x
4 0.5199s 0.0052s 0.01x
8 0.5329s 0.0117s 0.02x
16 0.5920s 0.0236s 0.04x
32 0.5902s 0.0454s 0.08x
2. B&B Solve Comparison#
Now compare end-to-end MINLP solving where the batch IPM advantage compounds over many B&B iterations.
def make_minlp_1():
"""MINLP with 3 binary + 2 continuous vars."""
m = dm.Model("minlp1")
x = m.continuous("x", lb=0, ub=5)
y = m.continuous("y", lb=0, ub=5)
z1 = m.binary("z1")
z2 = m.binary("z2")
z3 = m.binary("z3")
m.minimize(x**2 + y**2 + z1 + 2 * z2 + 3 * z3)
m.subject_to(x + y >= 1)
m.subject_to(x + z1 + z2 + z3 >= 2)
return m
def make_minlp_2():
"""MINLP with integer variable."""
m = dm.Model("minlp2")
x = m.continuous("x", lb=0, ub=5)
y = m.continuous("y", lb=0, ub=5)
n = m.integer("n", lb=0, ub=5)
m.minimize((x - 2) ** 2 + (y - 3) ** 2 + (n - 1.5) ** 2)
m.subject_to(x + y + n <= 8)
return m
minlp_test_cases = [
("3-binary MINLP", make_minlp_1),
("Integer MINLP", make_minlp_2),
]
for name, make_fn in minlp_test_cases:
print(f"\n--- {name} ---")
for backend, batch_sz in [("ipm", 1), ("ipm", 8), ("ipm", 16), ("ipopt", 1)]:
m = make_fn()
t0 = time.perf_counter()
r = m.solve(
nlp_solver=backend,
batch_size=batch_sz,
max_nodes=1000,
time_limit=60,
)
elapsed = time.perf_counter() - t0
label = f"{backend}(batch={batch_sz})"
print(
f" {label:<20s} obj={r.objective:>10.4f} "
f"status={r.status:<12s} nodes={r.node_count:>4d} "
f"time={elapsed:.3f}s rust={r.rust_time:.3f}s jax={r.jax_time:.3f}s"
)
--- 3-binary MINLP ---
ipm(batch=1) obj= 2.0000 status=optimal nodes= 3 time=0.247s rust=0.000s jax=0.206s
ipm(batch=8) obj= 2.0000 status=optimal nodes= 3 time=0.238s rust=0.000s jax=0.236s
ipm(batch=16) obj= 2.0000 status=optimal nodes= 3 time=0.003s rust=0.000s jax=0.001s
ipopt(batch=1) obj= 2.0000 status=optimal nodes= 3 time=0.002s rust=0.000s jax=0.001s
--- Integer MINLP ---
ipm(batch=1) obj= 0.2500 status=optimal nodes= 3 time=0.859s rust=0.000s jax=0.179s
ipm(batch=8) obj= 0.2500 status=optimal nodes= 3 time=0.267s rust=0.000s jax=0.243s
ipm(batch=16) obj= 0.2500 status=optimal nodes= 3 time=0.050s rust=0.000s jax=0.025s
ipopt(batch=1) obj= 0.2500 status=optimal nodes= 3 time=0.050s rust=0.000s jax=0.026s
Summary#
Key observations:
vmap batch IPM amortizes JIT overhead across all nodes in a batch, giving near-constant time for small-to-medium batch sizes.
Sequential Ipopt scales linearly with batch size (each node is a separate cyipopt call with Python overhead).
For small problems on CPU, Ipopt’s per-solve speed may dominate. The vmap advantage grows with problem size and is much more pronounced on GPU.
In B&B, the batch IPM with
batch_size > 1reduces wall time by solving multiple node relaxations simultaneously.