Variational message passing

This section briefly describes the variational message passing (VMP) framework, which is currently the only implemented inference engine in BayesPy. The variational Bayesian (VB) inference engine in BayesPy assumes that the posterior approximation factorizes with respect to nodes and plates. VMP is based on updating one node at a time (the plates in one node can be updated simultaneously) and iteratively updating all nodes in turns until convergence.

Standard update equation

The general update equation for the factorized approximation of node \boldsymbol{\theta} is the following:

(1)\log q(\boldsymbol{\theta})
&=
\langle
  \log p\left( \boldsymbol{\theta} |
               \mathrm{pa}(\boldsymbol{\theta}) \right)
\rangle
+ \sum_{\mathbf{x} \in \mathrm{ch}(\boldsymbol{\theta})}
  \langle \log p(\mathbf{x}|\mathrm{pa}(\mathbf{x})) \rangle
+ \mathrm{const},

where \mathrm{pa}(\boldsymbol{\theta}) and \mathrm{ch}(\boldsymbol{\theta}) are the set of parents and children of \boldsymbol{\theta}, respectively. Thus, the posterior approximation of a node is updated by taking a sum of the expectations of all log densities in which the node variable appears. The expectations are over the approximate distribution of all other variables than \boldsymbol{\theta}. Actually, not all the variables are needed, because the non-constant part depends only on the Markov blanket of \boldsymbol{\theta}. This leads to a local optimization scheme, which uses messages from neighbouring nodes.

The messages are simple for conjugate exponential family models. An exponential family distribution has the following log probability density function:

(2)\log p(\mathbf{x}|\mathbf{\Theta})
&=
\mathbf{u}_{\mathbf{x}}(\mathbf{x})^{\mathrm{T}}
\boldsymbol{\phi}_{\mathbf{x}}(\mathbf{\Theta})
+ g_{\mathbf{x}}(\mathbf{\Theta})
+ f_{\mathbf{x}}(\mathbf{x}),

where \mathbf{\Theta}=\{\boldsymbol{\theta}_j\} is the set of parents, \mathbf{u} is the sufficient statistic vector, \boldsymbol{\phi} is the natural parameter vector, g is the negative log normalizer, and f is the log base function. Note that the log density is linear with respect to the terms that are functions of \mathbf{x}: \mathbf{u} and f. If a parent has a conjugate prior, (2) is also linear with respect to the parent’s sufficient statistic vector. Thus, (2) can be re-organized with respect to a parent \boldsymbol{\theta}_j as

\log p(\mathbf{x}|\mathbf{\Theta})
&=
\mathbf{u}_{\boldsymbol{\theta}_j}(\boldsymbol{\theta}_j)^{\mathrm{T}}
\boldsymbol{\phi}_{\mathbf{x}\rightarrow\boldsymbol{\theta}_j}
(\mathbf{x}, \{\boldsymbol{\theta}_k\}_{k\neq j})
+ \mathrm{const},

where \mathbf{u}_{\boldsymbol{\theta}_j} is the sufficient statistic vector of \boldsymbol{\theta}_j and the constant part is constant with respect to \boldsymbol{\theta}_j. Thus, the update equation (1) for \boldsymbol{\theta}_j can be written as

\log q(\boldsymbol{\theta}_j)
&=
\mathbf{u}_{\boldsymbol{\theta}_j}(\boldsymbol{\theta}_j)^{\mathrm{T}}
  \langle \boldsymbol{\phi}_{\boldsymbol{\theta}_j} \rangle
+ f_{\boldsymbol{\theta}_j}(\boldsymbol{\theta}_j)
+
\mathbf{u}_{\boldsymbol{\theta}_j}(\boldsymbol{\theta}_j)^{\mathrm{T}}
\sum_{\mathbf{x} \in \mathrm{ch}(\boldsymbol{\theta}_j)}
      \langle \boldsymbol{\phi}_{\mathbf{x}\rightarrow\boldsymbol{\theta}_j} \rangle
+ \mathrm{const},
\\
&=
\mathbf{u}_{\boldsymbol{\theta}_j}(\boldsymbol{\theta}_j)^{\mathrm{T}}
\left(
  \langle \boldsymbol{\phi}_{\boldsymbol{\theta}_j} \rangle
  + \sum_{\mathbf{x} \in \mathrm{ch}(\boldsymbol{\theta}_j)}
      \langle \boldsymbol{\phi}_{\mathbf{x}\rightarrow\boldsymbol{\theta}_j} \rangle
\right)
+ f_{\boldsymbol{\theta}_j}(\boldsymbol{\theta}_j)
+ \mathrm{const},

where the summation is over all the child nodes of \boldsymbol{\theta}_j. Because of the conjugacy, \langle\boldsymbol{\phi}_{\boldsymbol{\theta}_j}\rangle depends (multi)linearly on the parents’ sufficient statistic vector. Similarly, \langle \boldsymbol{\phi}_{\mathbf{x}\rightarrow\boldsymbol{\theta}_j}
\rangle depends (multi)linearly on the expectations of the children’s and co-parents’ sufficient statistics. This gives the following update equation for the natural parameter vector of the posterior approximation q(\boldsymbol{\phi}_j):

(3)\tilde{\boldsymbol{\phi}}_j &= \langle \boldsymbol{\phi}_{\boldsymbol{\theta}_j} \rangle
  + \sum_{\mathbf{x} \in \mathrm{ch}(\boldsymbol{\theta}_j)} \langle
      \boldsymbol{\phi}_{\mathbf{x}\rightarrow\boldsymbol{\theta}_j} \rangle.

Variational messages

The update equation (3) leads to a message passing scheme: the term \langle \boldsymbol{\phi}_{\boldsymbol{\theta}_j} \rangle is a function of the parents’ sufficient statistic vector and the term \langle
\boldsymbol{\phi}_{\mathbf{x}\rightarrow\boldsymbol{\theta}_j} \rangle can be interpreted as a message from the child node \mathbf{x}. Thus, the message from the child node \mathbf{x} to the parent node \boldsymbol{\theta} is

\mathbf{m}_{\mathbf{x}\rightarrow\boldsymbol{\theta}}
&\equiv
\langle \boldsymbol{\phi}_{\mathbf{x}\rightarrow\boldsymbol{\theta}} \rangle,

which can be computed as a function of the sufficient statistic vector of the co-parent nodes of \boldsymbol{\theta} and the sufficient statistic vector of the child node \mathbf{x}. The message from the parent node \boldsymbol{\theta} to the child node \mathbf{x} is simply the expectation of the sufficient statistic vector:

\mathbf{m}_{\mathbf{\boldsymbol{\theta}}\rightarrow\mathbf{x}}
&\equiv
\langle \mathbf{u}_{\boldsymbol{\theta}} \rangle.

In order to compute the expectation of the sufficient statistic vector we need to write q(\boldsymbol{\theta}) as

\log q(\boldsymbol{\theta}) &=
\mathbf{u}(\boldsymbol{\theta})^{\mathrm{T}}
\tilde{\boldsymbol{\phi}}
+ \tilde{g}(\tilde{\boldsymbol{\phi}})
+ f(\boldsymbol{\theta}),

where \tilde{\boldsymbol{\phi}} is the natural parameter vector of q(\boldsymbol{\theta}). Now, the expectation of the sufficient statistic vector is defined as

(4)\langle \mathbf{u}_{\boldsymbol{\theta}} \rangle
&= - \frac{\partial \tilde{g}}{\partial
\tilde{\boldsymbol{\phi}}_{\boldsymbol{\theta}}}
(\tilde{\boldsymbol{\phi}}_{\boldsymbol{\theta}}).

We call this expectation of the sufficient statistic vector as the moments vector.

Lower bound

Computing the VB lower bound is not necessary in order to find the posterior approximation, although it is extremely useful in monitoring convergence and possible bugs. The VB lower bound can be written as

\mathcal{L} = \langle \log p(\mathbf{Y}, \mathbf{X}) \rangle - \langle \log
q(\mathbf{X}) \rangle,

where \mathbf{Y} is the set of all observed variables and \mathbf{X} is the set of all latent variables. It can also be written as

\mathcal{L} = \sum_{\mathbf{y} \in \mathbf{Y}} \langle \log p(\mathbf{y} |
\mathrm{pa}(\mathbf{y})) \rangle
+ \sum_{\mathbf{x} \in \mathbf{X}} \left[ \langle \log p(\mathbf{x} |
  \mathrm{pa}(\mathbf{x})) \rangle - \langle \log q(\mathbf{x}) \right],

which shows that observed and latent variables contribute differently to the lower bound. These contributions have simple forms for exponential family nodes. Observed exponential family nodes contribute to the lower bound as follows:

\langle \log p(\mathbf{y}|\mathrm{pa}(\mathbf{y})) \rangle &=
\mathbf{u}(\mathbf{y})^T \langle \boldsymbol{\phi} \rangle
+ \langle g \rangle + f(\mathbf{x}),

where \mathbf{y} is the observed data. On the other hand, latent exponential family nodes contribute to the lower bound as follows:

\langle \log p(\mathbf{x}|\boldsymbol{\theta}) \rangle
- \langle \log q(\mathbf{x}) \rangle &= \langle \mathbf{u} \rangle^T (\langle
\boldsymbol{\phi} \rangle - \tilde{\boldsymbol{\phi}} )
+ \langle g \rangle - \tilde{g}.

If a node is partially observed and partially unobserved, these formulas are applied plate-wise appropriately.

Terms

To summarize, implementing VMP requires one to write for each stochastic exponential family node:

\langle \boldsymbol{\phi} \rangle : the expectation of the prior natural parameter vector

Computed as a function of the messages from parents.

\tilde{\boldsymbol{\phi}} : natural parameter vector of the posterior approximation

Computed as a sum of \langle \boldsymbol{\phi} \rangle and the messages from children.

\langle \mathbf{u} \rangle : the posterior moments vector

Computed as a function of \tilde{\boldsymbol{\phi}} as defined in (4).

\mathbf{u}(\mathbf{x}) : the moments vector for given data

Computed as a function of of the observed data \mathbf{x}.

\langle g \rangle : the expectation of the negative log normalizer of the prior

Computed as a function of parent moments.

\tilde{g} : the negative log normalizer of the posterior approximation

Computed as a function of \tilde{\boldsymbol{\phi}}.

f(\mathbf{x}) : the log base measure for given data

Computed as a function of the observed data \mathbf{x}.

\langle \boldsymbol{\phi}_{\mathbf{x}\rightarrow\boldsymbol{\theta}}
\rangle : the message to parent \boldsymbol{\theta}

Computed as a function of the moments of this node and the other parents.

Deterministic nodes require only the following terms:

\langle \mathbf{u} \rangle : the posterior moments vector

Computed as a function of the messages from the parents.

\mathbf{m} : the message to a parent

Computed as a function of the messages from the other parents and all children.