# Source code for bayespy.inference.vmp.nodes.gate

```################################################################################
# Copyright (C) 2014 Jaakko Luttinen
#
################################################################################

"""
"""

import numpy as np

from bayespy.utils import misc

from .node import Node, Moments
from .deterministic import Deterministic
from .categorical import CategoricalMoments
from .concatenate import Concatenate

[docs]class Gate(Deterministic):
"""
Deterministic gating of one node.

Gating is performed over one plate axis.

Note: You should not use gating for several variables which parents of a
same node if the gates use the same gate assignments.  In such case, the
results will be wrong.  The reason is a general one: A stochastic node may
not be a parent of another node via several paths unless at most one path
has no other stochastic nodes between them.
"""

[docs]    def __init__(self, Z, X, gated_plate=-1, moments=None, **kwargs):
"""
Constructor for the gating node.

Parameters
----------

Z : Categorical-like node
A variable which chooses the index along the gated plate axis

X : node
The node whose plate axis is gated

gated_plate : int (optional)
The index of the plate axis to be gated (by default, -1, that is,
the last axis).
"""

if gated_plate >= 0:
raise ValueError("Cluster plate must be negative integer")
self.gated_plate = gated_plate

if moments is not None:
X = self._ensure_moments(
X,
moments.__class__,
**moments.get_instance_conversion_kwargs()
)

if not isinstance(X, Node):
raise ValueError("X must be a node or moments should be provided")

X_moments = X._moments
self._moments = X_moments
dims = X.dims
if len(X.plates) < abs(gated_plate):
raise ValueError("The gated node does not have a plate axis is "
"gated")
K = X.plates[gated_plate]

Z = self._ensure_moments(Z, CategoricalMoments, categories=K)

self._parent_moments = (Z._moments, X_moments)

if Z.dims != ( (K,), ):
raise ValueError("Inconsistent number of clusters")

self.K = K

super().__init__(Z, X, dims=dims, **kwargs)

def _compute_moments(self, u_Z, u_X):
"""
"""

u = []
for i in range(len(u_X)):
# Make the moments of Z and X broadcastable and move the gated plate
# to be the last axis in the moments, then sum-product over that
# axis
ndim = len(self.dims[i])
z = misc.moveaxis(z, -ndim-1, -1)
gated_axis = self.gated_plate - ndim
if np.ndim(u_X[i]) < abs(gated_axis):
else:
x = misc.moveaxis(u_X[i], gated_axis, -1)
ui = misc.sum_product(z, x, axes_to_sum=-1)
u.append(ui)
return u

def _compute_message_to_parent(self, index, m_child, u_Z, u_X):
"""
"""
if index == 0:
m0 = 0
# Compute Child * X, sum over variable axes and move the gated axis
# to be the last.  Need to do some shape changing in order to make
# Child and X to broadcast properly.
for i in range(len(m_child)):
ndim = len(self.dims[i])
c = m_child[i][...,None]
c = misc.moveaxis(c, -1, -ndim-1)
gated_axis = self.gated_plate - ndim
x = u_X[i]
if np.ndim(x) < abs(gated_axis):
x = np.expand_dims(x, -ndim-1)
else:
x = misc.moveaxis(x, gated_axis, -ndim-1)
axes = tuple(range(-ndim, 0))
m0 = m0 + misc.sum_product(c, x, axes_to_sum=axes)

# Make sure the variable axis does not use broadcasting
m0 = m0 * np.ones(self.K)

# Send the message
m = [m0]
return m

elif index == 1:

m = []
for i in range(len(m_child)):
# Make the moments of Z and the message from children
# broadcastable. The gated plate is handled as the last axis in
# the arrays and moved to the correct position at the end.

# Add variable axes to Z moments
ndim = len(self.dims[i])
z = misc.moveaxis(z, -ndim-1, -1)
# Axis index of the gated plate
gated_axis = self.gated_plate - ndim
# Add the gate axis to the message from the children
# Compute the message to parent
mi = z * c
# Add extra axes if necessary
if np.ndim(mi) < abs(gated_axis):
abs(gated_axis) - np.ndim(mi))
# Move the axis to the correct position
mi = misc.moveaxis(mi, -1, gated_axis)
m.append(mi)

return m

else:
raise ValueError("Invalid parent index")

def _compute_weights_to_parent(self, index, weights):
"""
"""
if index == 0:
return weights
elif index == 1:
if self.gated_plate >= 0:
raise ValueError("Gated plate axis must be negative")
return (
np.expand_dims(weights, axis=self.gated_plate)
if np.ndim(weights) >= abs(self.gated_plate) else
weights
)
else:
raise ValueError("Invalid parent index")

def _compute_plates_to_parent(self, index, plates):
"""
"""
if index == 0:
return plates
elif index == 1:
plates = list(plates)
# Add the cluster plate axis
if self.gated_plate < 0:
knd = len(plates) + self.gated_plate + 1
else:
raise RuntimeError("Cluster plate axis must be negative")
plates.insert(knd, self.K)

return tuple(plates)
else:
raise ValueError("Invalid parent index")

def _compute_plates_from_parent(self, index, plates):
"""
"""
if index == 0:
return plates
elif index == 1:
plates = list(plates)
# Remove the cluster plate, if the parent has it
if len(plates) >= abs(self.gated_plate):
plates.pop(self.gated_plate)
return tuple(plates)
else:
raise ValueError("Invalid parent index")

[docs]def Choose(z, *nodes):
"""Choose plate elements from nodes based on a categorical variable.

For instance:

.. testsetup::

from bayespy.nodes import *

.. code-block:: python

>>> import bayespy as bp
>>> z = [0, 0, 2, 1]
>>> x0 = bp.nodes.GaussianARD(0, 1)
>>> x1 = bp.nodes.GaussianARD(10, 1)
>>> x2 = bp.nodes.GaussianARD(20, 1)
>>> x = bp.nodes.Choose(z, x0, x1, x2)
>>> print(x.get_moments())
[ 0.  0. 20. 10.]

This is basically just a thin wrapper over applying Gate node over the
concatenation of the nodes.
"""
categories = len(nodes)
z = Deterministic._ensure_moments(
z,
CategoricalMoments,
categories=categories
)
nodes = [node[...,None] for node in nodes]
combined = Concatenate(*nodes)
return Gate(z, combined)
```