# Source code for bayespy.inference.vmp.nodes.concat_gaussian

```import numpy as np

from bayespy.utils import misc
from bayespy.utils import linalg
from .gaussian import GaussianMoments
from .deterministic import Deterministic

[docs]class ConcatGaussian(Deterministic):
"""Concatenate Gaussian vectors along the variable axis (not plate axis)

NOTE: This concatenates on the variable axis! That is, the dimensionality
of the resulting Gaussian vector is the sum of the dimensionalities of the
input Gaussian vectors.

TODO: Add support for Gaussian arrays and arbitrary concatenation axis.
"""

[docs]    def __init__(self, *nodes, **kwargs):

# Number of nodes to concatenate
N = len(nodes)

# This is stuff that will be useful when implementing arbitrary
# concatenation. That is, first determine ndim.
#
# # Convert nodes to Gaussians (if they are not nodes, don't worry)
# nodes_gaussian = []
# for node in nodes:
#     try:
#         node_gaussian = node._convert(GaussianMoments)
#     except AttributeError: # Moments.NoConverterError:
#         nodes_gaussian.append(node)
#     else:
#         nodes_gaussian.append(node_gaussian)
# nodes = nodes_gaussian
#
# # Determine shape from the first Gaussian node
# shape = None
# for node in nodes:
#     try:
#         shape = node.dims
#     except AttibuteError:
#         pass
#     else:
#         break
# if shape is None:
#     raise ValueError("Couldn't determine shape from the input nodes")
#
# ndim = len(shape)

nodes = [self._ensure_moments(node, GaussianMoments, ndim=1)
for node in nodes]

D = sum(node.dims for node in nodes)

shape = (D,)

self._moments = GaussianMoments(shape)

self._parent_moments = [node._moments for node in nodes]

# Make sure all parents are Gaussian vectors
if any(len(node.dims) != 1 for node in nodes):
raise ValueError("Input nodes must be (Gaussian) vectors")

self.slices = tuple(np.cumsum( + [node.dims for node in nodes]))
D = self.slices[-1]

return super().__init__(*nodes, dims=((D,), (D, D)), **kwargs)

def _compute_moments(self, *u_nodes):
x = misc.concatenate(*[u for u in u_nodes], axis=-1)
xx = misc.block_diag(*[u for u in u_nodes])

# Explicitly broadcast xx to plates of x
x_plates = np.shape(x)[:-1]
xx = np.ones(x_plates)[...,None,None] * xx

# Compute the cross-covariance terms using the means of each variable
# (because covariances are zero for factorized nodes in the VB
# approximation)
i_start = 0
for m in range(len(u_nodes)):
i_end = i_start + np.shape(u_nodes[m])[-1]
j_start = 0
for n in range(m):
j_end = j_start + np.shape(u_nodes[n])[-1]
xm_xn = linalg.outer(u_nodes[m], u_nodes[n], ndim=1)
xx[...,i_start:i_end,j_start:j_end] = xm_xn
xx[...,j_start:j_end,i_start:i_end] = misc.T(xm_xn)
j_start = j_end
i_start = i_end

return [x, xx]

def _compute_message_to_parent(self, i, m, *u_nodes):
r = self.slices

# Pick the proper parts from the message array
m0 = m[...,r[i]:r[i+1]]
m1 = m[...,r[i]:r[i+1],r[i]:r[i+1]]

# Handle cross-covariance terms by using the mean of the covariate node
for (j, u) in enumerate(u_nodes):
if j != i:
m0 = m0 + 2 * np.einsum(
'...ij,...j->...i',
m[...,r[i]:r[i+1],r[j]:r[j+1]],
u
)

return [m0, m1]
```