Probabilistic Programming with Typed Traces in Lean

Paul Horsfall

June 2025

Introduction

I’ve been experimenting with implementing ideas from probabilistic programming in a dependently typed language. This article documents one particular path I’ve explored in Lean.

In particular, it describes a DSL in which a probabilistic program’s sample space is reflected in its type, and then uses this language to express some very simple models and inference algorithms.

Background

I’m interested in probabilistic programming languages for Bayesian inference. In this setting, a probabilistic program is written to describe a probabilistic model (assumed to describe some process of interest), data generated by the process are combined with the model as “observations”, then (approximate) inference is performed to obtained posterior beliefs.

In this work, a probabilistic program is a regular functional program augmented with an effectful flip operation, taking a success probability as argument and returning a Bool.

Intuitively, we interpret such a program as a stochastic procedure that draws samples from a Bernoulli distribution at flip statements, and accumulates these sampled values in a data structure called a “trace”. More generally, a probabilistic program can be thought of as representing a probability distribution over such traces. We can write interpreters to e.g. sample from this distribution, or compute the probability density of a trace under the distribution, etc.1

Here, probabilistic programs also have a factor operation, which is used to weight to samples from the program. This is the basic mechanism on which the ability to incorporate observed data is built.

Inference procedures can be built compositionally from the parts just described. My focus here is on the use of types to help ensure the correctness of such compositions.

import Lean.Data
set_option autoImplicit true

Traces

I’m primarily interested in statically describing the possible shapes that the trace generated by a given probabilistic program can take. (Equivalently, the support of the probability distribution over traces.) When performing inference for a probabilistic program, we often have to give an auxilliary probabilistic program (proposal, guide, etc.), and these two programs need to be compatible in some sense. Often this is checked at run-time, but here I’m trying to do this statically with types.

Intuitively, a program that flips two coins might generate a trace with type something like Vec Bool 2, capturing both the number of random choices made and their types.

More generally, a program might sample from primitive distributions of differing types, so we’ll use heterogeneous lists as the basis of the trace representation.

-- typically called `HList`, but i'm calling it `Tr` to be suggestive
-- of "trace"
abbrev Tr : List Type -> Type
  | [] => Unit
  | t::ts => t × Tr ts

example : Tr [Bool, Bool] := (true, false, ())
example : Tr [Float, Bool] := (0.0, false, ())

One of the appeals of probabilistic programs is that they can describe complex distributions with structural choices. For example, a program might first flip a coin, and then flip a second coin only if the first came up heads. I accomodate this with the following type:

-- a branching structure
structure Br (s t : List Type) where
  val : Bool
  sub : match val with | true => Tr s | false => Tr t

infixr:55 " <> " => Br

-- i'm defining `Br` rather than using the general dependent pair type
-- `Sigma` in order to give a `toString` instance.
instance [ToString (Tr s)] [ToString (Tr t)] : ToString (Br s t) where
  toString | ⟨true, tr⟩ => s!"⟨true, {tr}⟩"
           | ⟨false, tr⟩ => s!"⟨false, {tr}⟩"

The following examples show how this type can be used to capture the shape of traces generated by the previous example. The branch itself (<>) introduces a Bool value into the trace, corresponding to the first coin flip. When this comes up heads (true) we take the left branch, which demands we give a second Bool; the result of the second coin flip. When the first flip comes up tails (false), we take the right branch, which is empty.

example : [Bool] <> [] := ⟨true , (true, ())⟩
example : [Bool] <> [] := ⟨true , (false, ())⟩
example : [Bool] <> [] := ⟨false, ()⟩

When the primitive types within a trace take on finitely many values, then so does the trace itself, as witnessed by this typeclass:

class Finite (a : Type) where
  enumerate : List a

instance : Finite Bool where
  enumerate := [true, false]

instance : Finite Unit where
  enumerate := [()]

def product (xs : List a) (ys : List b) : List (a × b) :=
  xs.flatMap fun x => ys.map fun y => (x, y)

-- we need instances for nested products, which `Tr` reduces to,
-- rather than `Tr` itself.
open Finite in
instance [Finite a] [Finite b] : Finite (a × b) where
  enumerate := product enumerate enumerate

open Finite in
instance [Finite (Tr s)] [Finite (Tr t)] : Finite (s <> t) where
  enumerate := enumerate.map (⟨true, ·⟩) ++ enumerate.map (⟨false, ·⟩)

This shows (without proof) that the three values given in the previous examples are the only values that inhabit [Bool] <> []:

example : (Finite.enumerate : List ([Bool] <> [])) =
        [⟨true,  (true,  ())⟩,
true,  (false, ())⟩,
false, ()⟩]
        := rfl

Syntax

I’m using a variation on the Freer monad to represent the syntax of probabilistic programs. Since I’m not concerned with re-use right now, I hard-code this Freer-like structure to use one particular set of primitive operations.

The operations are the aforementioned flip and factor, plus a branching operation cond to introduce the trace branching structure described above.

Under the sampling semantics, cond p q flips a fair coin and uses the result to choose whether to sample from program p or q. (Similar to Lean’s built-in cond : Bool -> a -> a -> a.) While p and q must have a common return type, they can have different trace types. This can’t be implemented with if, since it would require p and q to also have the same trace type.

With Freer, the datatype for operations typically records the type of any arguments as well as the return type of the operation. Here, Op is extended to also (optionally) record the type of the value that the operation extends the trace with. e.g. flip extends the trace with a Bool, and factor does not extend the trace.

Similarly, P is extended to record the type of the trace associated with the program. In particular, pure (aka return) produces an empty trace, and the step constructor, which extends a program with an operation, simply extends the trace with the type indicated by the operation, if present.

def optcons : Option a -> List a -> List a
  | .none, xs => xs
  | .some x, xs => x :: xs

infixr:55 " :? " => optcons

mutual

  -- primitive operations
  inductive Op : Type -> Option Type -> Type 1 where
    | flip : Float -> Op Bool (.some Bool)
    | factor : Float -> Op Unit .none
    | cond : P s a -> P t a -> Op a (.some (s <> t))

  -- Freer-like structure
  inductive P : List Type -> Type -> Type 1 where
    | pure : a -> P [] a
    | step : Op a t -> (a -> P ts b) -> P (t :? ts) b

end

We can make probabilistic programs for each of the primitive operations:

def P.flip (pr : Float) : P [Bool] Bool :=
  step (.flip pr) pure

def P.factor (pr : Float) : P [] Unit :=
  step (.factor pr) pure

def P.cond (p : P s a) (q : P t a) : P [s <> t] a :=
  step (.cond p q) pure

The presence of an additional type index means P can’t be a Monad2. However, we can still give a sensible definition of bind:

theorem optcons_append : s :? (t ++ u) = (s :? t) ++ u :=
    by cases s; unfold optcons; rfl; rfl

def P.bind : P s a -> (a -> P t b) -> P (s ++ t) b
  | pure x, k => k x
  | step op k, k' =>
     -- cast `P (s :? (t ++ u)) a` to `P ((s :? t) ++ u) a`
     congrArg (P . _) optcons_append ▸
     step op fun x => (k x).bind k'

The implementation here is identical to the implementation of bind for Freer; only the types differ.

We now have enough machinery to be able to write down some probabilistic programs3. Here’s a simple medical diagnosis example4:

open P in
def medical : P [Bool, Bool] Bool :=
  flip 0.01 |>.bind λ hasDisease =>
  let testProb := if hasDisease then 0.8 else 0.096
  flip testProb |>.bind λ testResult =>
  pure hasDisease

Interpreters

In this section I implement several interpreters for probabilistic programs.

The first produces a stochastic procedure that samples traces from the distribution described by the program. This is implemented as a state monad that maintains the state of a random number generator. factor is interpreted as weighting the generated samples, which is tracked by stacking a second state monad.

abbrev SampW := StateT Float (StateM StdGen)

-- sample a value in [0,1) (in base monad)
def _uniform : StateM StdGen Float := fun g =>
  let N := stdRange.snd
  let (n, g') := stdNext g
  ((n-1).toFloat / N.toFloat, g')

-- lift `_uniform` into transformer stack. (there might be a more
-- idiomatic way to do this?)
def SampW.uniform : SampW Float := _uniform

def SampW.flip (pr : Float) : SampW Bool :=
  return (<- uniform) < pr

def SampW.weight (pr : Float) : SampW Unit :=
  .modifyGet λ w => ((), w + .log pr)

-- run a monadic computation multiple times, collecting the results in
-- a list
def replicate [Monad m] : Nat -> m a -> m (List a)
  | 0, _ => return []
  | i+1, act => return (<- act) :: (<- replicate i act)

-- run a sampler once, and throw out the weight
def SampW.run1 (seed : Nat) (m : SampW a) : a :=
  m |> (StateT.run' · 0)
    |> (StateT.run' · (mkStdGen seed))
    |> Id.run

-- run a sampler `n` times, returning `n` weighted samples
def SampW.run (seed : Nat) (n : Nat) (m : SampW a) : List (a × Float) :=
  m |> (StateT.run · 0)
    |> replicate n
    |> (StateT.run' · (mkStdGen seed))
    |> Id.run

With that we can implement the sampler. In order to implement this compositonally we need the sampler to generate return values as well as traces. Something similar crops up in most (perhaps all) of the interpreters we implement.

def P.sampler' : P t a -> SampW (a × Tr t)
  | pure x => return (x, ())
  | step (.flip pr) k => do
         let b <- .flip pr
         let (x, tr) <- sampler' (k b)
         return (x, (b, tr))
  | step (.factor pr) k => do
         .weight pr
         sampler' (k ())
  | step (.cond p q) k => do
         -- sample a bool to determine branch to follow
         let b <- .flip 0.5
         match b with
         -- run left branch, the the rest of the program
         | true => let (x, tr) <- sampler' p
                   let (y, tr') <- sampler' (k x)
                   return (y, (⟨true, tr⟩, tr'))
         -- run left branch, the the rest of the program
         | false => let (x, tr) <- sampler' q
                    let (y, tr') <- sampler' (k x)
                    return (y, (⟨false, tr⟩, tr'))

-- a variant that drops the return value, leaving just the trace
def P.sampler (p : P t a) : SampW (Tr t) :=
  do let (_, tr) <- p.sampler'; return tr

A second interpreter produces the density function for the distribution over traces.

def bernoulliDensity (pr : Float) : Bool -> Float
  | true => pr | false => 1 - pr

def P.density' : P t a -> Tr t -> (a × Float)
  | pure x, () => (x, 1.0)
  | step (.flip pr) k, (x, tr) =>
          let (y, pr') := density' (k x) tr
          -- probability of drawing `x` from a Bernoulli * density of
          -- rest of program
          (y, bernoulliDensity pr x * pr')
  | step (.factor pr) k, tr =>
         let (x, pr') := density' (k ()) tr
         (x, pr * pr')
  | step (.cond p q) k, (⟨b, tr⟩, tr') =>
         -- compute the density of the appropriate branch, given the
         -- bool in the trace
         let (x, pr) := match b with
         | true => density' p tr
         | false => density' q tr
         -- compute the density of the rest of the program
         let (y, pr') := density' (k x) tr'
         (y, 0.5 * pr * pr')

def P.density (p : P t a) (tr : Tr t) : Float :=
  p.density' tr |>.snd

Finally, we can also interpet a probabilistic program as a function from traces to return values.

def P.retval : P t a -> Tr t -> a
  | pure x, () => x
  | step (.flip _) k, (x, tr) => retval (k x) tr
  | step (.factor _) k, tr => retval (k ()) tr
  | step (.cond p q) k, (⟨b, tr⟩, tr') =>
         let x := match b with | true => retval p tr | false => retval q tr
         retval (k x) tr'

Observing Data

When performing Bayesian inference we typically want to “condition” the model on observed data.

A direct approach to this is to incorporate the observed data into the model itself. Intuitively, to observe the result of a flip, we want to replace the flip with a factor statement that weights the execution by the probability of the observed value.

We can capture this pattern with this program:

def P.flipObs (pr : Float) (val : Bool) : P [] Bool :=
  factor (bernoulliDensity pr val) |>.bind λ () =>
  pure val

For example, to add the observation that testResult came out true to the medical example, we can re-write it like so:

open P in
def medicalPosTest : P [Bool] Bool :=
  flip 0.01 |>.bind λ hasDisease =>
  let testProb := if hasDisease then 0.8 else 0.096
  -- flip testProb |>.bind λ testResult =>
  flipObs testProb true |>.bind λ testResult =>
  pure hasDisease

This approach is straight-forward to implement, but it can be fiddly to use. For this reason, practical systems often offer some way to describe observed data (aka condition the model) without having to modify the model’s source code.

One way to implement that idea here, is to implement conditioning as a program transform, applying exactly the transform we applied manually above.

Knowing the type of the program’s trace allows us to ensure that only sensible observations can be made. e.g. You can’t observe a value at a non-existent flip, or observe anything other than a Bool from a flip.

In particular, to make an observation one has to give a (type-safe) index into the program’s trace to indicate which choice is being observed, and an observed value of the corresponding type. The result is a program in which the observed choice no longer appears in the trace type.

-- remove an element from a list by index
def wo : (xs : List a) -> Fin xs.length -> List a
  | _ :: xs, ⟨0, _⟩ => xs
  | x :: xs, ⟨i+1, h⟩ => x :: wo xs ⟨i, by simp at h; exact h⟩

example : wo ["a", "b", "c"] ⟨1, by simp⟩ = ["a", "c"] := rfl

-- note: we can't make an observation of a `pure` computation. the
-- types ensure we can skip the `pure` case altogether.
def P.obs : P t a -> (i : Fin t.length) -> t.get i -> P (wo t i) a
  -- replace `flip` with `factor`, continue with observed value.
  -- (analogous to `flipObs`.)
  | step (.flip pr) k, ⟨0, _⟩, val =>
         step (.factor (bernoulliDensity pr val)) fun () => k val
  -- skip over this `flip`, pushing observation into continuation.
  | step (.flip pr) k, ⟨i+1, _⟩, val =>
         step (.flip pr) fun x => (k x).obs ⟨i, _⟩ val
  -- there's no value that could be observed at factor. we just need
  -- to push the observation into the continutation.
  | step (.factor pr) k, ⟨i, _⟩, val =>
         step (.factor pr) fun () => (k ()).obs ⟨i, _⟩ val
  -- here we're observing an entire sub-trace, so we compute its
  -- density under the model indicated by the first value in the
  -- sub-trace. then replace the `cond` with `factor`, much as we did
  -- for `flip`.
  | step (.cond p q) k, ⟨0, _⟩, ⟨b, tr⟩ =>
         let (x, pr) := match b with
                        | true => density' p tr
                        | false => density' q tr
         step (.factor (0.5 * pr)) fun () => k x
  -- keep going.
  | step (.cond p q) k, ⟨i+1, _⟩, val =>
         step (.cond p q) fun x => (k x).obs ⟨i, _⟩ val

With this, we can express medicalPosTest without modifying or rewriting medical like so:

def medicalPosTest' : P [Bool] Bool :=
  medical.obs ⟨1, by simptrue

Inference

Enumeration

As a warm-up, I’ll implement exhaustive enumeration for probabilistic programs. i.e. We’ll reify the exact distribution over traces represented by the program. This is only possible when the set of traces the program can generate is finite.

This is simple to express in terms of machinery we’ve already assembled. We simply enumerate all possible traces, compute the density of each, then normalise the resulting distribution.

def normalize (xs : List (a × Float)) : List (a × Float) :=
  let z := xs.foldr (fun (_, p) p' => p + p') 0
  xs.map fun (x, p) => (x, p / z)

def P.enumerate [Finite (Tr t)] (p : P t a) : List (Tr t × Float) :=
  let support := Finite.enumerate
  support.map (fun tr => (tr, p.density tr)) |> normalize

Here’s what the distribution on traces looks like for the medical and medicalPosTest examples.

/-- info: [((true, true, ()),   0.008000),
           ((true, false, ()),  0.002000),
           ((false, true, ()),  0.095040),
           ((false, false, ()), 0.894960)] -/
#guard_msgs (whitespace := lax) in
#eval medical.enumerate

/-- info: [((true, ()),  0.077640),
           ((false, ()), 0.922360)] -/
#guard_msgs (whitespace := lax) in
#eval medicalPosTest.enumerate

It’s sometimes convenient to think about the (marginal) distribution on program return values rather than traces. To implement this, we use the retval interpretation to map traces to return values5. Since retval in not injective (in general), we need a notion of equality in order to sum the probabilities of paths returning the same value.

open Lean (AssocList)

def update [BEq a] (xs : AssocList a Float) (x : a) (pr : Float)
  : AssocList a Float :=
    match xs.find? x with
    | .some pr' => xs.replace x (pr + pr')
    | .none => xs.insert x pr

def marginalise [BEq b] (f : a -> b) (xs : List (a × Float))
  : List (b × Float) :=
    xs.foldr (fun (x, pr) ys => update ys (f x) pr) .empty |>.toList

-- marginal distribution on return values
def P.marginal [Finite (Tr t)] [BEq a] (p : P t a) : List (a × Float) :=
  p.enumerate |> marginalise p.retval

We can use this to see how the prior probabilities on hasDisease compare to the posterior probabilities having observed a positive test result in the medical example:

/-- info: [(true, 0.010000), (false, 0.990000)] -/
#guard_msgs in
#eval medical.marginal

/-- info: [(true, 0.077640), (false, 0.922360)] -/
#guard_msgs in
#eval medicalPosTest'.marginal

Markov Chain Monte Carlo

For a more interesting example we can implement MCMC inference using a version of Metropolis-Hastings.

The basic idea is to generate samples by performing a random walk in the space of traces. We’ll have a probabilistic program propose moves, and use the Metropolis-Hasting acceptance rule to ensure that the resulting distribution on traces coverges to the probability distribution represented by a target program.

The core of the implementation is a stochastic procedure mh that takes a single step in such a walk. (i.e. Given the current trace, it samples the next trace in the walk.) mh is parameterised by target program p and proposal q. The proposal is a conditional distribution over traces, represented as a function from traces to probabilistic programs.

The implementation makes use of the sampling and density interpretations. The type system ensures that the target and proposal are defined over the same space of traces6.

-- note: generates un-weighted samples
def mh (p : P t a) (q : Tr t -> P t b) (old : Tr t) : SampW (Tr t) := do
  -- sample from the proposal
  let new <- (q old).sampler
  let rho := (p.density new * (q new).density old)
           / (p.density old * (q old).density new)
  let p := min 1 rho -- acceptance probability
  if <- .flip p
    then return new -- accept
    else return old -- reject

As an example, we can perform MCMC (posterior) inference for the medical diagnosis program. medicalPosTest has a trace with type [Bool], so we’ll write a proposal the negates the single Bool in this trace with high probability.

def negateOneBool : Tr [Bool] -> P [Bool] Bool
  | (true, ()) => P.flip 0.01
  | (false, ()) => P.flip 0.99

We need a little more machinery in order to perform inference:

def iterate [Monad m] (f : a -> m a) (x : a) : Nat -> m (List a)
  | 0 => return []
  | n+1 => do let y <- f x
              let ys <- iterate f y n
              return y :: ys

def expectation (f : a -> Float) (xs : List a) : Float :=
  xs.foldr (fun x z => f x + z) 0 / xs.length.toFloat

def probOf [DecidableEq a] (x : a) : List a -> Float :=
  expectation fun x' => if x = x' then 1 else 0

We can now use MCMC to compute posterior expectations for the medical diagnosis example. For simplicity I manually specify an initial trace from which to start the walk.

def trace0 := (false, ())
def step := mh medicalPosTest' negateOneBool
def mcmc := iterate step trace0 1000
def samples := mcmc.run1 (seed := 0)

/-- info: 0.081000 -/
#guard_msgs in
#eval probOf (true, ()) samples

Example

Truncated Geometric

The geometric distribution, implemented by repeatedly flipping coins and counting the number of tails seen before a head is obtained, is a classic example of a probabilistic program that makes structural choices. We can’t implement an unbounded geometric distribution this way in this language, but we can implement a truncated geometric.

I’ll first describe the space of traces generated by such a geometric distribution, truncted to n steps:

abbrev Geometric : (n : Nat) -> List Type
  | 0 => []
  | n+1 => [[] <> Geometric n]

We can enumerate this type for some particular n, and see that all but one of the traces look like a list of Bools comprising some number of false entries (possibly zero) terminated with a true. The exception is the last trace, where because of truncation, the trace is comprised entirely of false entries.

/-- info:
  [(⟨true, ()⟩, ()),
   (⟨false, (⟨true, ()⟩, ())⟩, ()),
   (⟨false, (⟨false, (⟨true, ()⟩, ())⟩, ())⟩, ()),
   (⟨false, (⟨false, (⟨false, (⟨true, ()⟩, ())⟩, ())⟩, ())⟩, ()),
   (⟨false, (⟨false, (⟨false, (⟨false, ()⟩, ())⟩, ())⟩, ())⟩, ())]
-/
#guard_msgs (whitespace := lax) in
#eval (Finite.enumerate : List (Tr (Geometric 4)))

It’s perhaps most natural to implement the (truncated) geometric recursively:

theorem append_nil : xs ++ [] = xs := by simp

open P in
def geometric : (n : Nat) -> P (Geometric n) Nat
  | 0 => pure 0
  | n+1 => cond -- branch on a flip
                -- `true` branch
                (pure 0)
                -- `false` branch
                -- cast `P (Geometric n ++ []) Nat` to `P (Geometric n) Nat`
                (congrArg (P . _) append_nil ▸
                (geometric n |>.bind λ m => -- recursive call
                 pure $ m + 1))             -- +1 for the `false` we got here

As a sanity check we can look at the marginal distribution on return values for this program:

/-- info: [(0, 0.500000),
           (1, 0.250000),
           (2, 0.125000),
           (3, 0.062500),
           (4, 0.062500)] -/
#guard_msgs (whitespace := lax) in
#eval (geometric 4).marginal

Next Steps

I more or less achieved what I hoped with this project, though in its current state this is far from being a practical tool. Here are some specific areas that I might look to improve in future work.

Having cond in the language allows models beyond Bayesian networks to be expressed, but it’s likely there are many probabilistic programs of interest that are not expressible.

Having to specify observations by giving an index into the trace is cumbersome. Some systems would allow flip statements to be labeled with human readable names. Reflecting this in a trace’s type seems to require something like row types. As an alternative I’m curious whether human readable names can be layered on top of the current implementation as an optional convenience.

When observing the result of cond, we have to give a Bool as well as an entire sub-trace for the corresponding branch. A more fine-grained approach would be to only require the Bool to be given, and have the corresponding sub-trace merged into the top-level trace.

Requiring a target and auxilliary program to have a common trace type is overly strict. One step toward loosening this might be to allow different types that related by a permutation.

The algorithms implemented here are all relatively naive. It would be interesting to see whether more sophisticated algorithms (such as those requiring automatic differentiation) can be implemented in this setting.

Related Work

The trace types paper of Lew et al. covers similar ground and much more besides; it was my main inspiration. I also draw heavily on my experience with WebPPL and Pyro. The Metropolis-Hastings implementation is based on the approach in Gen.


  1. In a more practical system flip would be replaced by sample which would allow draws from any primitive distribution. I’m restricting myself to only supporting draws from a Bernoulli distribution for simplicity. I expect that relaxing this would be straight-forward.

  2. Though it may be a “graded monad”, since the extra type index combines monoidally.

  3. Without a Monad instance we can’t use do notation, though it would be possible to create custom notation using macros.

  4. This example is taken from http://probmods.org/chapters/conditioning.html

  5. This isn’t an efficient approach. We compute the return value when computing the density, then throw it out, only to recompute it here.

  6. This isn’t sufficient to ensure correctness. e.g. q needs to represent a normalized distribution, i.e. not call factor, which isn’t enforced by the types. I’ve intentionally not addressed that here in order to keep the focus on traces.