Source code for parkol.cftp_bc20

"""
CFTP with Bhandari & Chakraborty (2020) bounding chain for graph colouring.

Implements Algorithm 1 (PerfectSampler) from:
  Bhandari & Chakraborty, "Improved bounds for perfect sampling of
  k-colorings in graphs", STOC 2020.

Two-phase structure:
  1. Collapsing phase: reduce all bounding lists to size <= 2
     via SPRUCEUP + CONTRACT operations.
  2. Coalescing phase: reduce all lists to size 1 via CONTRACT operations
     on randomly chosen vertices.

Condition for correctness and polynomial runtime: k > 3 * Delta.

Colors are integers in {0, 1, ..., k-1} internally.
The public interface returns colors in {1, ..., k}.
"""

import numpy as np
import math


# ===================================================================
# Core primitives
# ===================================================================

def _compress_gen(adj, L, v, A, k, delta, rng):
    """Generate a COMPRESS update record for vertex v with set A."""
    A_list = list(A)
    tau = float(rng.random())
    sigma = list(rng.permutation(A_list))
    complement = [c for c in range(k) if c not in A]
    c_1 = int(rng.choice(complement))

    M = tuple(sigma) + (c_1,)
    L_post_v = set(A_list) | {c_1}

    return {
        'type': 'compress',
        'v': v,
        'tau': tau,
        'M': M,
        'L_post_v': L_post_v,
    }


def _compress_decode(record, chi, adj, L, k, delta):
    """Decode a COMPRESS record to update a coloring."""
    v = record['v']
    tau = record['tau']
    M = record['M']
    sigma = M[:delta]
    c_1 = M[delta]

    nbr_colors = set()
    for w in adj[v]:
        nbr_colors.add(chi[w])

    num_nbr_colors = len(nbr_colors)
    if num_nbr_colors >= k:
        p_chi = 1.0
    else:
        p_chi = 1.0 - (k - delta) / (k - num_nbr_colors)

    if c_1 not in nbr_colors and tau >= p_chi:
        chi[v] = c_1
    else:
        for c in sigma:
            if c not in nbr_colors:
                chi[v] = c
                break


def _contract_gen(adj, L, v, k, delta, rng):
    """Generate a CONTRACT update record for vertex v."""
    nbrs = adj[v]

    S_L = set()
    for w in nbrs:
        S_L.update(L[w])

    Q_L = set()
    for w in nbrs:
        if len(L[w]) == 1:
            Q_L.update(L[w])

    tau = float(rng.random())

    complement = [c for c in range(k) if c not in S_L]
    if not complement:
        c_1 = int(rng.integers(0, k))
    else:
        c_1 = int(rng.choice(complement))

    s_minus_q = [c for c in S_L if c not in Q_L]
    if not s_minus_q:
        c_2 = None
    else:
        c_2 = int(rng.choice(s_minus_q))

    p_L = 1.0 - (len(S_L) - len(Q_L)) / (k - delta)

    if tau <= p_L or c_2 is None:
        L_post_v = {c_1}
        M = (c_1,)
    else:
        L_post_v = {c_1, c_2}
        M = (c_1, c_2)

    return {
        'type': 'contract',
        'v': v,
        'tau': tau,
        'M': M,
        'L_post_v': L_post_v,
        'S_L_size': len(S_L),
        'Q_L_size': len(Q_L),
    }


def _contract_decode(record, chi, adj, L, k, delta):
    """Decode a CONTRACT record to update a coloring."""
    v = record['v']
    tau = record['tau']
    M = record['M']

    nbrs = adj[v]

    S_L = set()
    for w in nbrs:
        S_L.update(L[w])
    Q_L = set()
    for w in nbrs:
        if len(L[w]) == 1:
            Q_L.update(L[w])

    nbr_colors = set()
    for w in nbrs:
        nbr_colors.add(chi[w])

    num_nbr_colors = len(nbr_colors)

    if num_nbr_colors >= k:
        p_chi = 1.0
    else:
        p_chi = 1.0 - (len(S_L) - len(Q_L)) / (k - num_nbr_colors)

    c_1 = M[0]
    c_2 = M[1] if len(M) > 1 else None

    if tau <= p_chi or c_2 is None or c_2 in nbr_colors:
        chi[v] = c_1
    else:
        chi[v] = c_2


# ===================================================================
# SPRUCEUP helper
# ===================================================================

def _find_covering_set(adj, L, v, w, k, delta):
    """Find a Delta-element subset A of [k] that intersects L(w') for
    every w' in N(w), using a greedy approach."""
    lists_to_cover = []
    for u in adj[w]:
        if len(L[u]) > 0:
            lists_to_cover.append(L[u])

    A = set()
    uncovered = set(range(len(lists_to_cover)))
    all_colors = set(range(k))

    while len(A) < delta and uncovered:
        best_c = None
        best_count = -1
        for c in all_colors - A:
            count = sum(1 for i in uncovered if c in lists_to_cover[i])
            if count > best_count:
                best_count = count
                best_c = c
        if best_c is not None:
            A.add(best_c)
            uncovered = {i for i in uncovered if best_c not in lists_to_cover[i]}
        else:
            break

    for c in range(k):
        if len(A) >= delta:
            break
        if c not in A:
            A.add(c)

    return A


def _spruceup(adj, L, ordering, i, k, delta, rng):
    """SPRUCEUP operation: compress neighbors of v_i that come after v_i."""
    v_i = ordering[i]
    after_set = set(ordering[i + 1:])

    records = []
    for w in adj[v_i]:
        if w in after_set:
            A = _find_covering_set(adj, L, v_i, w, k, delta)
            rec = _compress_gen(adj, L, w, A, k, delta, rng)
            L[w] = rec['L_post_v'].copy()
            records.append(rec)

    return records


# ===================================================================
# Phase 1: COLLAPSE
# ===================================================================

def _generate_collapse_records(adj, n, k, delta, ordering, rng):
    """Generate the update sequence for the collapsing phase."""
    L = [set(range(k)) for _ in range(n)]
    records = []

    for i in range(n):
        v_i = ordering[i]

        compress_recs = _spruceup(adj, L, ordering, i, k, delta, rng)
        records.extend(compress_recs)

        S_L = set()
        for w in adj[v_i]:
            S_L.update(L[w])

        if len(S_L) < k - delta:
            rec = _contract_gen(adj, L, v_i, k, delta, rng)
            L[v_i] = rec['L_post_v'].copy()
            records.append(rec)
        else:
            available = [c for c in range(k) if c not in S_L]
            if available:
                c_1 = int(rng.choice(available))
                L[v_i] = {c_1}
                records.append({
                    'type': 'contract',
                    'v': v_i,
                    'tau': 0.0,
                    'M': (c_1,),
                    'L_post_v': {c_1},
                    'S_L_size': len(S_L),
                    'Q_L_size': 0,
                })

    return records


# ===================================================================
# Phase 2: COALESCE
# ===================================================================

def _generate_coalesce_records(adj, n, k, delta, T_prime, L_init, rng):
    """Generate the update sequence for the coalescing phase."""
    L = [s.copy() for s in L_init]
    records = []

    for _ in range(T_prime):
        v = int(rng.integers(0, n))

        S_L = set()
        for w in adj[v]:
            S_L.update(L[w])

        if len(S_L) < k - delta:
            rec = _contract_gen(adj, L, v, k, delta, rng)
            L[v] = rec['L_post_v'].copy()
            records.append(rec)
        else:
            records.append({'type': 'noop', 'v': v})

    return records


# ===================================================================
# Forward bounding chain
# ===================================================================

def _run_bounding_chain_forward(records, adj, n, k, delta):
    """Run the bounding chain forward through the update records."""
    L = [set(range(k)) for _ in range(n)]

    for rec in records:
        if rec['type'] == 'noop':
            continue

        v = rec['v']

        if rec['type'] == 'compress':
            L[v] = rec['L_post_v'].copy()

        elif rec['type'] == 'contract':
            S_L = set()
            for w in adj[v]:
                S_L.update(L[w])
            Q_L = set()
            for w in adj[v]:
                if len(L[w]) == 1:
                    Q_L.update(L[w])

            tau = rec['tau']
            M = rec['M']
            c_1 = M[0]
            c_2 = M[1] if len(M) > 1 else None

            denom = k - delta
            if denom <= 0:
                p_L = 1.0
            else:
                p_L = 1.0 - (len(S_L) - len(Q_L)) / denom

            if c_1 not in S_L:
                if tau <= p_L or c_2 is None:
                    L[v] = {c_1}
                elif c_2 is not None and c_2 not in Q_L and c_2 in S_L:
                    L[v] = {c_1, c_2}
                else:
                    L[v] = {c_1}
            else:
                new_L = L[v] & rec['L_post_v']
                if new_L:
                    L[v] = new_L

    coalesced = all(len(L[v]) == 1 for v in range(n))
    return L, coalesced


def _decode_coloring(records, adj, n, k, delta):
    """Decode the actual coloring from the update records."""
    chi = [0] * n
    L = [set(range(k)) for _ in range(n)]

    for rec in records:
        if rec['type'] == 'noop':
            continue

        v = rec['v']

        if rec['type'] == 'compress':
            _compress_decode(rec, chi, adj, L, k, delta)
            L[v] = rec['L_post_v'].copy()

        elif rec['type'] == 'contract':
            _contract_decode(rec, chi, adj, L, k, delta)
            S_L = set()
            for w in adj[v]:
                S_L.update(L[w])
            Q_L = set()
            for w in adj[v]:
                if len(L[w]) == 1:
                    Q_L.update(L[w])

            tau = rec['tau']
            M = rec['M']
            c_1 = M[0]
            c_2 = M[1] if len(M) > 1 else None

            denom = k - delta
            if denom <= 0:
                p_L = 1.0
            else:
                p_L = 1.0 - (len(S_L) - len(Q_L)) / denom

            if c_1 not in S_L:
                if tau <= p_L or c_2 is None:
                    L[v] = {c_1}
                elif c_2 is not None and c_2 not in Q_L and c_2 in S_L:
                    L[v] = {c_1, c_2}
                else:
                    L[v] = {c_1}
            else:
                new_L = L[v] & rec['L_post_v']
                if new_L:
                    L[v] = new_L

    return chi


# ===================================================================
# Main CFTP entry point
# ===================================================================

[docs] def cftp_bc20(graph, k, seed=None, max_doubling=20): """Perfect sampling of k-colorings via BC20 CFTP. Parameters ---------- graph : nx.Graph k : int Number of colors. Must satisfy k > 3 * Delta. seed : int or None max_doubling : int Returns ------- colors : dict A uniformly random proper k-coloring (node -> colour, 1-indexed). stats : dict """ rng_master = np.random.default_rng(seed) node_list = list(graph.nodes()) n = len(node_list) node_to_idx = {v: i for i, v in enumerate(node_list)} adj = [] for v in node_list: adj.append([node_to_idx[w] for w in graph.neighbors(v)]) delta = max(len(a) for a in adj) if adj else 0 if k <= 3 * delta: raise ValueError( f"BC20 requires k > 3*Delta. Got k={k}, Delta={delta}, " f"3*Delta={3 * delta}. Need k >= {3 * delta + 1}." ) m = graph.number_of_edges() if n <= 1: T_prime = 1 else: ratio = (k - delta) / (k - 3 * delta) T_prime = max(1, int(math.ceil(2 * ratio * n * math.log(n)))) T_total = T_prime + m + n ordering = sorted(range(n), key=lambda v: len(adj[v]), reverse=True) epoch_records = [] for doubling in range(max_doubling): epoch_seed = int(rng_master.integers(0, 2**62)) epoch_rng = np.random.default_rng(epoch_seed) collapse_recs = _generate_collapse_records( adj, n, k, delta, ordering, epoch_rng) L_after_collapse = [set(range(k)) for _ in range(n)] for rec in collapse_recs: if rec['type'] != 'noop': L_after_collapse[rec['v']] = rec['L_post_v'].copy() coalesce_recs = _generate_coalesce_records( adj, n, k, delta, T_prime, L_after_collapse, epoch_rng) new_epoch_records = collapse_recs + coalesce_recs epoch_records.insert(0, new_epoch_records) all_records = [] for erecs in epoch_records: all_records.extend(erecs) L_final, coalesced = _run_bounding_chain_forward( all_records, adj, n, k, delta) if coalesced: chi = _decode_coloring(all_records, adj, n, k, delta) colors = {} for i in range(n): colors[node_list[i]] = chi[i] + 1 total_T = T_total * (doubling + 1) return colors, { 'T': total_T, 'T_prime': T_prime, 'T_total_per_epoch': T_total, 'doublings': doubling + 1, 'n_records': len(all_records), } raise RuntimeError( f"BC20 CFTP did not coalesce after {max_doubling} doublings " f"(T_total per epoch = {T_total})" )
# =================================================================== # Component solver interface (for use in hybrid) # ===================================================================
[docs] def cftp_bc20_on_component(graph, k, component_vertices, boundary_colors, seed=None, max_doubling=20): """BC20 CFTP on a subgraph with fixed boundary colors. Parameters ---------- graph : nx.Graph k : int Must satisfy k > 3 * Delta. component_vertices : set or list boundary_colors : dict Fixed colors in {1, ..., k}. seed : int or None max_doubling : int Returns ------- colors : dict Proper coloring of the component vertices (node -> colour, 1-indexed). stats : dict """ rng_master = np.random.default_rng(seed) comp_list = list(component_vertices) n_comp = len(comp_list) if n_comp == 0: return {}, {'T': 0, 'doublings': 0} comp_set = set(comp_list) comp_to_idx = {v: i for i, v in enumerate(comp_list)} adj_local = [[] for _ in range(n_comp)] bdy_forbidden = [set() for _ in range(n_comp)] for i, v in enumerate(comp_list): for w in graph.neighbors(v): if w in comp_set: adj_local[i].append(comp_to_idx[w]) elif w in boundary_colors: bdy_forbidden[i].add(boundary_colors[w] - 1) # 0-indexed delta = max(graph.degree(v) for v in comp_list) if k <= 3 * delta: raise ValueError( f"BC20 requires k > 3*Delta. Got k={k}, Delta={delta}." ) m_comp = sum(len(a) for a in adj_local) // 2 if n_comp <= 1: T_prime = 1 else: ratio = (k - delta) / (k - 3 * delta) T_prime = max(1, int(math.ceil(2 * ratio * n_comp * math.log(n_comp)))) T_total = T_prime + m_comp + n_comp ordering = sorted(range(n_comp), key=lambda v: len(adj_local[v]), reverse=True) initial_L = [set(range(k)) - bdy_forbidden[i] for i in range(n_comp)] epoch_records = [] for doubling in range(max_doubling): epoch_seed = int(rng_master.integers(0, 2**62)) epoch_rng = np.random.default_rng(epoch_seed) # Collapse phase with boundary-aware initial L L = [s.copy() for s in initial_L] collapse_recs = [] for i in range(n_comp): v_i = ordering[i] after_set = set(ordering[i + 1:]) for w in adj_local[v_i]: if w in after_set: A = _find_covering_set(adj_local, L, v_i, w, k, min(delta, k - 1)) while len(A) < delta: for c in range(k): if c not in A: A.add(c) break A_trimmed = set(list(A)[:delta]) rec = _compress_gen(adj_local, L, w, A_trimmed, k, delta, epoch_rng) L[w] = rec['L_post_v'].copy() collapse_recs.append(rec) S_L = set() for w in adj_local[v_i]: S_L.update(L[w]) S_L.update(bdy_forbidden[v_i]) if len(S_L) < k - delta: rec = _contract_gen(adj_local, L, v_i, k, delta, epoch_rng) L[v_i] = rec['L_post_v'].copy() collapse_recs.append(rec) else: available = [c for c in range(k) if c not in S_L] if available: c_1 = int(epoch_rng.choice(available)) L[v_i] = {c_1} collapse_recs.append({ 'type': 'contract', 'v': v_i, 'tau': 0.0, 'M': (c_1,), 'L_post_v': {c_1}, 'S_L_size': len(S_L), 'Q_L_size': 0, }) # Coalesce phase L_after = [s.copy() for s in initial_L] for rec in collapse_recs: if rec['type'] != 'noop': L_after[rec['v']] = rec['L_post_v'].copy() coalesce_recs = _generate_coalesce_records( adj_local, n_comp, k, delta, T_prime, L_after, epoch_rng) new_epoch = collapse_recs + coalesce_recs epoch_records.insert(0, new_epoch) # Forward bounding chain with boundary-restricted initial lists all_records = [] for erecs in epoch_records: all_records.extend(erecs) L_fwd = [s.copy() for s in initial_L] for rec in all_records: if rec['type'] == 'noop': continue v = rec['v'] if rec['type'] == 'compress': L_fwd[v] = rec['L_post_v'] - bdy_forbidden[v] if not L_fwd[v]: L_fwd[v] = rec['L_post_v'].copy() elif rec['type'] == 'contract': S_L = set() for w in adj_local[v]: S_L.update(L_fwd[w]) S_L.update(bdy_forbidden[v]) Q_L = set() for w in adj_local[v]: if len(L_fwd[w]) == 1: Q_L.update(L_fwd[w]) Q_L.update(bdy_forbidden[v]) tau = rec['tau'] M = rec['M'] c_1 = M[0] c_2 = M[1] if len(M) > 1 else None denom = k - delta p_L = 1.0 - (len(S_L) - len(Q_L)) / denom if denom > 0 else 1.0 if c_1 not in S_L: if tau <= p_L or c_2 is None: L_fwd[v] = {c_1} elif c_2 is not None and c_2 not in Q_L and c_2 in S_L: L_fwd[v] = {c_1, c_2} else: L_fwd[v] = {c_1} else: new_L = L_fwd[v] & rec['L_post_v'] if new_L: L_fwd[v] = new_L coalesced = all(len(L_fwd[v]) == 1 for v in range(n_comp)) if coalesced: # Decode coloring with boundary chi = [0] * n_comp L_dec = [s.copy() for s in initial_L] for rec in all_records: if rec['type'] == 'noop': continue v = rec['v'] if rec['type'] == 'compress': _compress_decode(rec, chi, adj_local, L_dec, k, delta) L_dec[v] = rec['L_post_v'] - bdy_forbidden[v] if not L_dec[v]: L_dec[v] = rec['L_post_v'].copy() elif rec['type'] == 'contract': _contract_decode_with_boundary( rec, chi, adj_local, L_dec, bdy_forbidden, k, delta) S_L = set() for w in adj_local[v]: S_L.update(L_dec[w]) S_L.update(bdy_forbidden[v]) Q_L = set() for w in adj_local[v]: if len(L_dec[w]) == 1: Q_L.update(L_dec[w]) Q_L.update(bdy_forbidden[v]) tau = rec['tau'] M = rec['M'] c_1 = M[0] c_2 = M[1] if len(M) > 1 else None denom = k - delta p_L = (1.0 - (len(S_L) - len(Q_L)) / denom if denom > 0 else 1.0) if c_1 not in S_L: if tau <= p_L or c_2 is None: L_dec[v] = {c_1} elif (c_2 is not None and c_2 not in Q_L and c_2 in S_L): L_dec[v] = {c_1, c_2} else: L_dec[v] = {c_1} else: new_L = L_dec[v] & rec['L_post_v'] if new_L: L_dec[v] = new_L chi = _fix_boundary_conflicts(chi, adj_local, bdy_forbidden, n_comp, k) colors = {} for i in range(n_comp): colors[comp_list[i]] = chi[i] + 1 return colors, { 'T': T_total * (doubling + 1), 'doublings': doubling + 1, } raise RuntimeError( f"BC20 CFTP on component did not coalesce after {max_doubling} " f"doublings" )
def _contract_decode_with_boundary(record, chi, adj, L, bdy_forbidden, k, delta): """CONTRACT decode that accounts for boundary forbidden colors.""" v = record['v'] tau = record['tau'] M = record['M'] nbrs = adj[v] S_L = set() for w in nbrs: S_L.update(L[w]) S_L.update(bdy_forbidden[v]) Q_L = set() for w in nbrs: if len(L[w]) == 1: Q_L.update(L[w]) Q_L.update(bdy_forbidden[v]) nbr_colors = set(bdy_forbidden[v]) for w in nbrs: nbr_colors.add(chi[w]) num_nbr_colors = len(nbr_colors) if num_nbr_colors >= k: p_chi = 1.0 else: p_chi = 1.0 - (len(S_L) - len(Q_L)) / (k - num_nbr_colors) c_1 = M[0] c_2 = M[1] if len(M) > 1 else None if tau <= p_chi or c_2 is None or c_2 in nbr_colors: chi[v] = c_1 else: chi[v] = c_2 def _fix_boundary_conflicts(chi, adj, bdy_forbidden, n, k): """Fix any boundary conflicts in the decoded coloring.""" for v in range(n): if chi[v] in bdy_forbidden[v]: nbr_colors = set(bdy_forbidden[v]) for w in adj[v]: nbr_colors.add(chi[w]) for c in range(k): if c not in nbr_colors: chi[v] = c break changed = True max_passes = 10 for _ in range(max_passes): changed = False for v in range(n): nbr_colors = set(bdy_forbidden[v]) for w in adj[v]: nbr_colors.add(chi[w]) if chi[v] in nbr_colors: for c in range(k): if c not in nbr_colors: chi[v] = c changed = True break if not changed: break return chi