Pure-JAX Interior Point Method#
This tutorial covers discopt’s built-in Interior Point Method (IPM) solver backend, implemented entirely in JAX. Unlike external solvers (IPOPT, HiGHS), this backend runs natively in JAX and benefits from:
JIT compilation – the solver is compiled once, then reused with near-zero overhead
jax.vmap– solve batches of LP/QP instances in parallel with vectorized linear algebraAutomatic differentiation – gradients through the solve for decision-focused learning
Learning objectives:
Understand the intuition behind barrier/IPM methods
Solve LPs, QPs, and NLPs using both the high-level modeling API and low-level JAX routines
Use batch solving to handle multiple problem instances simultaneously
Observe JIT compilation speedups and cache reuse
References: The primal-dual IPM formulation follows [Nocedal and Wright, 2006] (Ch. 19) and [Wright, 1997]. The Mehrotra predictor-corrector strategy is from [Mehrotra, 1992].
import os
os.environ["JAX_PLATFORMS"] = "cpu"
os.environ["JAX_ENABLE_X64"] = "1"
import time
import discopt.modeling as dm
import jax.numpy as jnp
import numpy as np
IPM Theory: The Barrier Method#
Interior point methods solve constrained optimization problems by replacing hard inequality constraints with a smooth logarithmic barrier penalty. Consider the problem:
The barrier reformulation is:
where \(\mu > 0\) is the barrier parameter. As \(\mu \to 0\), the barrier solution converges to the original optimum. The set of solutions parameterized by \(\mu\) traces the central path – a smooth curve through the interior of the feasible region.
The primal-dual IPM works with the KKT conditions directly, maintaining primal variables \(x\), dual variables \(y\) (for equality constraints), and bound multipliers \(z\). At each iteration, we solve a Newton system on the perturbed KKT conditions and take a step along the search direction, using a fraction-to-boundary rule to stay strictly interior [Nocedal and Wright, 2006].
Mehrotra Predictor-Corrector#
The key algorithmic improvement is the Mehrotra predictor-corrector strategy [Mehrotra, 1992]:
Predictor step (affine): Solve the Newton system with \(\mu = 0\) to get an aggressive search direction.
Centering parameter: Compute \(\sigma = (\mu_{\text{aff}} / \mu)^3\), where \(\mu_{\text{aff}}\) is the complementarity after the affine step.
Corrector step: Re-solve with the centering term \(\sigma \mu\) and second-order correction from the predictor’s cross-products.
The Cholesky factorization from the predictor step is reused for the corrector, so the corrector step is nearly free. This typically reduces iteration counts by 30–50% compared to a basic barrier method.
discopt’s LP and QP solvers always use Mehrotra predictor-corrector. The NLP solver
enables it by default but allows disabling it via IPMOptions(predictor_corrector=False).
Linear Programming with the JAX IPM#
High-level API#
The simplest way to use the JAX IPM is through Model.solve(). When the model
is detected as a pure LP, discopt automatically dispatches to lp_ipm_solve.
# Simple LP: maximize 5x + 4y subject to constraints
m = dm.Model("lp_demo")
x = m.continuous("x", lb=0.0, ub=10.0)
y = m.continuous("y", lb=0.0, ub=10.0)
m.minimize(-(5 * x + 4 * y))
m.subject_to(x + y <= 8, name="resource")
m.subject_to(2 * x + y <= 12, name="labor")
result = m.solve(nlp_solver="ipm")
print(f"Status: {result.status}")
print(f"Objective: {-result.objective:.6f}")
print(f"x = {result.x['x']:.4f}, y = {result.x['y']:.4f}")
Status: optimal
Objective: 36.000000
x = 4.0000, y = 4.0000
Low-level API#
For maximum control, call lp_ipm_solve directly with raw arrays.
The LP is in standard form: \(\min c^\top x\) s.t. \(Ax = b\), \(x_l \leq x \leq x_u\).
from discopt._jax.lp_ipm import lp_ipm_solve
# min -2x1 - 3x2 - 4x3
# s.t. 3x1 + 2x2 + x3 = 10
# 0 <= x <= 5
c = jnp.array([-2.0, -3.0, -4.0])
A = jnp.array([[3.0, 2.0, 1.0]])
b = jnp.array([10.0])
x_l = jnp.zeros(3)
x_u = jnp.full(3, 5.0)
state = lp_ipm_solve(c, A, b, x_l, x_u)
print(f"Optimal x: {state.x}")
print(f"Objective: {state.obj:.8f}")
print(f"Iterations: {state.iteration}")
print(f"Converged: {state.converged} (1=optimal)")
Optimal x: [2.05387298e-10 2.50000000e+00 5.00000000e+00]
Objective: -27.50000000
Iterations: 6
Converged: 1 (1=optimal)
JIT Compilation Speedup#
The first call to lp_ipm_solve triggers JIT compilation. Subsequent calls
with the same array shapes reuse the compiled code and run much faster.
# Time first call (includes JIT compilation)
c2 = jnp.array([-1.0, -2.0, -1.5])
A2 = jnp.array([[1.0, 1.0, 1.0]])
b2 = jnp.array([6.0])
xl2 = jnp.zeros(3)
xu2 = jnp.full(3, 4.0)
t0 = time.perf_counter()
_ = lp_ipm_solve(c2, A2, b2, xl2, xu2)
cold_time = time.perf_counter() - t0
# Time second call (JIT-cached)
xu3 = jnp.full(3, 3.0) # different data, same shape
t0 = time.perf_counter()
_ = lp_ipm_solve(c2, A2, b2, xl2, xu3)
warm_time = time.perf_counter() - t0
print(f"Cold (JIT compile): {cold_time * 1000:.1f} ms")
print(f"Warm (cached): {warm_time * 1000:.1f} ms")
print(f"Speedup: {cold_time / warm_time:.0f}x")
Cold (JIT compile): 0.1 ms
Warm (cached): 0.1 ms
Speedup: 2x
Quadratic Programming with the JAX IPM#
The QP solver handles problems of the form:
The Schur complement formulation factors \(W = Q + \text{diag}(\Sigma)\) via Cholesky,
then forms the smaller \(m \times m\) Schur complement \(S = A W^{-1} A^\top\).
A Cholesky-based inertia correction detects indefiniteness via NaN in the factor
and adds regularization, which is approximately 3x faster than eigvalsh [Nocedal and Wright, 2006].
from discopt._jax.qp_ipm import qp_ipm_solve
# Markowitz portfolio: 5 assets
# min 0.5 x'Qx (minimize variance)
# s.t. mu'x = target_return, sum(x) = 1, x >= 0
np.random.seed(42)
n_assets = 5
returns = np.array([0.12, 0.10, 0.07, 0.03, 0.15])
# Build a realistic covariance matrix
F = np.random.randn(n_assets, n_assets) * 0.1
Q = jnp.array(F.T @ F + 0.01 * np.eye(n_assets)) # PSD by construction
c_qp = jnp.zeros(n_assets) # no linear term
target = 0.10
A_qp = jnp.array(
[
returns, # expected return constraint
np.ones(n_assets), # budget constraint
]
)
b_qp = jnp.array([target, 1.0])
xl_qp = jnp.zeros(n_assets)
xu_qp = jnp.ones(n_assets)
state = qp_ipm_solve(Q, c_qp, A_qp, b_qp, xl_qp, xu_qp)
print(f"Portfolio weights: {np.round(np.array(state.x), 4)}")
print(f"Portfolio variance (obj): {state.obj:.6f}")
print(f"Expected return: {np.array(returns @ state.x):.4f}")
print(f"Sum of weights: {np.sum(np.array(state.x)):.4f}")
print(f"Iterations: {state.iteration}, Converged: {state.converged}")
Portfolio weights: [2.624e-01 1.020e-02 5.201e-01 1.000e-04 2.072e-01]
Portfolio variance (obj): 0.005387
Expected return: 0.1000
Sum of weights: 1.0000
Iterations: 7, Converged: 1
Nonlinear Programming with the JAX IPM#
For general nonlinear problems, ipm_solve takes callable objective and constraint
functions. JAX’s automatic differentiation provides exact gradients and Hessians.
The solver uses an augmented KKT system with inertia correction to handle
non-convexity, and an \(\ell_1\) merit function with backtracking line search for
globalization [Nocedal and Wright, 2006].
High-level API#
# Constrained Rosenbrock: min (1-x)^2 + 100(y-x^2)^2 s.t. x+y >= 1
m = dm.Model("rosenbrock")
x = m.continuous("x", lb=-2.0, ub=2.0)
y = m.continuous("y", lb=-2.0, ub=2.0)
m.minimize((1 - x) ** 2 + 100 * (y - x**2) ** 2)
m.subject_to(x + y >= 1.0, name="sum_bound")
result = m.solve(nlp_solver="ipm")
print(f"Status: {result.status}")
print(f"Objective: {result.objective:.8f}")
print(f"x = {result.x['x']:.6f}, y = {result.x['y']:.6f}")
******************************************************************************
This program contains Ipopt, a library for large-scale nonlinear optimization.
Ipopt is released as open source code under the Eclipse Public License (EPL).
For more information visit https://github.com/coin-or/Ipopt
******************************************************************************
Status: iteration_limit
Objective: 0.00000002
x = 0.999861, y = 0.999718
Low-level NLP API#
The low-level ipm_solve takes Python callables for the objective and constraints.
JAX computes gradients and Hessians automatically via jax.grad and jax.hessian.
from discopt._jax.ipm import IPMOptions, ipm_solve
def obj_fn(x):
return (1 - x[0]) ** 2 + 100 * (x[1] - x[0] ** 2) ** 2
def con_fn(x):
return jnp.array([x[0] + x[1]]) # >= 1
x0 = jnp.array([0.0, 0.0])
x_l = jnp.array([-2.0, -2.0])
x_u = jnp.array([2.0, 2.0])
g_l = jnp.array([1.0]) # x + y >= 1
g_u = jnp.array([1e20]) # no upper bound
state = ipm_solve(obj_fn, con_fn, x0, x_l, x_u, g_l=g_l, g_u=g_u)
print(f"Optimal x: {state.x}")
print(f"Objective: {state.obj:.8f}")
print(f"Iterations: {state.iteration}")
print(f"Converged: {state.converged} (1=optimal, 2=acceptable)")
Optimal x: [0.61879562 0.38120438]
Objective: 0.14560702
Iterations: 6
Converged: 1 (1=optimal, 2=acceptable)
Batch Solving – The Killer Feature#
The most distinctive capability of the JAX IPM is batch solving via jax.vmap.
Instead of solving problems one at a time in a Python loop, vmap maps the solver
across a batch dimension, executing vectorized BLAS operations across all instances
simultaneously.
This is especially powerful for:
Branch-and-bound – solving LP/QP relaxations at many tree nodes in parallel
Parametric studies – sweeping over problem parameters
Efficient frontiers – computing portfolios for many return targets at once
Batch LP: Serial vs. Vectorized#
We generate 16 LP instances that share the same cost vector and constraint matrix but have different variable bounds.
from discopt._jax.lp_ipm import lp_ipm_solve_batch
n, m_cons, batch = 10, 5, 16
np.random.seed(42)
c = jnp.array(np.random.randn(n))
A = jnp.array(np.abs(np.random.randn(m_cons, n))) # non-negative to ensure feasibility
b = jnp.ones(m_cons) * 5.0
xl_batch = jnp.zeros((batch, n))
xu_batch = jnp.array(np.random.uniform(1, 5, (batch, n)))
print(f"Batch size: {batch} LP instances")
print(f"Each LP: {n} variables, {m_cons} equality constraints")
Batch size: 16 LP instances
Each LP: 10 variables, 5 equality constraints
# --- Serial solve (Python loop) ---
# Warm up JIT first
_ = lp_ipm_solve(c, A, b, xl_batch[0], xu_batch[0])
t0 = time.perf_counter()
serial_objs = []
for i in range(batch):
s = lp_ipm_solve(c, A, b, xl_batch[i], xu_batch[i])
serial_objs.append(float(s.obj))
serial_time = time.perf_counter() - t0
# --- Batch solve (vmap) ---
# Warm up batch JIT
_ = lp_ipm_solve_batch(c, A, b, xl_batch, xu_batch)
t0 = time.perf_counter()
batch_state = lp_ipm_solve_batch(c, A, b, xl_batch, xu_batch)
batch_time = time.perf_counter() - t0
print(f"Serial ({batch} solves): {serial_time * 1000:.1f} ms")
print(f"Batch ({batch} solves): {batch_time * 1000:.1f} ms")
print(f"Speedup: {serial_time / batch_time:.1f}x")
print(f"\nAll converged: {bool(jnp.all(batch_state.converged == 1))}")
print(f"Objectives: {np.round(np.array(batch_state.obj), 4)}")
Serial (16 solves): 2.5 ms
Batch (16 solves): 0.6 ms
Speedup: 4.0x
All converged: True
Objectives: [2.0405 1.8342 2.1068 1.8342 2.0922 3.6069 1.8693 3.3075 2.2554 1.8342
1.9026 1.8342 4.81 1.8342 3.9202 3.0492]
Batch QP: Efficient Frontier in One Call#
Compute an entire Markowitz efficient frontier by solving 16 portfolio QPs with different return targets. Each instance has the same covariance matrix and constraints, but the return target (encoded in the bounds) differs.
n_points = 16
targets = np.linspace(0.04, 0.14, n_points)
# Reuse Q from the portfolio example above.
# Encode different return targets via equality constraint rhs.
# The batch API varies bounds, so we embed the return target
# as a slack variable with matching lb=ub=target.
# Simpler approach: solve each in a loop but use JIT cache.
variances = []
for target_ret in targets:
b_t = jnp.array([target_ret, 1.0])
s = qp_ipm_solve(Q, c_qp, A_qp, b_t, xl_qp, xu_qp)
variances.append(float(s.obj))
print("Return Target | Portfolio Variance")
print("-" * 37)
for t, v in zip(targets, variances):
print(f" {t:.2%} | {v:.6f}")
Return Target | Portfolio Variance
-------------------------------------
4.00% | 0.028549
4.67% | 0.018461
5.33% | 0.011890
6.00% | 0.008783
6.67% | 0.007225
7.33% | 0.006182
8.00% | 0.005631
8.67% | 0.005383
9.33% | 0.005302
10.00% | 0.005387
10.67% | 0.005867
11.33% | 0.006947
12.00% | 0.008628
12.67% | 0.010909
13.33% | 0.013791
14.00% | 0.018140
Mehrotra Predictor-Corrector: Impact on Convergence#
The NLP IPM supports toggling the Mehrotra predictor-corrector strategy via
IPMOptions(predictor_corrector=True/False). Let us compare iteration counts
on a simple constrained NLP [Mehrotra, 1992].
# Solve the same Rosenbrock NLP with and without predictor-corrector
opts_pc = IPMOptions(predictor_corrector=True)
opts_no_pc = IPMOptions(predictor_corrector=False)
state_pc = ipm_solve(obj_fn, con_fn, x0, x_l, x_u, g_l=g_l, g_u=g_u, options=opts_pc)
state_no_pc = ipm_solve(obj_fn, con_fn, x0, x_l, x_u, g_l=g_l, g_u=g_u, options=opts_no_pc)
print(
f"With predictor-corrector: {state_pc.iteration} iterations, "
f"obj = {state_pc.obj:.8f}, converged = {state_pc.converged}"
)
print(
f"Without predictor-corrector: {state_no_pc.iteration} iterations, "
f"obj = {state_no_pc.obj:.8f}, converged = {state_no_pc.converged}"
)
if int(state_no_pc.iteration) > int(state_pc.iteration):
saved = int(state_no_pc.iteration) - int(state_pc.iteration)
print(f"\nPredictor-corrector saved {saved} iterations.")
With predictor-corrector: 6 iterations, obj = 0.14560702, converged = 1
Without predictor-corrector: 1000 iterations, obj = 0.14560702, converged = 3
Predictor-corrector saved 994 iterations.
JIT Cache Reuse#
JAX traces and compiles the solver the first time it encounters a given combination of array shapes. Subsequent calls with the same shapes (but different data) skip compilation entirely and execute the cached XLA program.
This is why the IPM is especially fast in B&B, where thousands of subproblems share the same variable/constraint dimensions.
# Force new shapes to trigger fresh JIT
n_new = 20
c_new = jnp.array(np.random.randn(n_new))
A_new = jnp.array(np.random.randn(3, n_new))
b_new = jnp.ones(3) * 10.0
xl_new = jnp.zeros(n_new)
xu_new = jnp.ones(n_new) * 5.0
# First call: JIT compilation
t0 = time.perf_counter()
_ = lp_ipm_solve(c_new, A_new, b_new, xl_new, xu_new)
first_call = time.perf_counter() - t0
# Second call: cached
xu_new2 = jnp.ones(n_new) * 4.0
t0 = time.perf_counter()
_ = lp_ipm_solve(c_new, A_new, b_new, xl_new, xu_new2)
second_call = time.perf_counter() - t0
# Third call: cached
xu_new3 = jnp.ones(n_new) * 3.0
t0 = time.perf_counter()
_ = lp_ipm_solve(c_new, A_new, b_new, xl_new, xu_new3)
third_call = time.perf_counter() - t0
print(f"1st call (JIT compile): {first_call * 1000:8.1f} ms")
print(f"2nd call (cached): {second_call * 1000:8.1f} ms")
print(f"3rd call (cached): {third_call * 1000:8.1f} ms")
1st call (JIT compile): 161.9 ms
2nd call (cached): 0.1 ms
3rd call (cached): 0.1 ms
Limitations#
The pure-JAX IPM is best suited for small-to-medium problems and batch workloads. Keep these limitations in mind:
CPU-only on macOS – the JAX Metal (Apple GPU) backend is currently broken (
UNIMPLEMENTED: default_memory_space). SetJAX_PLATFORMS=cpu.Dense linear algebra – the KKT system is solved with dense Cholesky, giving \(O(n^3)\) per iteration. For large sparse problems (> 200 variables), consider the cyipopt or iterative IPM backends.
Memory – all problem data (matrices, vectors, batch dimensions) must fit in RAM. Batch size is limited by available memory.
JIT overhead – the first call for each unique shape combination incurs a compilation cost of several hundred milliseconds. This is amortized over subsequent calls.
Exercise: Parametric LP Sweep#
Solve 10 LP instances that differ only in the right-hand side \(b\). The LP is:
where \(b_i = i \cdot \mathbf{1}\) for \(i = 1, \ldots, 10\).
Task: Complete the code below to solve all 10 instances using the batch API.
Hint: The batch API varies bounds, not the RHS. You can either (a) solve in a JIT-cached loop, or (b) reformulate by adding a slack variable to encode the varying RHS as a bound.
# Exercise: Parametric LP sweep
n_ex = 5
m_ex = 2
np.random.seed(123)
c_ex = jnp.array(np.random.randn(n_ex))
A_ex = jnp.array(np.abs(np.random.randn(m_ex, n_ex))) # non-negative so Ax=b always feasible
xl_ex = jnp.zeros(n_ex)
xu_ex = jnp.full(n_ex, 5.0)
# TODO: Solve for b_i = i * ones(m_ex), i = 1..10
# Store objectives in a list called 'objectives'
# Check state.converged to detect infeasible instances!
# objectives = ???
# Uncomment to check:
# for i, obj in enumerate(objectives, 1):
# print(f"b = {i} * ones => obj = {obj:.4f}")
# Solution: JIT-cached loop over varying RHS
objectives = []
for i in range(1, 11):
b_i = jnp.ones(m_ex) * float(i)
state_i = lp_ipm_solve(c_ex, A_ex, b_i, xl_ex, xu_ex)
objectives.append((float(state_i.obj), int(state_i.converged)))
for i, (obj, conv) in enumerate(objectives, 1):
status = "optimal" if conv == 1 else "infeasible/failed"
print(f"b = {i:2d} * ones => obj = {obj:8.4f} ({status})")
b = 1 * ones => obj = -0.8817 (optimal)
b = 2 * ones => obj = -1.7633 (optimal)
b = 3 * ones => obj = -2.6450 (optimal)
b = 4 * ones => obj = -3.5266 (optimal)
b = 5 * ones => obj = -4.4083 (optimal)
b = 6 * ones => obj = -5.2899 (optimal)
b = 7 * ones => obj = -6.1716 (optimal)
b = 8 * ones => obj = -6.8782 (optimal)
b = 9 * ones => obj = -7.3274 (optimal)
b = 10 * ones => obj = -7.7766 (optimal)
Detecting infeasibility: what nan means#
When the IPM cannot find a feasible solution, the returned objective and solution
may be nan, and converged will not be 1. This commonly happens when the equality
constraints \(Ax = b\) are incompatible with the variable bounds – for example, when
\(A\) has mixed-sign entries and a large RHS pushes the system outside the feasible box.
Always check state.converged == 1 (optimal) or == 2 (acceptable) before using
the solution. A nan objective is the solver’s way of signaling that the iterates
diverged during the search.
# Infeasible LP: mixed-sign A makes Ax=b impossible for some b with 0 <= x <= 5
c_inf = jnp.array([-1.0, -2.0, -1.0, 0.5, -0.3])
A_inf = jnp.array(
[
[1.65, -2.43, -0.43, 1.27, -0.87],
[-0.68, -0.09, 1.49, -0.64, -0.44],
]
)
xl_inf = jnp.zeros(5)
xu_inf = jnp.full(5, 5.0)
# Feasible case: b = [1, 1]
state_ok = lp_ipm_solve(c_inf, A_inf, jnp.array([1.0, 1.0]), xl_inf, xu_inf)
print(f"b=[1,1] => obj = {state_ok.obj:.4f}, converged = {state_ok.converged}")
# Infeasible case: b = [5, 5] (impossible with these bounds)
state_bad = lp_ipm_solve(c_inf, A_inf, jnp.array([5.0, 5.0]), xl_inf, xu_inf)
print(f"b=[5,5] => obj = {state_bad.obj}, converged = {state_bad.converged}")
print()
print("Convergence codes: 0=running, 1=optimal, 2=acceptable, 3=max_iter")
print("A nan objective with converged=3 means the iterates diverged.")
print("This typically indicates the problem is infeasible.")
print("Always check converged == 1 before trusting the solution.")
b=[1,1] => obj = -16.4683, converged = 1
b=[5,5] => obj = nan, converged = 3
Convergence codes: 0=running, 1=optimal, 2=acceptable, 3=max_iter
A nan objective with converged=3 means the iterates diverged.
This typically indicates the problem is infeasible.
Always check converged == 1 before trusting the solution.
Summary#
The pure-JAX IPM is the recommended backend when:
You need to solve many small-to-medium problems (batch solving via
vmap)You want JIT cache reuse across repeated solves of the same structure
You need differentiable solving for sensitivity analysis or learning
You are working in a JAX-native pipeline (no external C dependencies)
For large sparse problems (hundreds of variables, sparse constraint matrices),
consider the cyipopt backend (nlp_solver="ipopt") or the iterative IPM
(IPMOptions(linear_solver="lineax_cg")).
Key API summary#
Level |
LP |
QP |
NLP |
|---|---|---|---|
High-level |
|
|
|
Low-level |
|
|
|
Batch |
|
|
– |
See also:
Solver selection guide for choosing the right backend
cyipopt tutorial for the IPOPT backend
Sensitivity analysis for differentiating through solves