Variational inference
In the last chapter, we saw that inference in probabilistic models in often intractable, and we learned about algorithms that provide approximate solutions to the inference problem (e.g. marginal inference) by using subroutines that involve sampling random variables. Most sampling-based inference algorithms are instances of Markov Chain Monte-Carlo (MCMC); two popular MCMC methods are Gibbs sampling and Metropolis-Hastings.
Unfortunately, these sampling-based methods have several important shortcomings.
- Although they are guaranteed to find a globally optimal solution given enough time, it is difficult to tell how close they are to a good solution given the finite amount of time that they have in practice.
- In order to quickly reach a good solution, MCMC methods require choosing an appropriate sampling technique (e.g. a good proposal in Metropolis-Hastings). Choosing this technique can be an art in itself.
In this chapter, we are going to look at an alternative approach to approximate inference called the variational family of algorithms.
Inference as optimization
The main idea of variational methods is to cast inference as an optimization problem.
Suppose we are given an intractable probability distribution . Variational techniques will try to solve an optimization problem over a class of tractable distributions in order to find a that is most similar to . We will then query (rather than ) in order to get an approximate solution.
The main differences between sampling and variational techniques are that:
- Unlike sampling-based methods, variational approaches will almost never find the globally optimal solution.
- However, we will always know if they have converged. In some cases, we will even have bounds on their accuracy.
- In practice, variational inference methods often scale better are more amenable to techniques like stochastic gradient optimization, parallelization over multiple processors, and acceleration using GPUs.
Although sampling methods were historically invented first (in the 1940’s), variational techniques have been steadily gaining popularity and are currently the more widely used inference technique.
The Kullback-Leibler divergence
To formulate inference as an optimization problem, we need to choose an approximating family and an optimization objective . This objective needs to capture the similarity between and ; the field of information theory provides us with a tool for this called the Kullback-Leibler (KL) divergence.
Formally, the KL divergence between two distributions $q$ and $p$ with discrete support is defined as
In information theory, this function is used to measure differences in information contained within two distributions. The KL divergence has the following properties that make it especially useful in our setting:
- for all .
- if and only if
These can be proven as an exercise. Note however that , i.e. the KL divergence is not symmetric. This is why we say that it’s a divergence, but not a distance. We will come back to this distinction shortly.
The variational lower bound
How perform variational inference with a KL divergence? First, let’s fix a form for . We’ll that assume that is a general (discrete, for simplicity) undirected model of the form
where the are the factors and is the normalization constant. This formulation captures virtually all the distributions in which we might want to perform approximate inference, such as marginal distributions of directed models with evidence .
Given this formulation, optimizing directly is not possible because of the potentially intractable normalization constant . In fact, even evaluating is not possible, because we need to evaluate .
Instead, we will work with the following objective, which has the same form as the KL divergence, but only involves the unnormalized probability :
This function is not only tractable, it also has the following important property:
Since , we get by rearranging terms that
Thus, is a lower bound on the partition function . In many cases, has an interesting interpretation. For example, we may be trying to compute the marginal probability of variables given observed data that plays the role of evidence. We assume that is directed. In this case, minimizing amounts to maximizing a lower bound on the log-likelihood of the observed data.
Because of this property, is called the variational lower bound or the evidence lower bound (ELBO); it often written in the form
Crucially, the difference between and is precisely . Thus, by maximizing the evidence-lower bound, we are minimizing by “squeezing” it between and .
On the choice of KL divergence
To recap, we have just defined an optimization objective for variational inference (the variational lower bound) and we have shown that maximizing the lower bound leads to minimizing the divergence .
Recall how we said earlier that ; both divergences equal zero when , but assign different penalties when . This raises the question: why did we choose one over the other and how do they differ?
Perhaps the most important difference is computational: optimizing involves an expectation with respect to , while requires computing expectations with respect to , which is typically intractable even to evaluate.
However, choosing this particular divergence affects the returned solution when the approximating family does not contain the true . Observe that — which is called the I-projection or information projection — is infinite if and :
This means that if we must have . We say that is zero-forcing for and it will typically under-estimate the support of
On the other hand, — known as the M-projection or the moment projection — is infinite if and . Thus, if we must have . We say that is zero-avoiding for and it will typically over-estimate the support of .
The figure below illustrates this phenomenon graphically.
Due to the properties that we just described, we often call the inclusive KL divergence, while is the exclusive KL divergence.
Mean-field inference
The next step in our development of variational inference concerns the choice of approximating family . The machine learning literature contains dozens of proposed ways to parametrize this class of distributions; these include exponential families, neural networks, Gaussian processes, latent variable models, and many others types of models.
However, one of the most widely used classes of distributions is simply the set of fully-factored ; here each is categorical distribution over a one-dimensional discrete variable, which can be described as a one-dimensional table.
This choice of turns out to be easy to optimize over and works surprisingly well. It is perhaps the most popular choice when optimizing the variational bound; variational inference with this choice of is called mean-field inference. It consists in solving the following optimization problem:
The standard way of performing this optimization problem is via coordinate descent over the : we iterate over and for each we optimize over while keeping the other “coordinates” fixed.
Interestingly, the optimization problem for one coordinate has a simple closed form solution:
Notice that both sides of the above equation contain univariate functions of : we are thus replacing with another function of the same form. The constant term is a normalization constant for the new distribution.
Notice also that on right-hand side, we are taking an expectation of a sum of factors
Of these, only factors belonging to the Markov blanket of are a function of (simply by the definition of the Markov blanket); the rest are constant with respect to and can be pushed into the constant term.
This leaves us with an expectation over a much smaller number of factors; if the Markov blanket of is small (as is often the case), we are able to analytically compute . For example, if the variables are discrete with possible values, and there are factors and variables in the Markov blanket of , then computing the expectation takes time: for each value of we sum over all assignments of the variables, and in each case, we sum over the factors.
The result of this is a procedure that iteratively fits a fully-factored that approximates in terms of . After each step of coordinate descent, we increase the variational lower bound, tightening it around .
In the end, the factors will not quite equal the true marginal distributions , but they will often be good enough for many practical purposes, such as determining .