Viewing revision from 2025-11-12 at 20:10 EST (a9157dd99c00). Return to current version

anna a

home | contact | blog | notes | random | recommendations | colophon | resume | support

Nondeterminism monad in Python

started 2025-11-10 at 17:17 EST | updated 2025-11-12 at 20:10 EST | written in Maryland | 4563 words

started 2025-11-10 at 17:17 EST updated 2025-11-12 at 20:10 EST written in Maryland 4563 words

tags: haskell, monads, python, generators, tricks

This post contains some fairly lengthy exposition; if you want to skip right to the content promised by the title, go to ## General nondeterminism with generators

Many of the most curious and useful features of functional programming languages like Haskell derive from their ability (often unencumbered by the norms and constraints of industrial software engineering) to restate common algorithmic problems in novel ways – i.e., to perform a change of basis into a domain more suited to the problem. One such frame shift (or rather, category of such) is widely known as declarative programming (as opposed to imperative or functional programming, for example), and concerns programming languages, libraries, and techniques based on stating the problem domain or constraint system, as well as the desired objective or target (the “what”), at a high level and leaving the low-level algorithmic details to the optimizer or runtime (the “how”). In some cases this may take the form of a domain-specific optimization or constraint solving library; other times it is integrated more tightly with a language’s semantics and execution model.

One self-contained and useful tool from this paradigm is “nondeterminism” (this is a somewhat overloaded term, but in this case I am not talking about the kind of nondeterminism that people mention with respect to e.g., reproducibility of software artifacts or experiments). The premise is that we delineate the ways in which a program can branch or the alternatives to be selected between (potentially with nesting, recursion, and other complications) and search/“solve” for some solution among the set of possible paths or choices. That is to say, the nondeterminism interface should abstract away the construction of the search space and execution of the search procedure to some extent; the programmer need only be concerned with which choices are available and how they interact (e.g., how state evolves over time depending on the branch taken at each step).

The classical implementation of this scheme is described in chapter 4.3 of the well-known Structure and Interpretation of Computer Programs: McCarthy’s amb (“ambiguous”) operator, usually implemented in Lisp. I will omit the details of the implementation since there are decent explanations elsewhere, but the idea rhymes with what I’ve described above: we consider “splitting” execution into different paths corresponding to available combinations of initial values, and return results only from the path (or paths) where execution “succeeded” (some specified criteria was met).

Nondeterminism in the list monad

Haskell has what I would consider a somewhat more principled implementation of nondeterminism using monads. In particular, the built-in list type forms a monad, with \xs f -> concat (map f xs) as the >>= (bind) operation (and the singleton list constructor, i.e., return x = [x], as return/pure). This means that the resulting list will be constructed by passing each element in xs to f to yield a new list, then concatenating the results:

ghci> [1, 2, 3] >>= (\x -> [x * 2, x * 3])
[2,3,4,6,6,9]

(see this Wikibooks entry for a more detailed explanation)

As you may expect, this means that we can stack multiple “branching” operations to recursively expand every possible path:

ghci> [1, 2, 3] >>= (\x -> [x * 2, x * 3]) >>= (\y -> [y + 4, 0])
[6,0,7,0,8,0,10,0,10,0,13,0]

The equivalent code in do-notation looks like this:

xs = do
    x <- [1, 2, 3]
    y <- [x * 2, x * 3]
    z <- [y + 4, 0]
    return z

Perhaps now it is clear how this is useful: we can trivially iterate over all combinations of choices for x, y, and z merely by specifying what the choices are, foregoing cumbersome and unscalable combinatorics logic. If you’re familiar with Haskell’s list comprehension notation, this is indeed syntax sugar for the list monad:

ghci> [((x, y), x * y) | x <- [1, 2, 3], y <- [4, 5, 6]]
[((1,4),4),((1,5),5),((1,6),6),((2,4),8),((2,5),10),((2,6),12),((3,4),12),((3,5),15),((3,6),18)]

ghci> [1, 2, 3] >>= (\x -> [4, 5, 6] >>= (\y -> [((x, y), x * y)]))
[((1,4),4),((1,5),5),((1,6),6),((2,4),8),((2,5),10),((2,6),12),((3,4),12),((3,5),15),((3,6),18)]

We can do more interesting things as well; Control.Monad exports a function called guard with the following (quite general) definition:

guard True  = pure ()
guard False = empty

…which for [], specializes to:

ghci> (guard True) :: [()]
[()]
ghci> (guard False) :: [()]
[]

guard lets us access a general “cancellation” action for applicative functors (specifically, Alternatives); in the context of the list monad, we can think of this as conditionally pruning a branch from our computation by ignoring the results accumulated so far in that branch and returning an empty list. Let’s say we want to find all pairs of integers in 1..10 x 1..10 with even products, and annotate them with those products:

import Control.Monad

xs = do
  a <- [1..10]
  b <- [1..10]
  guard (even (a * b))
  return ((a, b), a * b)

main = mapM print xs
((1,2),2)
((1,4),4)
((1,6),6)
((1,8),8)
((1,10),10)
((2,1),2)
((2,2),4)
((2,3),6)
((2,4),8)
((2,5),10)
((2,6),12)
...

Haskell has convenient syntax sugar for this too:

ghci> mapM print [((x, y), x * y) | x <- [1..10], y <- [1..10], even (x * y)]
((1,2),2)
((1,4),4)
((1,6),6)
((1,8),8)
((1,10),10)
((2,1),2)
...

(There is a tremendous amount of fascinating monad/applicative/traversable/alternative machinery that works with almost all of Haskell’s basic types, and which I would recommend having a look at if the above interests you at all; another example I’m fond of is sequence [[1..2], [3..6]], which exploits the fact that lists are both Traversable and Monad.)

It is important to note that – unlike with naive implementations that iterate over all combinations of elements from a handful of statically known sets, and only check which combinations would have survived at the end – this approach really does prune branches each time guard is invoked, avoiding much unnecessary work, and has all the execution semantics you would expect of a handwritten “iterate over items in source collections → map transformations over each element and collect the results → filter/prune → …” approach.

(The main deficiency, aside from those mild to moderate performance concerns (cache locality, laziness) that apply to Haskell’s execution model more generally, is that since Data.Set cannot be made into a monad (for any Set a, a carries an Ord constraint), we cannot use the obvious optimization strategy: “implement nondeterminism backed using a Set so that at each step/branching point, the universe is automatically collapsed down into unique values”. There are some packages which claim to implement a performant, monad-compatible set datatype, the implementation details of which I know not.)

Another very compelling feature of Haskell, which is a bit of a diversion from the main subject of this post but still worth bringing up, is that monad transformers can be used to mix and match nondeterminism with other kinds of effects – for example, the early termination/short-circuiting behavior of the Maybe monad, or the hermetic state manipulation features of State. As a brief example, we can use StateT over the list monad to iterate over all combinations of some transformations (successively applied to an initial value) while maintaining a “history” of each transformation trace, then print them all out:

import Control.Monad
import Control.Monad.State.Lazy
import Data.List
import Data.Function

test :: StateT [Int] [] ()
test = do
    x'@(x:xs) <- get
    rule <- lift [(+4), (*2), (`rem` 3)]
    put $ (rule x):x'
    return ()

main :: IO ()
main = do
    mapM_ (print . reverse) $ execStateT (replicateM 3 test) [1]
[1,5,9,13]
[1,5,9,18]
[1,5,9,0]
[1,5,10,14]
[1,5,10,20]
[1,5,10,1]
[1,5,2,6]
[1,5,2,4]
[1,5,2,2]
[1,2,6,10]
[1,2,6,12]
[1,2,6,0]
[1,2,4,8]
[1,2,4,8]
[1,2,4,1]
[1,2,2,6]
[1,2,2,4]
[1,2,2,2]
[1,1,5,9]
[1,1,5,10]
[1,1,5,2]
[1,1,2,6]
[1,1,2,4]
[1,1,2,2]
[1,1,1,5]
[1,1,1,2]
[1,1,1,1]

(The initial value is 1; the three transforms are “add 4”, “multiply by 2”, and “take the remainder mod 3”; we do three transformations in a row. As you would expect, we get 3^3 = 27 results.)

This page has some nice ileustrations of the control flow implied by various stacks of monad transformers. Finding the correct ordering of monad transformers in the stack, and mentally modeling the relevant types, is sometimes nontrivial; nevertheless, they can certainly improve concision in situations that call for them.

General nondeterminism with generators

Now that the basic idea has been motivated and demonstrated in a language more suited to it (I think this is actually a fairly good illustration of the utility of monads as “programmable semicolons” in languages with good syntax/compiler support for them), we can get to the main question: can we implement a reasonably ergonomic and performant version of this in Python? It is clear enough that McCarthy’s amb operator can be implemented in basically any programming language; Rosetta Code contains several Python implementations that seem fairly clean and well-behaved, including a (somewhat impractical) transliteration of the list monad described above (as well as more Haskell examples).

This is somewhat brittle, since we are generally forced to intermediate every operation in our code with a set of predefined combinators. Clearly, we would prefer to go beyond amb-like data structures – we want to interleave Python’s control flow with some amount of automated branching logic. Unless we want to either:

…this seemingly requires some way to interrupt execution at specific points and backtrack/rewind to those points, modifying execution state each time before resuming to inject the state of the current “branch”.

Coroutines seem well-suited to this purpose; Python’s generators, though not quite the same, provide enough functionality to do what we have described above. In particular, generators allow us to temporarily stop execution, returning control (and an arbitrary value) to the caller, then later resume execution while passing back (“sending”) a new value. We now have the basics of a workable approach: at each “ambiguous” expression in the program, just stop execution, run the remainder of the program once with each possible value for that expression, and coalesce the results. The actual code implementing this turns out to be quite short.

Here is a minimal example of the sort of function we would like to support nondeterminism for; we have a yield statement enclosing each “source” of ambiguity/Amb expression (and these can use values from prior ones normally, no extra magic needed):

def test(a: int) -> set[((int, int, int), int)]:
    x = yield Amb([1, 2, 3])
    y = yield Amb([2, 4, a])
    if x == y:
        z = yield Amb([0, a * 2])
    else:
        z = 4
    return ((x, y, z), (x + y) * z)

We’ll implement a very thin class to store these intermediate values and to flag stop points/junctions:

class Amb:
    def __init__(self, xs):
        self.xs = xs

Python generators aren’t readily cloneable, so we’ll write a helper to advance through one yield statement for each element of xs, using .send to set the values we want for each Amb expression in order (this is the main source of inefficiency in this implementation):

def send_n(g, xs):
    v = g.send(None)
    try:
        for x in xs:
            v = g.send(x)
        return v
    except StopIteration as v:
        return v.value

We catch StopIteration to intercept the final return value from the generator. Now we can put together our amb decorator, which takes an Amb-annotated generator function and returns a function with the nondeterminism effects applied (note that a real-world version of this should probably use functools.wraps or similar to copy relevant metadata from the original function: its docstring, name, etc.):

import itertools
from pprint import pprint

def amb(f):
    def go(g, xs):
        if isinstance(v := send_n(g(), xs), Amb):
            return set(itertools.chain.from_iterable(go(g, xs + [x]) for x in v.xs))
        else:
            return set([v])

    def r(*args, **kwargs):
        return go(lambda: f(*args, **kwargs), [])

    return r

This does exactly what we described above; go merely sends the current “branch” (ordered list of values to pass to the yields/Ambs) into the generator, and receives whatever is yielded or returned. If it’s a plain value, we are exiting the function and should embed the raw return value in a list. If we get an Amb, we are signaling a “breakpoint” or junction, at which we should evaluate the rest of the code once for each value inside the Amb and concatenate the results (a list of lists).

If we decorate our test function with amb and evaluate pprint(list(test(5))), we get:

[((1, 5, 4), 24),
 ((2, 4, 4), 24),
 ((3, 2, 4), 20),
 ((1, 2, 4), 12),
 ((3, 4, 4), 28),
 ((2, 5, 4), 28),
 ((3, 5, 4), 32),
 ((1, 4, 4), 20),
 ((2, 2, 0), 0),
 ((2, 2, 10), 40)]

Excellent!

Let’s try recreating our earlier StateT example as faithfully as possible:

@amb
def statet_test(start: int) -> list[list[int]]:
    r = [start]
    for i in range(3):
        rule = yield Amb([lambda x: x + 4, lambda x: x * 2, lambda x: x % 3])
        start = rule(start)
        r.append(start)
    return r

for r in statet_test(1):
    print(r)

Surprisingly, this is ~as concise as the Haskell version (albeit much slower)! I’ve temporarily changed the backing type in go from a set to a list, for two reasons: we want to get our results in the same order and we cannot have a set[list[int]] since lists are non-hashable. Either one works fine. Here are the results (seemingly matching the ones from the original):

[1, 5, 9, 13]
[1, 5, 9, 18]
[1, 5, 9, 0]
[1, 5, 10, 14]
[1, 5, 10, 20]
[1, 5, 10, 1]
[1, 5, 2, 6]
[1, 5, 2, 4]
[1, 5, 2, 2]
[1, 2, 6, 10]
[1, 2, 6, 12]
[1, 2, 6, 0]
[1, 2, 4, 8]
[1, 2, 4, 8]
[1, 2, 4, 1]
[1, 2, 2, 6]
[1, 2, 2, 4]
[1, 2, 2, 2]
[1, 1, 5, 9]
[1, 1, 5, 10]
[1, 1, 5, 2]
[1, 1, 2, 6]
[1, 1, 2, 4]
[1, 1, 2, 2]
[1, 1, 1, 5]
[1, 1, 1, 2]
[1, 1, 1, 1]

Just to round it out, here’s a more interesting example using string manipulation:

from random import shuffle

@amb
def string_test() -> list[str]:
    a = yield Amb(["where [content] go",
                   "what [content] do",
                   f"{yield Amb(['how', 'why'])} [content] do it"])
    content = (yield Amb(['will', 'did'])) + ' ' + (yield Amb(['you', 'she', 'he']))
    return a.replace('[content]', content) + (yield Amb(["?", "...?"]))

s = string_test()
shuffle(s)
pprint(s[:20])
['where did he go?',
 'what will he do...?',
 'why did you do it?',
 'what did you do...?',
 'why will he do it...?',
 'what will you do...?',
 'where will you go...?',
 'how did he do it?',
 'why did he do it?',
 'where will you go...?',
 'where will she go...?',
 'where did she go?',
 'what did you do?',
 'why will you do it?',
 'why did he do it...?',
 'where did she go...?',
 'how will she do it?',
 'how did she do it?',
 'what did she do?',
 'what will he do?']

If you squint at the definitions, it might be clear that our two branches in go map almost directly onto the bind (concat) and return (singleton) methods of the list monad. Indeed, we could swap in the behavior of a different monad and get the results we expect:

import itertools
from typing import TypeVar, Generic, Tuple
T = TypeVar('T')

class Maybe[T]:
    def __init__(self, x: T, isjust: bool) -> None:
        self.x: T = x
        self.isjust: bool = isjust

    def __str__(self):
        if self.isjust:
            return f"Just({self.x})"
        else:
            return "Nothing"

def send_n(g, xs):
    [elided]

def run(f):
    def go(g, xs):
        if isinstance(v := send_n(g(), xs), Maybe):
            if v.isjust:
                return go(g, xs + [v.x])
            else:
                return Maybe(None, False)
        else:
            return Maybe(v, True)

    def r(*args, **kwargs):
        return go(lambda: f(*args, **kwargs), [])

    return r

@run
def test(a: int) -> Maybe[Tuple[Tuple[int, int, int], int]]:
    x = yield Maybe(7, True)
    y = yield Maybe(3, True)
    if x == a:
        z = yield Maybe(None, False)
    else:
        z = yield Maybe(x + y - 5, True)
    return ((x, y, z), (x + y) * z)


print(5, test(5))
print(7, test(7))
5 Just(((7, 3, 5), 50))
7 Nothing

In the first example, the x == a evaluates to False and the “monadic state” is set to Some(x + y - 5); in the second, it evaluates to True and the state is Nothing, which short-circuits evaluation.

(n.b.: this is just illustrative – another, much better, way to implement something like this if you have a “scalar” short-circuiting monad like Maybe/Option or Either/Result is by using a custom exception handler that, for example, catches exceptions thrown by .unwrap calls on Nothing/Err values and transforms them back into the appropriate type, basically circumventing the need to actually thread handlers for the wrapper type through your function; this has the benefit of handling deeply nested call stacks with little additional effort, and is almost surely faster)

When we look back at how do-notation desugars in Haskell, the correspondence to the control flow used above is even clearer:

do { x1 <- action1
   ; x2 <- action2
   ; mk_action3 x1 x2 }
action1 >>= (\ x1 -> action2 >>= (\ x2 -> mk_action3 x1 x2 ))

(example from Wikibooks: “Haskell/do notation”; CC BY-SA 4.0)

As a final note, you can probably convince yourself without too much effort that the implicit control flow in cases where we select specific branches (i.e., perform goal-directed search) directly mirrors the backtracking that would occur in say, an equivalent hand-programmed tree search algorithm, or a typical continuation-based implementation of amb in Lisp. Wikipedia’s description of this process is somewhat more precise:

If all alternatives fail at a particular choice point, then an entire branch fails, and the program will backtrack further, to an older choice point. One complication is that, because any choice is tentative and may be remade, the system must be able to restore old program states by undoing side-effects caused by partially executing a branch that eventually failed.

Back to the list/set version. This is a nice toy, but is it compatible with more complex control flow? TODO

TODO: compare code with version saved on styx

Scaling it up

This is certainly interesting, but even ignoring the efficiency loss from having to reconstruct the entire function state (by rewinding the generator) each time we follow a branch, the performance leaves much to be desired: Python is not a fast language. To illustrate, we’ll try implementing a naive function that generates a list of Pythagorean triples with a, b, c in 1..200:

guard = lambda c: Amb([None]) if c else Amb([])

@amb
def pythagorean_triples(n: int) -> set[Tuple[int, int, int]]:
    x = yield Amb(range(1, n+1))
    y = yield Amb(range(x+1, n+1)) # avoid double-counting
    z = yield Amb(range(y+1, n+1))
    yield guard(x ** 2 + y ** 2 == z ** 2)
    return (x, y, z)

print(len(t := pythagorean_triples(200)))
pprint(t)

time python amb.py gives:

127
{(3, 4, 5),
 (5, 12, 13),
 (6, 8, 10),
 (7, 24, 25),
 (8, 15, 17),
 (9, 12, 15),
 (9, 40, 41),
 (10, 24, 26),
 (11, 60, 61),
 (12, 16, 20),
 (12, 35, 37),
 (13, 84, 85),
 (14, 48, 50),
 (15, 20, 25),

 ...

 python amb.py  6.88s user 0.02s system 99% cpu 6.932 total

Vectorized combinatorial sugar

This is somewhat better than I expected, but still impractical for most real-world problems. For performing e.g., Monte Carlo simulations, we want something with performance within an order of magnitude of C/C++ code (or at least Haskell). The ideal case would be to somehow vectorize the annotated function with minimal input from the user, probably using NumPy or a similar library that provides Python bindings to efficient array operations. In particular, if each step (or bind operation) in our function can be represented by f(a, b, ...), with each argument being some (nondeterministic) expression derived from an Amb term, we want to implicitly generate the Cartesian product of all a_i, b_j, ... and pass it to a vectorized version of f. This will inevitably require some syntactic and semantic tradeoffs over the generator-based version. For now, we’ll impose the constraint that our decorated function only takes as inputs, computes using, or returns integers or floats.

Unfortunately, just swapping NumPy arrays into the code we showed earlier (instead of lists/sets) probably wouldn’t be of much use: we’d still end up iterating through every element in native Python. We would inevitably need to have a vectorized version of every operation involved in the code so that we could process many branches in parallel. The other extreme involves full vectorization ignoring control flow; after each Amb, we could unroll all that had been encountered up to that point, take their Cartesian product, and compute every relevant operation on every combination of inputs, accepting some wasted work in exchange for being able to forego extremely expensive native Python logic. In the case of our Pythagorean triple calculator, we don’t lose much, since we need to consider every combination of elements from our three Amb expressions (clearly, a simpler collection-based solution like some of the ones shown on Rosetta Code would also work for this problem).

If we are willing to dispense for the moment with some more complex control flow, we may as well just replace the generator-based interceptor with a special object that broadcasts over common arithmetic operations; if we do something like Wrapped([1, 2]) + Wrapped([3, 4]), for example, we would expect to get Wrapped([1 + 3, 2 + 3, 1 + 4, 2 + 4]) == Wrapped([4, 5, 5, 6]). This is certainly cleaner than with the generator approach; we only traverse the code once, instead of once per branch/combination. The main trouble is with if-statements – we must follow both branches at different places in the array, ideally rewriting the if as an np.where or similar (we can perform this translation manually, but figuring out how to avoid this is an interesting exercise). We’ll come back to that.

Let’s create a wrapper class that stands in for scalar values in our program and automatically performs the broadcasting described above:

import numpy as np
import operator

class Amb2:
    def __init__(self, data):
        if not isinstance(data, np.ndarray):
            data = np.array(data)
        self.data = data

def makeop(name: str):
    def tmp(xs: Amb2, ys: Amb2) -> Amb2:
        a, b = np.meshgrid(xs.data, ys.data, indexing='ij')
        prod = np.stack([a.ravel(), b.ravel()], axis=-1)
        return Amb2(getattr(operator, name)(prod[:, 0], prod[:, 1]))

    setattr(Amb2, f'__{name}__', tmp)

for f in ['add', 'mul', 'sub', 'truediv', 'floordiv', 'pow', 'and_', 'or_', 'xor']:
    makeop(f)

def amb(f):
    def r(*args, **kwargs):
        return f(*args, **kwargs).data
    return r

@amb
def test() -> np.ndarray:
    a = Amb2([1, 3, 5])
    b = Amb2([2, 4, 6])
    return a + b

print(test())
[ 3  5  7  5  7  9  7  9 11]

It’s probably also worthwhile to support scalar types mixed into our code without additional effort from the programmer:

def makeop(name: str):
    def tmp(xs: Amb2 | int | float, ys: Amb2 | int | float) -> Amb2:
        if not isinstance(xs, Amb2):
            assert isinstance(xs, (int, float))
            xs = Amb2([xs])
        if not isinstance(ys, Amb2):
            assert isinstance(ys, (int, float))
            ys = Amb2([ys])

        a, b = np.meshgrid(xs.data, ys.data, indexing='ij')
        prod = np.stack([a.ravel(), b.ravel()], axis=-1)
        return Amb2(getattr(operator, name)(prod[:, 0], prod[:, 1]))

    setattr(Amb2, f'__{name}__', tmp)
    setattr(Amb2, f'__r{name}__', tmp)

...

@amb
def test() -> np.ndarray:
    a = Amb2([1, 3, 5])
    b = Amb2([2, 4, 6])
    return 1.5 * (a + b + 2)

print(test())
[ 7.5 10.5 13.5 10.5 13.5 16.5 13.5 16.5 19.5]

Mutation is also supported, even if the left-hand side is not (yet) Amb:

@amb
def test2() -> np.ndarray:
    a = 1
    for i in range(3):
        a += Amb2([2, 3])
    return a
[ 7  8  8  9  8  9  9 10]

For/while loops (and conditionals) however only work when the loop condition is “primitive” and does not contain any Amb-expressions. For example, a while-loop with a condition derived from an Amb would likely not behave as expected; one could imagine a way to make this work by detecting when we drop into a loop, overriding the behavior of bool-coercion to keep the loop running until the condition is false for all values in the Amb, and “masking” assignment operations so that they only affect members of the target value for which corresponding values of any ambiguous expressions in the context still make the loop condition evaluate to True (perhaps maintaining a stack of contexts for nested loops/conditional statements), all in an efficient vectorized fashion. This is nontrivial.

The main issue here is that we discard provenance information describing where the values were derived from; we could cache the expression tree used to generate the concrete values for each Amb, but this wouldn’t be of much use if we wanted to e.g., perform another computation using the same variables and use it to filter the results of the original computation.

Let’s return to our earlier generator-based decorator, which gives us more precision over control flow.

TODO

End

I hope you have learned something useful (or at least entertaining) from this post, or at least found some of the links therein interesting. Thanks for reading!