Source code for bayespy.plot

################################################################################
# Copyright (C) 2011-2013 Jaakko Luttinen
#
# This file is licensed under the MIT License.
################################################################################


"""
Functions for plotting nodes.

Functions
=========

.. currentmodule:: bayespy.plot

.. autosummary::
   :toctree: generated/

   pdf
   contour
   plot
   hinton
   gaussian_mixture_2d
   ellipse_from_cov
   ellipse_from_precision

Plotters
========

.. autosummary::
   :toctree: generated/

   Plotter
   PDFPlotter
   ContourPlotter
   HintonPlotter
   FunctionPlotter
   GaussianTimeseriesPlotter
   CategoricalMarkovChainPlotter
"""


import os, sys

############################################################################
# A STUPID WORKAROUND FOR A MATPLOTLIB 1.4.0 BUG RELATED TO INTERACTIVE MODE
# See: https://github.com/matplotlib/matplotlib/issues/3505
import __main__
if hasattr(__main__, '__file__'):
    sys.ps1 = ('WORKAROUND FOR A BUG #3505 IN MATPLOTLIB.\n'
               'IF YOU SEE THIS MESSAGE, TRY MATPLOTLIB!=1.4.0.')
# This workaround does not work on Python shell, only on stand-alone scripts
# and IPython. A better solution: require MPL!=1.4.0.
#############################################################################

import numpy as np
import scipy.sparse as sp
import scipy
from scipy import special
import matplotlib.pyplot as plt
from matplotlib import animation
from matplotlib import colors
from matplotlib import patches
#from matplotlib.pyplot import *

from bayespy.inference.vmp.nodes.categorical import CategoricalMoments
from bayespy.inference.vmp.nodes.gaussian import (GaussianMoments,
                                                  GaussianWishartMoments)
from bayespy.inference.vmp.nodes.beta import BetaMoments
from bayespy.inference.vmp.nodes.beta import DirichletMoments
from bayespy.inference.vmp.nodes.bernoulli import BernoulliMoments
from bayespy.inference.vmp.nodes.categorical import CategoricalMoments
from bayespy.inference.vmp.nodes.gamma import GammaMoments
from bayespy.inference.vmp.nodes.node import Node, Moments

from bayespy.utils import (misc,
                           random,
                           linalg)


# Users can use pyplot via this module
import matplotlib
mpl = matplotlib
pyplot = plt


def interactive(function):
    """A decorator for forcing functions to use the interactive mode.

    Parameters
    ----------

    function : callable
        The function to be decorated
    """

    def new_function(*args, **kwargs):
        if mpl.is_interactive():
            was_interactive = True
        else:
            was_interactive = False
            mpl.interactive(True)

        retval = function(*args, **kwargs)

        if not was_interactive:
            mpl.interactive(False)

        return retval

    return new_function


def _subplots(plotfunc, *args, fig=None, kwargs=None):
    """Create a collection of subplots

    Each subplot is created with the same plotting function.

    Inputs are given as pairs:

    (x, 3), (y, 2), ...

    where x,y,... are the input arrays and 3,2,... are the ndim
    parameters.  The last ndim axes of each array are interpreted as a
    single element to the plotting function.

    All high-level plotting functions should wrap low-level plotting
    functions with this function in order to generate subplots for
    plates.
    """

    if kwargs is None:
        kwargs = {}

    if fig is None:
        fig = plt.gcf()

    # Parse shape and plates of each input array
    shapes = [np.shape(x)[-n:] if n > 0 else ()
              for (x,n) in args]
    plates = [np.shape(x)[:-n] if n > 0 else np.shape(x)
              for (x,n) in args]

    # Get the full grid shape of the subplots
    broadcasted_plates = misc.broadcasted_shape(*plates)

    # Subplot indexing layout
    M = np.prod(broadcasted_plates[-2::-2])
    N = np.prod(broadcasted_plates[-1::-2])
    strides_subplot = [np.prod(broadcasted_plates[(j+2)::2]) * N
                       if ((len(broadcasted_plates)-j) % 2) == 0 else
                       np.prod(broadcasted_plates[(j+2)::2])
                       for j in range(len(broadcasted_plates))]

    # Plot each subplot
    for ind in misc.nested_iterator(broadcasted_plates):

        # Get the list of inputs for this subplot
        broadcasted_args = []
        for n in range(len(args)):
            i = misc.safe_indices(ind, plates[n])
            broadcasted_args.append(args[n][0][i])

        # Plot the subplot using the given function
        ind_subplot = np.einsum('i,i', ind, strides_subplot)
        axes = fig.add_subplot(M, N, ind_subplot+1)
        plotfunc(*broadcasted_args, axes=axes, **kwargs)


[docs]def pdf(Z, x, *args, name=None, axes=None, fig=None, **kwargs): """ Plot probability density function of a scalar variable. Parameters ---------- Z : node or function Stochastic node or log pdf function x : array Grid points """ # TODO: Make it possible to plot a plated variable using _subplots function. if axes is None and fig is None: axes = plt.gca() else: if fig is None: fig = plt.gcf() axes = fig.add_subplot(111) try: lpdf = Z.logpdf(x) except AttributeError: lpdf = Z(x) p = np.exp(lpdf) retval = axes.plot(x, p, *args, **kwargs) if name is None: try: name = Z.name except AttributeError: pass if name: axes.set_title(r'$q(%s)$' % (name)) axes.set_xlabel(r'$%s$' % (name)) return retval
[docs]def contour(Z, x, y, n=None, axes=None, fig=None, **kwargs): """ Plot 2-D probability density function of a 2-D variable. Parameters ---------- Z : node or function Stochastic node or log pdf function x : array Grid points on x axis y : array Grid points on y axis """ # TODO: Make it possible to plot a plated variable using _subplots function. if axes is None and fig is None: axes = plt.gca() else: if fig is None: fig = plt.gcf() axes = fig.add_subplot(111) XY = misc.grid(x, y) try: lpdf = Z.logpdf(XY) except AttributeError: lpdf = Z(XY) p = np.exp(lpdf) shape = (np.size(x), np.size(y)) X = np.reshape(XY[:,0], shape) Y = np.reshape(XY[:,1], shape) P = np.reshape(p, shape) if n is not None: levels = np.linspace(0, np.amax(P), num=n+2)[1:-1] return axes.contour(X, Y, P, levels, **kwargs) else: return axes.contour(X, Y, P, **kwargs)
def plot_gaussian_mc(X, scale=2, **kwargs): """ Plot Gaussian Markov chain as a 1-D function Parameters ---------- X : node Node with Gaussian Markov chain moments. """ timeseries_gaussian(X, axis=-2, scale=scale, **kwargs) def plot_bernoulli(X, axis=-1, scale=2, **kwargs): """ Plot Bernoulli node as a 1-D function """ X = X._ensure_moments(X, BernoulliMoments) u_X = X.get_moments() z = u_X[0] return _timeseries_mean_and_error(z, None, axis=axis, **kwargs) def plot_gaussian(X, axis=-1, scale=2, **kwargs): """ Plot Gaussian node as a 1-D function Parameters ---------- X : node Node with Gaussian moments. axis : int The index of the time axis. """ X = X._ensure_moments(X, GaussianMoments, ndim=0) u_X = X.get_moments() x = u_X[0] xx = misc.get_diag(u_X[1], ndim=len(X.dims[0])) std = scale * np.sqrt(xx - x**2) #std = scale * np.sqrt(np.einsum('...ii->...i', xx) - x**2) return _timeseries_mean_and_error(x, std, axis=axis, **kwargs)
[docs]def plot(Y, axis=-1, scale=2, center=False, **kwargs): """ Plot a variable or an array as 1-D function with errorbars """ if misc.is_numeric(Y): return _timeseries_mean_and_error(Y, None, axis=axis, center=center, **kwargs) if isinstance(Y, Node): # Try Bernoulli plotting try: Y = Y._ensure_moments(Y, BernoulliMoments) except BernoulliMoments.NoConverterError: pass else: return plot_bernoulli(Y, axis=axis, scale=scale, center=center, **kwargs) # Try Gaussian plotting try: Y = Y._ensure_moments(Y, GaussianMoments, ndim=0) except GaussianMoments.NoConverterError: pass else: return plot_gaussian(Y, axis=axis, scale=scale, center=center, **kwargs) (mu, var) = Y.get_mean_and_variance() std = np.sqrt(var) return _timeseries_mean_and_error(mu, std, axis=axis, scale=scale, center=center, **kwargs)
# Some backward compatibility def timeseries_gaussian_mc(*args, center=True, **kwargs): return plot_gaussian_mc(*args, center=center, **kwargs) def timeseries_gaussian(*args, center=True, **kwargs): return plot_gaussian(*args, center=center, **kwargs) timeseries_normal = timeseries_gaussian def timeseries(*args, center=True, **kwargs): return plot(*args, center=center, **kwargs) def _timeseries_mean_and_error(y, std, *args, axis=-1, center=True, fig=None, axes=None, **kwargs): # TODO/FIXME: You must multiply by ones(plates) in order to plot # broadcasted plates properly if fig is None: fig = plt.gcf() y = np.atleast_1d(y) shape = list(np.shape(y)) # Get and remove the length of the time axis T = shape.pop(axis) # Move time axis to first y = np.rollaxis(y, axis) if std is not None: std = np.rollaxis(std, axis) y = np.reshape(y, (T, -1)) if std is not None: std = np.reshape(std, (T, -1)) # Remove 1s shape = [s for s in shape if s > 1] # Calculate number of rows and columns shape = misc.multiply_shapes(shape, (1,1)) if len(shape) > 2: raise Exception("Can plot only in 2 dimensions (rows and columns)") (M, N) = shape # Prefer plotting to rows if M == 1: M = N N = 1 # Plot each timeseries if axes is None: ax0 = fig.add_subplot(M, N, 1) for i in range(M*N): if axes is None: if i > 0: # Share x axis between all subplots ax = fig.add_subplot(M, N, i+1, sharex=ax0) else: ax = ax0 # Autoscale the axes to data and use tight y and x axes ax.autoscale(enable=True, tight=True) ax.set_ylim(auto=True) if i < (M-1)*N: # Remove x tick labels from other than the last row plt.setp(ax.get_xticklabels(), visible=False) else: ax = axes[i] if std is None: errorplot(y=y[:,i], axes=ax, **kwargs) else: if len(args) > 0: raise Exception("Can't handle extra arguments") errorplot(y=y[:,i], error=std[:,i], axes=ax, **kwargs) if center: # Center the zero level on y-axis ylim = ax.get_ylim() vmax = np.max(np.abs(ylim)) ax.set_ylim([-vmax, vmax]) if axes is None: # Remove height space between subplots fig.subplots_adjust(hspace=0) def _blob(axes, x, y, area, colour): """ Draws a square-shaped blob with the given area (< 1) at the given coordinates. """ hs = np.sqrt(area) / 2 xcorners = np.array([x - hs, x + hs, x + hs, x - hs]) ycorners = np.array([y - hs, y - hs, y + hs, y + hs]) axes.fill(xcorners, ycorners, colour, edgecolor=colour) def _rectangle(axes, x, y, width, height, **kwargs): _x = x - width/2 _y = y - height/2 rectangle = plt.Rectangle((_x, _y), width, height, **kwargs) axes.add_patch(rectangle) return
[docs]def gaussian_mixture_2d(X, alpha=None, scale=2, fill=False, axes=None, **kwargs): """ Plot Gaussian mixture as ellipses in 2-D Parameters ---------- X : Mixture node alpha : Dirichlet-like node (optional) Probabilities for the clusters scale : float (optional) Scale for the covariance ellipses (by default, 2) """ if axes is None: axes = plt.gca() mu_Lambda = X._ensure_moments(X.parents[1], GaussianWishartMoments) (mu, _, Lambda, _) = mu_Lambda.get_moments() mu = np.linalg.solve(Lambda, mu) if len(mu_Lambda.plates) != 1: raise NotImplementedError("Not yet implemented for more plates") K = mu_Lambda.plates[0] width = np.zeros(K) height = np.zeros(K) angle = np.zeros(K) for k in range(K): m = mu[k] L = Lambda[k] (u, W) = scipy.linalg.eigh(L) u[0] = np.sqrt(1/u[0]) u[1] = np.sqrt(1/u[1]) width[k] = 2*u[0] height[k] = 2*u[1] angle[k] = np.arctan(W[0,1] / W[0,0]) angle = 180 * angle / np.pi mode_height = 1 / (width * height) # Use cluster probabilities to adjust alpha channel if alpha is not None: # Compute the normalized probabilities in a numerically stable way logsum_p = misc.logsumexp(alpha.u[0], axis=-1, keepdims=True) logp = alpha.u[0] - logsum_p p = np.exp(logp) # Visibility is based on cluster mode peak height visibility = mode_height * p visibility /= np.amax(visibility) else: visibility = np.ones(K) for k in range(K): ell = mpl.patches.Ellipse(mu[k], scale*width[k], scale*height[k], angle=(180+angle[k]), fill=fill, alpha=visibility[k], **kwargs) axes.add_artist(ell) plt.axis('equal') # If observed, plot the data too if np.any(X.observed): mask = np.array(X.observed) * np.ones(X.plates, dtype=np.bool) y = X.u[0][mask] plt.plot(y[:,0], y[:,1], 'r.') return
def _hinton(W, error=None, vmax=None, square=False, axes=None): """ Draws a Hinton diagram for visualizing a weight matrix. Temporarily disables matplotlib interactive mode if it is on, otherwise this takes forever. Originally copied from http://wiki.scipy.org/Cookbook/Matplotlib/HintonDiagrams """ if axes is None: axes = plt.gca() W = misc.atleast_nd(W, 2) (height, width) = W.shape if not vmax: #vmax = 2**np.ceil(np.log(np.max(np.abs(W)))/np.log(2)) if error is not None: vmax = np.max(np.abs(W) + error) else: vmax = np.max(np.abs(W)) axes.fill(0.5+np.array([0,width,width,0]), 0.5+np.array([0,0,height,height]), 'gray') if square: axes.set_aspect('equal') axes.set_ylim(0.5, height+0.5) axes.set_xlim(0.5, width+0.5) axes.set_xticks([]) axes.set_yticks([]) axes.invert_yaxis() for x in range(width): for y in range(height): _x = x+1 _y = y+1 w = W[y,x] _w = np.abs(w) if w > 0: _c = 'white' else: _c = 'black' if error is not None: e = error[y,x] if e < 0: print(e, _w, vmax) raise Exception("BUG? Negative error") if _w + e > vmax: print(e, _w, vmax) raise Exception("BUG? Value+error greater than max") _rectangle(axes, _x, _y, min(1, np.sqrt((_w+e)/vmax)), min(1, np.sqrt((_w+e)/vmax)), edgecolor=_c, fill=False) _blob(axes, _x, _y, min(1, _w/vmax), _c) def matrix(A, axes=None): if axes is None: axes = plt.gca() A = np.atleast_2d(A) vmax = np.max(np.abs(A)) return axes.imshow(A, interpolation='nearest', cmap='RdBu_r', vmin=-vmax, vmax=vmax) def new_matrix(A, vmax=None): A = np.atleast_2d(A) if vmax is None: vmax = np.max(np.abs(A)) (M, N) = np.shape(A) for i in range(M): for j in range(N): pass def gaussian_hinton(X, rows=None, cols=None, scale=1, fig=None): """ Plot the Hinton diagram of a Gaussian node """ if fig is None: fig = plt.gcf() # Get mean and second moment X = X._ensure_moments(X, GaussianMoments, ndim=0) (x, xx) = X.get_moments() ndim = len(X.dims[0]) shape = X.get_shape(0) size = len(X.get_shape(0)) # Compute standard deviation xx = misc.get_diag(xx, ndim=ndim) std = np.sqrt(xx - x**2) # Force explicit elements when broadcasting x = x * np.ones(shape) std = std * np.ones(shape) if rows is None: rows = np.nan if cols is None: cols = np.nan # Preprocess the axes to 0,...,ndim if rows < 0: rows += size if cols < 0: cols += size if rows < 0 or rows >= size: raise ValueError("Row axis invalid") if cols < 0 or cols >= size: raise ValueError("Column axis invalid") # Remove non-row and non-column axes that have length 1 squeezed_shape = list(shape) for i in reversed(range(len(shape))): if shape[i] == 1 and i != rows and i != cols: squeezed_shape.pop(i) if i < cols: cols -= 1 if i < rows: rows -= 1 x = np.reshape(x, squeezed_shape) std = np.reshape(std, squeezed_shape) if np.ndim(x) < 2: cols += 2 - np.ndim(x) rows += 2 - np.ndim(x) x = np.atleast_2d(x) std = np.atleast_2d(std) size = np.ndim(x) if np.isnan(cols): if rows != size - 1: cols = size - 1 else: cols = size - 2 if np.isnan(rows): if cols != size - 1: rows = size - 1 else: rows = size - 2 # Put the row and column axes to the end axes = [i for i in range(size) if i not in (rows, cols)] + [rows, cols] x = np.transpose(x, axes=axes) std = np.transpose(std, axes=axes) vmax = np.max(np.abs(x) + scale*std) if scale == 0: _subplots(_hinton, (x, 2), fig=fig, kwargs=dict(vmax=vmax)) else: def plotfunc(z, e, **kwargs): return _hinton(z, error=e, **kwargs) _subplots(plotfunc, (x, 2), (scale*std, 2), fig=fig, kwargs=dict(vmax=vmax)) def _hinton_figure(x, rows=None, cols=None, fig=None, square=True): """ Plot the Hinton diagram of a Gaussian node """ scale = 0 std = 0 if fig is None: fig = plt.gcf() # Get mean and second moment shape = np.shape(x) size = np.ndim(x) if rows is None: rows = np.nan if cols is None: cols = np.nan # Preprocess the axes to 0,...,ndim if rows < 0: rows += size if cols < 0: cols += size if rows < 0 or rows >= size: raise ValueError("Row axis invalid") if cols < 0 or cols >= size: raise ValueError("Column axis invalid") # Remove non-row and non-column axes that have length 1 squeezed_shape = list(shape) for i in reversed(range(len(shape))): if shape[i] == 1 and i != rows and i != cols: squeezed_shape.pop(i) if i < cols: cols -= 1 if i < rows: rows -= 1 x = np.reshape(x, squeezed_shape) size = np.ndim(x) if np.isnan(cols): if rows != size - 1: cols = size - 1 else: cols = size - 2 if np.isnan(rows): if cols != size - 1: rows = size - 1 else: rows = size - 2 # Put the row and column axes to the end if np.ndim(x) >= 2: axes = [i for i in range(size) if i not in (rows, cols)] + [rows, cols] x = np.transpose(x, axes=axes) #std = np.transpose(std, axes=axes) vmax = np.max(np.abs(x) + scale*std) kw = dict(vmax=vmax, square=square) if scale == 0: _subplots(_hinton, (x, 2), fig=fig, kwargs=kw) else: def plotfunc(z, e, **kwargs): return _hinton(z, error=e, **kwargs) _subplots(plotfunc, (x, 2), (scale*std, 2), fig=fig, kwargs=kw) # For backwards compatibility: gaussian_array = gaussian_hinton def timeseries_categorical_mc(Z, fig=None): if fig is None: fig = plt.gcf() # Make sure that the node is categorical Z = Z._ensure_moments(Z, CategoricalMoments, categories=None) # Get expectations (and broadcast explicitly) z = Z._message_to_child()[0] * np.ones(Z.get_shape(0)) # Compute the subplot layout z = misc.atleast_nd(z, 4) if np.ndim(z) != 4: raise ValueError("Can not plot arrays with over 4 axes") M = np.shape(z)[0] N = np.shape(z)[1] # Plot Hintons for i in range(M): for j in range(N): axes = fig.add_subplot(M, N, i*N+j+1) _hinton(z[i,j].T, vmax=1.0, square=False, axes=axes) def gamma_hinton(alpha, square=True, **kwargs): """ Plot a beta distributed random variable as a Hinton diagram """ # Make sure that the node is beta alpha = alpha._ensure_moments(alpha, GammaMoments) # Compute exp( <log p> ) x = alpha.get_moments()[0] # Explicit broadcasting x = x * np.ones(alpha.plates) # Plot Hinton diagram return _hinton_figure(x, square=square, **kwargs) def beta_hinton(P, square=True): """ Plot a beta distributed random variable as a Hinton diagram """ # Make sure that the node is beta P = P._ensure_moments(P, BetaMoments) # Compute exp( <log p> ) p = np.exp(P._message_to_child()[0][...,0]) # Explicit broadcasting p = p * np.ones(P.plates) # Plot Hinton diagram return _hinton(p, vmax=1.0, square=square) def dirichlet_hinton(P, square=True): """ Plot a beta distributed random variable as a Hinton diagram """ # Make sure that the node is beta P = P._ensure_moments(P, DirichletMoments) # Compute exp( <log p> ) p = np.exp(P._message_to_child()[0]) # Explicit broadcasting p = p * np.ones(P.plates+(1,)) # Plot Hinton diagram return _hinton(p, vmax=1.0, square=square) def bernoulli_hinton(Z, square=True): """ Plot a Bernoulli distributed random variable as a Hinton diagram """ # Make sure that the node is Bernoulli Z = Z._ensure_moments(Z, BernoulliMoments) # Get <Z> z = Z._message_to_child()[0] # Explicit broadcasting z = z * np.ones(Z.plates) # Plot Hinton diagram return _hinton(z, vmax=1.0, square=square) def categorical_hinton(Z, square=True): """ Plot a Bernoulli distributed random variable as a Hinton diagram """ # Make sure that the node is Bernoulli Z = Z._ensure_moments(Z, CategoricalMoments, categories=None) # Get <Z> z = Z._message_to_child()[0] # Explicit broadcasting z = z * np.ones(Z.plates+(1,)) # Plot Hinton diagram return _hinton(np.squeeze(z), vmax=1.0, square=square)
[docs]def hinton(X, **kwargs): r""" Plot the Hinton diagram of a node The keyword arguments depend on the node type. For some node types, the diagram also shows uncertainty with non-filled rectangles. Currently, beta-like, Gaussian-like and Dirichlet-like nodes are supported. Parameters ---------- X : node """ if hasattr(X, "_ensure_moments"): try: X = X._ensure_moments(X, GaussianMoments, ndim=0) except Moments.NoConverterError: pass else: return gaussian_hinton(X, **kwargs) try: X = X._ensure_moments(X, GammaMoments) except Moments.NoConverterError: pass else: return gamma_hinton(X, **kwargs) try: X = X._ensure_moments(X, BetaMoments) except Moments.NoConverterError: pass else: return beta_hinton(X, **kwargs) try: X = X._ensure_moments(X, DirichletMoments) except Moments.NoConverterError: pass else: return dirichlet_hinton(X, **kwargs) try: X = X._ensure_moments(X, BernoulliMoments) except Moments.NoConverterError: pass else: return bernoulli_hinton(X, **kwargs) try: X = X._ensure_moments(X, CategoricalMoments, categories=None) except Moments.NoConverterError: pass else: return categorical_hinton(X, **kwargs) return _hinton_figure(X, **kwargs)
[docs]class Plotter(): r""" Wrapper for plotting functions and base class for node plotters The purpose of this class is to collect all the parameters needed by a plotting function and provide a callable interface which needs only the node as the input. Plotter instances are callable objects that plot a given node using a specified plotting function. Parameters ---------- plotter : function Plotting function to use args : defined by the plotting function Additional inputs needed by the plotting function kwargs : defined by the plotting function Additional keyword arguments supported by the plotting function Examples -------- First, create a gamma variable: >>> import numpy as np >>> from bayespy.nodes import Gamma >>> x = Gamma(4, 5) The probability density function can be plotted as: >>> import bayespy.plot as bpplt >>> bpplt.pdf(x, np.linspace(0.1, 10, num=100)) # doctest: +ELLIPSIS [<matplotlib.lines.Line2D object at 0x...>] However, this can be problematic when one needs to provide a plotting function for the inference engine as the inference engine gives only the node as input. Thus, we need to create a simple plotter wrapper: >>> p = bpplt.Plotter(bpplt.pdf, np.linspace(0.1, 10, num=100)) Now, this callable object ``p`` needs only the node as the input: >>> p(x) # doctest: +ELLIPSIS [<matplotlib.lines.Line2D object at 0x...>] Thus, it can be given to the inference engine to use as a plotting function: >>> x = Gamma(4, 5, plotter=p) >>> x.plot() # doctest: +ELLIPSIS [<matplotlib.lines.Line2D object at 0x...>] """
[docs] def __init__(self, plotter, *args, **kwargs): self._args = args self._kwargs = kwargs self._plotter = plotter
def __call__(self, X, fig=None, **kwargs): """ Plot the node using the specified plotting function Parameters ---------- X : node The plotted node """ kwargs_all = self._kwargs.copy() kwargs_all.update(kwargs) return self._plotter(X, *self._args, fig=fig, **kwargs_all)
[docs]class PDFPlotter(Plotter): r""" Plotter of probability density function of a scalar node Parameters ---------- x_grid : array Numerical grid on which the density function is computed and plotted See also -------- pdf """
[docs] def __init__(self, x_grid, **kwargs): super().__init__(pdf, x_grid, **kwargs)
[docs]class ContourPlotter(Plotter): r""" Plotter of probability density function of a two-dimensional node Parameters ---------- x1_grid : array Grid for the first dimension x2_grid : array Grid for the second dimension See also -------- contour """
[docs] def __init__(self, x1_grid, x2_grid, **kwargs): super().__init__(contour, x1_grid, x2_grid, **kwargs)
[docs]class HintonPlotter(Plotter): r""" Plotter of the Hinton diagram of a node See also -------- hinton """
[docs] def __init__(self, **kwargs): super().__init__(hinton, **kwargs)
[docs]class FunctionPlotter(Plotter): r""" Plotter of a node as a 1-dimensional function See also -------- plot """
[docs] def __init__(self, **kwargs): super().__init__(plot, **kwargs)
class GaussianMarkovChainPlotter(Plotter): r""" Plotter of a Gaussian Markov chain as a timeseries """ def __init__(self, **kwargs): super().__init__(timeseries_gaussian_mc, **kwargs)
[docs]class GaussianTimeseriesPlotter(Plotter): r""" Plotter of a Gaussian node as a timeseries """
[docs] def __init__(self, **kwargs): super().__init__(timeseries_gaussian, **kwargs)
class GaussianHintonPlotter(Plotter): r""" Plotter of a Gaussian node as a Hinton diagram """ def __init__(self, **kwargs): super().__init__(gaussian_array, **kwargs)
[docs]class CategoricalMarkovChainPlotter(Plotter): r""" Plotter of a Categorical timeseries """
[docs] def __init__(self, **kwargs): super().__init__(timeseries_categorical_mc, **kwargs)
def matrix_animation(A, filename=None, fps=25, fig=None, **kwargs): if fig is None: fig = plt.gcf() axes = fig.add_subplot(111) A = np.atleast_3d(A) vmax = np.max(np.abs(A)) x = axes.imshow(A[0], interpolation='nearest', cmap='RdBu_r', vmin=-vmax, vmax=vmax, **kwargs) s = axes.set_title('t = %d' % 0) def animate(nframe): s.set_text('t = %d' % nframe) x.set_array(A[nframe]) return (x, s) anim = animation.FuncAnimation(fig, animate, frames=np.shape(A)[0], interval=1000/fps, blit=False, repeat=False) return anim def save_animation(anim, filename, fps=25, bitrate=5000, fig=None): # A bug in numpy/matplotlib causes this not to work in python3.3: # https://github.com/matplotlib/matplotlib/issues/1891 # # So the following command does not work currently.. # # anim.save(filename, fps=fps) if fig is None: fig = plt.gcf() writer = animation.FFMpegFileWriter(fps=fps, bitrate=bitrate) writer.setup(fig, filename, 100) anim.save(filename, fps=fps, writer=writer, bitrate=bitrate) return def binary_matrix(A, axes=None): if axes is None: axes = plt.gca() A = np.atleast_2d(A) G = np.zeros(np.shape(A) + (3,)) G[A] = [0,0,0] G[np.logical_not(A)] = [1,1,1] axes.imshow(G, interpolation='nearest') def gaussian_mixture_logpdf(x, w, mu, Sigma): # Shape(x) = (N, D) # Shape(w) = (K,) # Shape(mu) = (K, D) # Shape(Sigma) = (K, D, D) # Shape(result) = (N,) # Dimensionality D = np.shape(x)[-1] # Cholesky decomposition of the covariance matrix U = linalg.chol(Sigma) # Reshape x: # Shape(x) = (N, 1, D) x = np.expand_dims(x, axis=-2) # (x-mu) and (x-mu)'*inv(Sigma)*(x-mu): # Shape(v) = (N, K, D) # Shape(z) = (N, K) v = x - mu z = np.einsum('...i,...i', v, linalg.chol_solve(U, v)) # Log-determinant of Sigma: # Shape(ldet) = (K,) ldet = linalg.chol_logdet(U) # Compute log pdf for each cluster: # Shape(lpdf) = (N, K) lpdf = misc.gaussian_logpdf(z, 0, 0, ldet, D) def matrixplot(A, colorbar=False, axes=None): if axes is None: axes = plt.gca() if sp.issparse(A): A = A.toarray() axes.imshow(A, interpolation='nearest') if colorbar: plt.colorbar(ax=axes) def contourplot(x1, x2, y, colorbar=False, filled=True, axes=None): """ Plots 2D contour plot. x1 and x2 are 1D vectors, y contains the function values. y.size must be x1.size*x2.size. """ if axes is None: axes = plt.gca() y = np.reshape(y, (len(x2),len(x1))) if filled: axes.contourf(x1, x2, y) else: axes.contour(x1, x2, y) if colorbar: plt.colorbar(ax=axes) def errorplot(y=None, error=None, x=None, lower=None, upper=None, color=(0,0,0,1), fillcolor=None, axes=None, **kwargs): if axes is None: axes = plt.gca() # Default inputs if x is None: x = np.arange(np.size(y)) # Parse errors (lower=lower/error/upper, upper=upper/error/lower) if lower is None: if error is not None: lower = error elif upper is not None: lower = upper if upper is None: if error is not None: upper = error elif lower is not None: upper = lower # Plot errors if (lower is not None) and (upper is not None): l = y - lower u = y + upper if fillcolor is None: color = colors.ColorConverter().to_rgba(color) fillcolor = tuple(color[:3]) + (0.2 * color[3],) axes.fill_between(x, l, u, facecolor=fillcolor, edgecolor=(0, 0, 0, 0), linewidth=1, interpolate=True) # Plot function axes.plot(x, y, color=color, **kwargs) def plotmatrix(X): """ Creates a matrix of marginal plots. On diagonal, are marginal plots of each variable. Off-diagonal plot (i,j) shows the joint marginal density of x_i and x_j. """ return X.plotmatrix() def _pdf_t(mu, s2, nu, axes=None, scale=4, color='k'): """ """ if axes is None: axes = plt.gca() s = np.sqrt(s2) x = np.linspace(mu-scale*s, mu+scale*s, num=100) y2 = (x-mu)**2 / s2 lpdf = random.t_logpdf(y2, np.log(s2), nu, 1) p = np.exp(lpdf) return axes.plot(x, p, color=color) def _pdf_gamma(a, b, axes=None, scale=4, color='k'): """ """ if axes is None: axes = plt.gca() if np.size(a) != 1 or np.size(b) != 1: raise ValueError("Parameters must be scalars") mean = a/b v = scale*np.sqrt(a/b**2) m = max(0, mean-v) n = mean + v x = np.linspace(m, n, num=100) logx = np.log(x) lpdf = random.gamma_logpdf(b*x, logx, a*logx, a*np.log(b), special.gammaln(a)) p = np.exp(lpdf) return axes.plot(x, p, color=color) def _contour_t(mu, Cov, nu, axes=None, scale=4, transpose=False, colors='k'): """ """ if axes is None: axes = plt.gca() if np.shape(mu) != (2,) or np.shape(Cov) != (2,2) or np.shape(nu) != (): print(np.shape(mu), np.shape(Cov), np.shape(nu)) raise ValueError("Only 2-d t-distribution allowed") if transpose: mu = mu[[1,0]] Cov = Cov[np.ix_([1,0],[1,0])] s = np.sqrt(np.diag(Cov)) x0 = np.linspace(mu[0]-scale*s[0], mu[0]+scale*s[0], num=100) x1 = np.linspace(mu[1]-scale*s[1], mu[1]+scale*s[1], num=100) X0X1 = misc.grid(x0, x1) Y = X0X1 - mu L = linalg.chol(Cov) logdet_Cov = linalg.chol_logdet(L) Z = linalg.chol_solve(L, Y) Z = linalg.inner(Y, Z, ndim=1) lpdf = random.t_logpdf(Z, logdet_Cov, nu, 2) p = np.exp(lpdf) shape = (np.size(x0), np.size(x1)) X0 = np.reshape(X0X1[:,0], shape) X1 = np.reshape(X0X1[:,1], shape) P = np.reshape(p, shape) return axes.contour(X0, X1, P, colors=colors) def _contour_gaussian_gamma(mu, s2, a, b, axes=None, transpose=False): """ """ pass
[docs]def ellipse_from_cov(xy, cov, scale=2, **kwargs): """ Create an ellipse from a covariance matrix. Parameters ---------- xy : np.ndarray position of the ellipse cov : np.ndarray covariance matrix scale : float scale of the ellipse (default is three standard deviations) kwargs : dict keyword arguments passed on to `matplotlib.patches.Ellipse` Returns ------- ellipse : matplotlib.patches.Ellipse """ evals, evecs = np.linalg.eigh(cov) angle = np.arctan2(*evecs[::-1, 0]) width, height = scale * np.sqrt(evals) return patches.Ellipse(xy, width, height, np.rad2deg(angle), **kwargs)
[docs]def ellipse_from_precision(xy, precision, scale=2, **kwargs): """ Create an ellipse from a covariance matrix. Parameters ---------- xy : np.ndarray position of the ellipse cov : np.ndarray covariance matrix scale : float scale of the ellipse (default is three standard deviations) kwargs : dict keyword arguments passed on to `matplotlib.patches.Ellipse` Returns ------- ellipse : matplotlib.patches.Ellipse """ return ellipse_from_cov(xy, np.linalg.inv(precision), scale, **kwargs)