GNN Branching Policy#

In branch-and-bound (B&B), the solver repeatedly picks a fractional integer variable and splits the search into two children. Which variable it picks at each node — the branching rule — has an outsized effect on how large the search tree grows, often more than any other algorithmic choice [Lodi and Zarpellon, 2017]. A poor rule can blow the tree up exponentially; a good one keeps it small.

The gold standard for variable selection is strong branching: tentatively branch on every candidate, solve both child relaxations, and keep the variable that most improves the dual bound. Strong branching produces very small trees but is prohibitively expensive — it solves two relaxations per candidate, per node.

A now-classic line of work asks: can we learn a fast surrogate that imitates strong branching? Khalil et al. [2016] first cast branching as a learning-to-rank problem over hand-crafted features. Gasse et al. [2019] then showed that a graph convolutional neural network over the natural bipartite variable-constraint graph of a MILP can imitate strong branching with high fidelity and almost no inference cost, beating expert-designed heuristics on several benchmark families.

discopt implements exactly this idea as an optional branching policy. This notebook documents the architecture, shows how to switch it on through Model.solve(branching_policy="gnn"), and is honest about its current status.

import os

os.environ["JAX_PLATFORMS"] = "cpu"
os.environ["JAX_ENABLE_X64"] = "1"

import discopt.modeling as dm
from discopt._jax.problem_graph import build_graph
from discopt._jax.gnn_policy import (
    VAR_FEAT_DIM,
    CON_FEAT_DIM,
    HIDDEN_DIM,
    N_ROUNDS,
    select_branch_variable_gnn,
)

1. The bipartite-graph encoding#

Following Gasse et al. [2019], discopt represents the optimization problem at a B&B node as a bipartite graph with two kinds of nodes:

  • Variable nodes — one per decision variable. Each carries a feature vector [value, lb, ub, is_integer, fractionality, log_n_vars, log_n_cons] (the relaxation value, the node-local bounds, an integrality flag, how fractional the value is, and two problem-size features).

  • Constraint nodes — one per constraint, with features [lhs_value, slack, sense_encoding, log_n_vars, log_n_cons] (the constraint body value, its slack, and an encoding of <= / == / >=).

An edge connects variable i to constraint j whenever variable i appears in constraint j’s body. This is precisely the variable-constraint incidence structure of the model.

The size features (log_n_vars, log_n_cons) are attached to every node so a policy trained on small instances degrades gracefully on larger ones — a deliberate mitigation of train/test distribution mismatch.

This encoding is built by discopt._jax.problem_graph.build_graph. Let’s build a tiny model and inspect the resulting graph.

m = dm.Model("graph_demo")
a = m.integer("a", lb=0, ub=4)
b = m.integer("b", lb=0, ub=4)
m.maximize(a + 2 * b)
m.subject_to(a + b <= 5)
m.subject_to(2 * a + b <= 7)

# A pretend fractional relaxation solution to encode at this node.
import numpy as np

solution = np.array([2.5, 1.5])
graph = build_graph(m, solution)

print(f"variable nodes : {graph.n_vars}  (feature dim {graph.var_features.shape[1]})")
print(f"constraint nodes: {graph.n_cons}  (feature dim {graph.con_features.shape[1]})")
print(f"edges          : {graph.edge_indices.shape[1]}")
print()
print("variable features [value, lb, ub, is_integer, fractionality, log_nv, log_nc]:")
print(np.asarray(graph.var_features))
print()
print("constraint features [lhs, slack, sense, log_nv, log_nc]:")
print(np.asarray(graph.con_features))
variable nodes : 2  (feature dim 7)
constraint nodes: 2  (feature dim 5)
edges          : 4

variable features [value, lb, ub, is_integer, fractionality, log_nv, log_nc]:
[[2.5        0.         4.         1.         0.5        1.09861229
  1.09861229]
 [1.5        0.         4.         1.         0.5        1.09861229
  1.09861229]]

constraint features [lhs, slack, sense, log_nv, log_nc]:
[[-1.          1.         -1.          1.09861229  1.09861229]
 [-0.5         0.5        -1.          1.09861229  1.09861229]]

2. The scoring model and the branching hook#

On this bipartite graph, discopt runs a small message-passing GNN (discopt._jax.gnn_policy):

  1. Embed variable and constraint features into a hidden space (tanh MLPs).

  2. Message passing for N_ROUNDS rounds, each round propagating variable → constraint and then constraint → variable along the edges, aggregating by sum. This lets each variable’s representation absorb information about the constraints it participates in and their neighboring variables.

  3. Readout a single scalar branching score per variable node.

The variable with the highest score among the fractional integer candidates is the one the policy recommends. A larger Equinox-based variant (discopt._jax.gnn_branching.BranchingGNN) adds an imitation-learning training loop.

Crucially, the GNN only scores — it never changes the math. Inside the solver loop, for each open node the orchestrator builds the graph, asks the policy for a recommended variable index, and passes that as a branch hint to the Rust TreeManager. The Rust core still performs the actual branching, bound propagation, and correctness checks. The policy reorders the search; it cannot make the answer wrong.

print(f"VAR_FEAT_DIM = {VAR_FEAT_DIM}   # variable node feature width")
print(f"CON_FEAT_DIM = {CON_FEAT_DIM}   # constraint node feature width")
print(f"HIDDEN_DIM   = {HIDDEN_DIM}  # message-passing hidden width")
print(f"N_ROUNDS     = {N_ROUNDS}   # var->con->var message-passing rounds")

# Ask the policy for a recommended branch variable on the graph above.
# With params=None it falls back to most-fractional scoring (the default
# behaviour when no trained weights are supplied).
choice = select_branch_variable_gnn(graph, params=None)
print(f"\nrecommended branch variable index: {choice}")
VAR_FEAT_DIM = 7   # variable node feature width
CON_FEAT_DIM = 5   # constraint node feature width
HIDDEN_DIM   = 16  # message-passing hidden width
N_ROUNDS     = 2   # var->con->var message-passing rounds
recommended branch variable index: 0

3. Solving with the GNN policy#

The policy is selected per-solve with the branching_policy keyword. The default is "fractional" (most-fractional branching); pass "gnn" to route variable selection through the graph network. Both must return the same optimum — only the order in which nodes are explored can differ.

We use a small nonconvex MINLP that genuinely branches (its continuous relaxation is not integer-feasible), so the node count is meaningful.

def build_minlp():
    m = dm.Model("branch_demo")
    x = m.integer("x", lb=0, ub=5)
    y = m.integer("y", lb=0, ub=5)
    m.maximize(x * y + 2 * x)        # bilinear -> nonconvex
    m.subject_to(2 * x + 3 * y <= 12)
    m.subject_to(x + y <= 5)
    return m


results = {}
for policy in ["fractional", "gnn"]:
    res = build_minlp().solve(branching_policy=policy)
    results[policy] = res

print(f"{'policy':<12s} {'status':<10s} {'objective':>12s} {'nodes':>7s}")
print("-" * 44)
for policy, res in results.items():
    print(f"{policy:<12s} {res.status:<10s} {res.objective:>12.4f} {res.node_count:>7d}")
policy       status        objective   nodes
--------------------------------------------
fractional   optimal         12.0000       3
gnn          optimal         12.0000       3
# Correctness invariant: the policy changes search order, not the answer.
obj_frac = results["fractional"].objective
obj_gnn = results["gnn"].objective
assert abs(obj_frac - obj_gnn) < 1e-4, (obj_frac, obj_gnn)
print(f"Both policies reach the same optimum: {obj_gnn:.4f}  ✓")
Both policies reach the same optimum: 12.0000  ✓

On this instance both policies explore the same number of nodes. That is expected: discopt ships the GNN scaffold — graph encoder, message-passing forward pass, strong-branching data collector, and imitation-learning loop — but no pre-trained weights are bundled. In the solver’s "gnn" path the scorer is currently invoked with params=None, which falls back to most-fractional scoring, so the search coincides with the default on small problems. The value today is the plumbing: a verified, correctness-preserving hook where a trained policy drops in.

4. When to use it, status, and how training works#

When a learned policy helps. The payoff from imitation-of-strong-branching grows with tree size: it is most worthwhile on hard combinatorial MILPs where strong branching would shrink the tree dramatically but is too slow to run at every node, and where many similar instances are solved repeatedly so the offline training cost amortizes [Gasse et al., 2019, Lodi and Zarpellon, 2017]. For small or easy models the default "fractional" rule is already fine and has zero overhead.

How the policy is trained. discopt’s discopt._jax.gnn_branching module mirrors the Gasse et al. [2019] recipe:

  1. collect_strong_branching_data runs B&B and, at each node, computes the true strong-branching score for every fractional candidate (solving both child relaxations and scoring by the product of bound improvements, as in Gasse et al.). It records (graph, best_variable) pairs — the expert labels.

  2. train_branching_gnn fits the BranchingGNN by imitation learning: a cross-entropy loss that pushes the network’s softmax over candidate scores toward the strong-branching choice.

  3. The trained GNNBranchingPolicy then scores variables in well under a millisecond per node after JIT warmup — fast enough to call at every node.

Current status and limitations.

  • No weights ship with the package; the in-solver "gnn" path falls back to most-fractional until a model is trained and wired in, so today it is correctness-preserving but performance-neutral.

  • The encoding uses a deliberately small feature set and hidden width for fast inference; it is not yet tuned to match the published benchmark gains.

  • Like all imitation policies, quality is bounded by the strong-branching expert and by how well the training instances cover the target distribution [Khalil et al., 2016, Lodi and Zarpellon, 2017].

The takeaway: branching_policy="gnn" is a safe, optional hook built on the standard bipartite-GNN-imitates-strong-branching architecture — useful today for experimentation and ready to deliver speedups once trained weights are supplied.