December 15, 2021

Bean Machine: Composable, Fast Probabilistic Inference on PyTorch

By: Rodrigo de Salvo BrazJP Chen, Brad CottelJohann GeorgeSweta KarlekarLily LiFeynman LiangEric Lippert, Sepehr Akhavan Masouleh, Erik MeijerStephanie NailNeeraj PradhanKinjal Shah, Todd SmallWalid TahaMichael Tingley, Narges TorabiXiaoyan Wang, Stefan WebbZitong Zhou

Today, we’re excited to announce an early beta release of Bean Machine, a PyTorch-based probabilistic programming system that makes it easy to represent and to learn about uncertainties in the machine learning models that we work with every day.

Bean Machine enables you to develop domain-specific probabilistic models, and automatically learn about unobserved properties of the model with automatic, uncertainty-aware learning algorithms. Compared to other machine learning approaches, probabilistic modeling offers numerous benefits:

  1. Uncertainty estimation. Predictions are quantified with reliable measures of uncertainty in the form of probability distributions. An analyst can understand not only the system's prediction, but also the relative likelihood of other possible predictions.
  2. Expressivity. It's easy to encode a rich model directly in source code. This allows one to match the structure of the model to the structure of the problem.
  3. Interpretability. Because the model matches the domain, one can query intermediate learned properties within the model. This means users are not just working with a “black box” but can interpret why a particular prediction was made, and in turn this can aid them in the model development process.

Though powerful, probabilistic modeling does take some getting used to. If this is your first exposure to the topic, we welcome you to check out a short overview of the concept in the Fabulous Adventures in Coding blog.

We on the Bean Machine development team believe that the usability of a system forms the bedrock for its success, and we’ve taken care to center Bean Machine’s design around a declarative philosophy within the PyTorch ecosystem. This, we hope, makes using Bean Machine simple and intuitive — whether that’s authoring a model, or advanced tinkering with its learning strategies. A declarative style means that data scientists and ML engineers can simply write out the math for their model directly in Python, and allow Bean Machine to do the hard work of inferring the possible distributions for predictions based on this declaration of their model.

Probabilistic modeling involves a few major steps: (1) the model; (2) handling data; (3) learning; and (4) analysis. In the next sections, we’ll give a brief tour of how to perform these in Bean Machine. This post is just a light overview, and we welcome you to check out the website for more information on each of these sections.

Before we dive in, please be aware that Bean Machine is still in a beta state. APIs may change rapidly as the language evolves, and we welcome feedback (and pull requests, when appropriate.) along the way.

Modeling

Bean Machine’s modeling revolves around the concept of a “generative model.” A generative model is a domain-specific probabilistic model that describes the underlying model for an area under study, before having seen any data. Though Bean Machine models can use arbitrary Python code, they are primarily comprised of declarations of random variables — uncertain quantities that represent either unknown values or observed values.

While we’ve tried to keep Bean Machine’s syntax as accessible as possible, probabilistic modeling does require some understanding of probability distributions and statistical modeling. If this is your first exposure to these concepts, we highly recommend the excellent YouTube series Statistical Rethinking, or the free, online tutorial Bayesian Methods for Hackers.

To illustrate Bean Machine’s syntax, let’s consider a Gaussian mixture model. For instance, let’s say we observe temperature readings from two cities for the last 10 days. We want to learn the temperature profile of those cities, but we don’t know which city each reading came from. You could write this as a generative model in Bean Machine:

city_count = 2

@bm.random_variable
def city_temp_mean(city_index: int):
"""Temperature mean in deg Fahrenheit for the given city."""
return dist.Normal(70, 50)

@bm.random_variable
def city_temp_precision(city_index: int):
"""Temperature std dev in deg Fahrenheit for the given city."""
return dist.Gamma(1.0, 1.0)

@bm.random_variable
def city_assignment(i: int):
"""Index for the hypothesized city that temperature i came from."""
return dist.Categorical(torch.ones((city_count,)))

@bm.random_variable
def city_temp(i: int):
"""Temperature reading i, in deg Fahrenheit."""
city_index = city_assignment(i).item()
return dist.Normal(
city_temp_mean(city_index), 1.0 / city_temp_precision(city_index)
)

In this model, we assume temperatures are observed from individual cities, and that each city has its own temperature profile, specified as a Normal distribution with city-specific mean and standard deviations.

Why do we call models built in this way generative? Generative means that the model is not only useful for analyzing a problem of interest, but also for generating data according to an assumed underlying process. That is, the model we build—which is our understanding of that underlying process—allows us to simulate the range of outcomes or predictions we might see. This enables us to measure uncertainty in our predictions and, ultimately, learn probability distributions for values of interest.

There are a few things to notice. Foremost, Bean Machine sports a declarative style: every random quantity in the model translates to a Python function declaration that is decorated as a random_variable and that returns a distribution. And, although Bean Machine uses a declarative style, arbitrary Python code, including stochastic control flow, is allowed within a random variable function.

Furthermore, random variable functions have parameters that are used to define their logical identity. For example, city_temp(2) represents the temperature reading with index 2, and city_assignment(2) represents the city that that temperature was drawn from.

You’ll also notice that we’ve defined this model in the top level (e.g., it could be in a Jupyter Notebook). We’ve done this for maximum prototyping convenience; however, you can encapsulate your model within a well-documented class, or even across multiple files, and everything will still work as intended.

Finally, you’ll notice that the model is purely generative, in that there is no notion of data or observations anywhere in the model. This pattern is a powerful aspect that enables Bean Machine to easily support an array of predictive and diagnostic procedures. (city_count is a modeling assumption, and not part of the data.)

Data

In Bean Machine we often collect data in Python dictionaries that associate them with random variables of a specific model.

While the model captures a hypothesis about a generative process, conditioning on observed data guides the model so that it reflects these observations. Bean Machine’s syntax makes it easy to “bind” observed data to particular random variables. Later, we can use inference to sample values for other random variables that are consistent with those observations.

Let’s imagine we collected 40 temperature readings from 2 cities, but again recall that we don’t know which readings were taken from which city. This is exactly the kind of observation data we might want to bind to city_temp. Bean Machine’s syntax for binding is very simple — we’ll build up a dictionary mapping random variables to their bindings:

observations = {
city_temp(0): tensor(93.),
city_temp(1): tensor(65.),
city_temp(2): tensor(81.),
# etc...
}

Here we’re just building an ordinary Python dictionary where the keys are random variables and the values are observations for the city_temp random variable. This is perhaps a surprising type for a dictionary key at first, since it looks like it might sample from a random variable. However, a call such as city_temp(0) is actually a convenient syntax for referencing a random variable — which we use here to bind data. Later, Bean Machine will use these observations to constrain its learning process, rather than sampling for them.

The argument represents the index of the observation: city_temp(0) is the first observation, city_temp(1) is the second, and so on. We’ve written out each key in the dict here for clarity; in practice, you would likely use a dictionary comprehension (loop) to populate the observations dictionary less verbosely.

One important point: we’ve chosen to bind data specifically to city_temp only because it logically represents the observations that we’ve gathered. There’s nothing special about this random variable. You can bind data to any set of random variables, and can use this to explore the hypothesis space. That’s the power of generative modeling.

All we’ve done so far is to define our model (in the previous section) and to build up a Python dictionary of observations. Now let’s look at how to use these.

Learning

Learning is the process of improving knowledge based on observations. In the probabilistic setting, learning is called “inference,” and consists of computing distributions for variables of interest known as “queried variables.” Using the model presented above, we might be interested in querying about the mean and precision of the cities’ temperature distributions. It could also be useful to query the city assignments for each temperature — city_assignment is the model’s prediction for which city a temperature should be attributed or assigned to. Let's build a list of these queries. It's quite similar to how we bind observations to random variables:

queries = (
[city_temp_mean(k) for k in range(city_count)]
+ [city_temp_precision(k) for k in range(city_count)]
+ [city_assignment(i) for i, _ in enumerate(observations)]
)

The above for loops merely concatenate two city_temp_means, as well as the precision and city_assignments for each temperature. It's a big bucket of queries we're interested in, which assists Bean Machine's inference algorithm to run repeated randomized executions in order to answer these queries.

We can now instruct Bean Machine to use the observations to compute the distributions for our queries. Unlike parametric statistical distributions which are closed-form, Bean Machine returns so-called empirical distributions, that is, a collection of samples representing a distribution. Empirical distributions enable Bean Machine to fully automate the inference process without the user having to think about learning algorithms. There are actually many ways to perform this inference procedure, so, we’ll just use one of these for illustration purposes in this post. Here’s the code to get these distributions (don’t worry about CompositionalInference just yet.):

samples = bm.CompositionalInference().infer(
queries=queries,
observations=observations,
num_samples=1000,
)

The variable samples now contains values that represent the distributions of queries, and that are consistent with the provided observations using an inference method called CompositionalInference.

CompositionalInference is a powerful abstraction that we’ll cover in more detail shortly. For now, you can think of it as a flexible inference method that is appropriate for many kinds of models, including those with discrete random variables, stochastic control flow, and high dimensionality.

Analysis

The value in samples is a rich, DataFrame-like object that represents the distribution for each of our queries. We can index into it using the same convenient syntax that we used to bind observations and specify queries. Let’s look at the first city’s distribution of temperature means, for example:

samples[city_temp_mean(0)]

Output:

tensor([[36.008, 57.030, 61.968, ..., 50.662, 50.885, 52.002],
[53.815, 54.146, 50.922, ..., 51.523, 52.788, 56.023],
[50.302, 57.825, 59.648, ..., 51.312, 53.801, 54.359],
[49.759, 60.713, 60.082, ..., 54.080, 54.611, 53.771]])

You might have expected this to return a length-1000 vector, but instead we got a 4x1000 matrix. This is because, by default, Bean Machine performs 4 identical “chains” of inference so that we can later verify that inference ran as expected. If you want to analyze just one chain, you can do so with samples.get_chain().

There are lots of interesting analyses you can do with samples. It might be useful to plot our temperatures and city profiles, for example. We’ll leave out the data cleaning and visualization code, but check out Bean Machine’s tutorials for examples.

We’ve plotted expected values for the city temperature profiles, as well as observed scores shaded by their probability of coming from each of the two cities (blue means City 0, red means City 1). However, the full potential of probabilistic modeling can be realized when examining not only these expected values, but also their confidence intervals. Bean Machine interoperates nicely with a framework called ArviZ to provide succinct summarization of probabilistic information:

arviz.summary(samples.to_inference_data())
                                                           mean          sd     hdi_5.5%    hdi_94.5%       ess
city_temp_mean(0,) 78.387 5.709 69.396 87.378 207.0
city_temp_precision(0,) 0.096 0.153 0.041 0.101 157.0
city_temp_mean(1,) 52.127 2.758 47.618 55.859 217.0
city_temp_precision(1,) 0.150 0.133 0.073 0.183 207.0
(...remaining output truncated...)

This summarizes many useful statistics about the queries, such as means, standard deviations, and 89% credible intervals.

Please note that we’ve written the above example with education in mind. In practice, you would want to tensorize the model to minimize inference runtime. Check out our tutorials on some great examples.

In the next few sections, we’ll go over some more advanced Bean Machine functionality. Or, feel free to head over to our docs if you just want to get started.

Compositional inference

Probabilistic inference for continuous variables (as opposed to discrete ones) relies on gradient information in order to produce accurate inference results quickly, and Bean Machine handles this for you entirely behind the scenes. However, gradient information is not available for discrete random variables. The above model makes use of a discrete random variable, namely city_assignment(i). So what’s going on here? The answer is that Bean Machine supports a rich library of different inference methods, and the CompositionalInference method in particular has the ability to combine and compose these methods as appropriate for the problem at hand.

In CompositionalInference, Bean Machine automatically selected inference methods per random variable as is appropriate for each variable. For example, Bean Machine discovered that city_temp_mean and city_temp_precision were continuous, and thus used gradient information to fit these random variables. Conversely, city_assignment is a discrete random variable, and Bean Machine used a weighted sampling method to fit it.

CompositionalInference offers powerful default behavior; but, it is often desirable to have more explicit control over the exact inference strategy used. CompositionalInference exposes a convenient API to achieve this:

inference_method = bm.CompositionalInference(
{
city_temp_mean: bm.SingleSiteNoUTurnSampler(),
city_temp_precision: bm.SingleSiteNoUTurnSampler(),
city_assignment: bm.SingleSiteUniformMetropolisHastings(),
}
)
samples = inference_method.infer(
queries=queries,
observations=observations,
num_samples=1000,
)

This is similar to how we used CompositionalInference previously, but now we’re passing a configuration dictionary to CompositionalInference. In particular, we are binding certain inference methods to random variable families. The keys of this configuration dictionary are random variables functions. Whenever Bean Machine encounters a random variable contained in one of these random variable families, it will use the provided inference method in order to fit that random variable.

In this example, we use the No-U-Turn Sampler (NUTS), a gradient-based inference method for fitting our continuous random variable families city_temp_mean and city_temp_precision. We use a uniform sampler for city_assignment, which is an importance sampler that enables us to explore each of the possible city assignments for each temperature. Bean Machine offers a rich library of inference methods for use in this fashion.

Multi-site inference

We refer to the modular sampling scheme discussed in the previous section as “single-site.” This is because each “site” — that is, a specific instance of a random variable — can have an inference method tailored to it. Internally, Bean Machine iteratively samples a value for a random variable based on the assignments of other random variables, and then moves on and repeats the process for the next random variable. This modularity makes it possible to build up models with sophisticated structure and site-specific inference methods without worrying about any implementation details of how exactly the inference is performed.

However, it’s quite common for models to want to leverage information across sites. In many models, several random variables are tightly correlated. In these settings, using that correlation information during inference can help to reduce the number of samples that it takes for your model to converge to the correct results. Bean Machine offers another modular tool to let you take advantage of correlations in your model. We call this tool “multi-site inference.”

Multi-site inference is best illustrated with an example. Let’s say we wanted to exploit correlations between the mean and precision of temperature estimates in the model. Here’s how to achieve that:

inference_method = bm.CompositionalInference(
{
city_assignment: bm.SingleSiteUniformMetropolisHastings(),
(city_temp_mean, city_temp_precision): bm.GlobalNoUTurnSampler(),
}
)

Notice that we’ve bundled two different sites together in the keys of our configuration dictionary. This hints to Bean Machine that these sites may be correlated, and it should try to take advantage of that. When sampling a new value for a particular city_temp_mean value, Bean Machine will also sample a new value for city_temp_precision, and then will decide jointly whether these are appropriate samples. This is a potentially-useful optimization in this model, and is especially useful in models that have a deep latent structure, such as Hidden Markov Models.

You’ll notice that we’ve also used the GlobalNoUTurnSampler inference method here. While all inference methods support multi-site inference as outlined in the previous paragraph, some inference methods offer further algorithmic improvements. These inference method-specific improvements might, for example, allow the method to exploit correlations when using gradients. As long as you use the right inference method, Bean Machine will figure all of this out for you, automatically. You can find a list of supported inference methods here.

Higher-order inference methods

One of the advantages single-site inference brings is that it allows Bean Machine to deal with a small subcomponent of your model when sampling new values for a particular random variable. That is, in order to update one particular random variable, it doesn’t necessarily need to examine the entire model. This property is especially helpful to support operations that scale poorly with the size of your model.

In particular, the Bean Machine team is developing inference methods that use 2nd-order gradient information, though one should be aware that this is an active area of current research. In a typical (non-single-site) setting, using 2nd-order gradients can be prohibitively expensive, as this requires inversion of matrices whose size is a function of the number of random variables in the model — an expensive operation.

Bean Machine’s single-site capabilities limit this operation to just the size of the site you’re dealing with (be that single- or multi-site). Thus, something that scaled with the number of random variables in your entire model before may now scale favorably if you can decompose the model into a number of multi-sites. In this setting, we’ve found it quite tractable to use 2nd-order gradients in inference. Bean Machine provides the Newtonian Monte Carlo (NMC) inference method to take advantage of 2nd-order gradients.

inference_method = bm.CompositionalInference(
{
city_assignment: bm.SingleSiteUniformMetropolisHastings(),
city_temp_mean: bm.SingleSiteNewtonianMonteCarlo(),
city_temp_precision: bm.SingleSiteNewtonianMonteCarlo(),
}
)

Bean Machine affords good inference performance for tensorized models; however, many probabilistic models have a rich or sparse structure that is difficult to write in terms of just a handful of large tensor operations. And in many cases, these are exactly the problems for which probabilistic modeling is most compelling

To address this, we are in the process of developing Bean Machine Graph (BMG) Inference, a specialized combination of compiler and a fast, independent runtime that is optimized to run inference even for un-tensorized models. By design, BMG Inference has essentially the same interface as other Bean Machine inference methods. The idea is to make it possible to simply use with your modeling code as-is. The following code illustrates how this can be done for our running example here:

Bean Machine Graph (BMG) inference

Bean Machine affords good inference performance for tensorized models; however, many probabilistic models have a rich or sparse structure that is difficult to write in terms of just a handful of large tensor operations. And in many cases, these are exactly the problems for which probabilistic modeling is most compelling.

To address this, we are in the process of developing Bean Machine Graph (BMG) Inference, a specialized combination of compiler and a fast, independent runtime that is optimized to run inference even for un-tensorized models. By design, BMG Inference has essentially the same interface as other Bean Machine inference methods. The idea is to make it possible to simply use with your modeling code as-is. The following code illustrates how this can be done for our running example here:

samples = BMGInference().infer(
queries=queries,
observations=observations,
num_samples=1000,
)

Here we have simply replaced with inference_method with BMGInference . BMG Inference routinely achieves 1 to 2 orders-of-magnitude speedup.

Behind the scenes, BMGInference is using a custom compiler to interpret your Bean Machine model and translate it to a specialized implementation with no Python dependencies. As a user, this should all just work automatically. Currently, the feature set it supports is limited, although we are rapidly building it up to support common modeling problems — both in terms of modeling code, as well as compositional inference capabilities. We’re including it in this beta release of Bean Machine as an indication of where we’re hoping to evolve the framework; for now, please consider it a very experimental preview, and we encourage you to check out its documentation for the latest on what model features it supports.

What now?

There’s lots more to learn about Bean Machine. Head over to our docs or tutorials to get started.

Etymology

The system’s language and logo are inspired from a physical device for visualizing probability distributions, a pre-computing example of a probabilistic system.