"""
Core gamma-PRS algorithm for uniform graph colouring.
Implements the partial rejection sampling method from:
"Uniform Sampling of Graph Colourings via Soft Colouring and Partial
Rejection Sampling"
"""
import numpy as np
from collections import deque
from .utils import (
preprocess_graph, has_improper_edge_vec, connected_components_mask,
)
# ---------------------------------------------------------------------------
# Vectorised core: n_v, bad vertices, resampling set
# ---------------------------------------------------------------------------
[docs]
def compute_n_v_vec(G, colors, u_values, gamma):
"""Compute n_v(gamma, x) for all vertices using sparse matrix ops.
n_v = number of neighbours w with c_w = c_v and u_w > gamma^{d_w}.
"""
n, k, A, degrees = G['n'], G['k'], G['A'], G['degrees']
gamma_pow_d = gamma ** degrees
active = (u_values > gamma_pow_d).astype(np.float64)
n_v = np.zeros(n, dtype=np.int32)
for c in range(1, k + 1):
mask_c = (colors == c)
active_c = (mask_c & (active > 0.5)).astype(np.float64)
counts = np.asarray(A @ active_c).ravel()
n_v[mask_c] += counts[mask_c].astype(np.int32)
return n_v
[docs]
def find_bad_vertices_vec(G, colors, u_values, gamma, mask=None):
"""Find bad vertices: {v : u_v > gamma^{n_v}}.
Returns boolean array. If mask is given, only vertices in mask are
checked.
"""
n_v = compute_n_v_vec(G, colors, u_values, gamma)
thresholds = gamma ** n_v
bad = u_values > thresholds
if mask is not None:
bad &= mask
return bad
[docs]
def find_resampling_set_vec(G, colors, u_values, gamma, mask=None):
"""Algorithm 1: Find the resampling set R via BFS from bad vertices.
Starting from Bad(x, gamma), expand through non-passive vertices.
Include one layer of passive boundary vertices.
"""
bad = find_bad_vertices_vec(G, colors, u_values, gamma, mask)
if not np.any(bad):
return np.zeros(G['n'], dtype=bool)
n, adj, degrees = G['n'], G['adj'], G['degrees']
gamma_pow_d = gamma ** degrees
passive = u_values <= gamma_pow_d
R = bad.copy()
visited = bad.copy()
allowed = mask if mask is not None else np.ones(n, dtype=bool)
boundary = np.zeros(n, dtype=bool)
queue = deque(np.where(bad)[0])
while queue:
v = queue.popleft()
for w in adj[v]:
if allowed[w] and not visited[w]:
visited[w] = True
if passive[w]:
boundary[w] = True
else:
R[w] = True
queue.append(w)
R |= boundary
return R
[docs]
def resample_vertices_vec(colors, u_values, k, R_mask, rng):
"""Resample colours and u-values for vertices where R_mask is True."""
idx = np.where(R_mask)[0]
m = len(idx)
if m == 0:
return
colors[idx] = rng.integers(1, k + 1, size=m)
u_values[idx] = rng.random(size=m)
# ---------------------------------------------------------------------------
# Recursive gamma-PRS (Algorithm 4)
# ---------------------------------------------------------------------------
[docs]
def gamma_prs_recursive(G, colors, u_values, gamma_seq, ell, mask,
rng, stats, depth=0, max_depth=100000):
"""Algorithm 4: gamma-PRS(G, x, ell) -- recursive implementation."""
if depth > max_depth:
raise RecursionError(f"gamma-PRS exceeded max depth {max_depth}")
gamma_ell = gamma_seq[ell]
k = G['k']
while True:
bad = find_bad_vertices_vec(G, colors, u_values, gamma_ell, mask)
if not np.any(bad):
break
R = find_resampling_set_vec(G, colors, u_values, gamma_ell, mask)
if not np.any(R):
break
resample_vertices_vec(colors, u_values, k, R, rng)
stats['resample_count'] += 1
stats['vertices_resampled'] += int(np.sum(R))
components = connected_components_mask(G['adj'], R, G['n'])
for comp in components:
for j in range(ell + 1):
gamma_prs_recursive(G, colors, u_values, gamma_seq, j,
comp, rng, stats, depth + 1, max_depth)
# ---------------------------------------------------------------------------
# Iterative gamma-PRS at a single level
# ---------------------------------------------------------------------------
[docs]
def gamma_prs_iterative(G, colors, u_values, gamma, mask,
rng, stats, max_iter=10**7):
"""Iterative PRS at a single gamma level.
Repeatedly: find bad vertices -> compute R -> resample R,
until Bad(x, gamma) is empty.
"""
k = G['k']
for it in range(max_iter):
bad = find_bad_vertices_vec(G, colors, u_values, gamma, mask)
if not np.any(bad):
return
R = find_resampling_set_vec(G, colors, u_values, gamma, mask)
if not np.any(R):
return
resample_vertices_vec(colors, u_values, k, R, rng)
stats['resample_count'] += 1
stats['vertices_resampled'] += int(np.sum(R))
raise RuntimeError(f"Iterative PRS did not converge in {max_iter} iterations")
# ---------------------------------------------------------------------------
# Main entry point: proper colouring through PRS
# ---------------------------------------------------------------------------
[docs]
def prs_graph_coloring(graph, k, gamma_base=0.9, max_levels=1000,
seed=None, recursive=False):
"""Uniform sampling of a proper k-colouring via gamma-PRS.
Parameters
----------
graph : nx.Graph
The input graph.
k : int
Number of colours (must be >= chromatic number).
gamma_base : float
Base for the gamma-sequence: gamma_ell = gamma_base^ell.
max_levels : int
Safety limit on the number of levels.
seed : int or None
Random seed for reproducibility.
recursive : bool
If True, use the full recursive Algorithm 4.
If False, use iterative PRS at each level (faster, practical).
Returns
-------
colors : dict
A proper k-colouring mapping node -> colour (1-indexed).
stats : dict
Run statistics.
"""
rng = np.random.default_rng(seed)
G = preprocess_graph(graph, k)
n = G['n']
colors = rng.integers(1, k + 1, size=n)
u_values = rng.random(size=n).astype(np.float64)
gamma_seq = [gamma_base ** ell for ell in range(max_levels)]
stats = {
'levels': 0,
'resample_count': 0,
'vertices_resampled': 0,
}
ell = 0
all_mask = np.ones(n, dtype=bool)
while has_improper_edge_vec(colors, G['edge_pairs']):
if ell >= max_levels:
raise RuntimeError(f"Did not converge within {max_levels} levels")
if recursive:
gamma_prs_recursive(G, colors, u_values, gamma_seq, ell,
all_mask, rng, stats)
else:
gamma_prs_iterative(G, colors, u_values, gamma_seq[ell],
all_mask, rng, stats)
ell += 1
stats['levels'] = ell
color_dict = {G['node_list'][i]: int(colors[i]) for i in range(n)}
return color_dict, stats