{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\nGibbs Sampling\n==============\n\nThis example presents an illustration of the MLFM to learn the model\n\n\\begin{align}\\dot{\\mathbf{x}}(t)\\end{align}\n\nWe do the usual imports and generate some simulated data\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import numpy as np\nimport matplotlib.pyplot as plt\nfrom pydygp.probabilitydistributions import (Normal,\n                                             GeneralisedInverseGaussian,\n                                             ChiSquare,\n                                             Gamma,\n                                             InverseGamma)\nfrom sklearn.gaussian_process.kernels import RBF\nfrom pydygp.liealgebras import so\nfrom pydygp.linlatentforcemodels import GibbsMLFMAdapGrad\n\nnp.random.seed(15)\n\n\ngmlfm = GibbsMLFMAdapGrad(so(3), R=1, lf_kernels=(RBF(), ))\n\nbeta = np.row_stack(([0.]*3,\n                     np.random.normal(size=3)))\n\nx0 = np.eye(3)\n\n# Time points to solve the model at\ntt = np.linspace(0., 6., 9)\n\n# Data and true forces\nData, lf = gmlfm.sim(x0, tt, beta=beta, size=3)\n\n# vectorise and stack the data\nY = np.column_stack((y.T.ravel() for y in Data))\n\nlogpsi_prior = GeneralisedInverseGaussian(a=5, b=5, p=-1).logtransform()\nloggamma_prior = Gamma(a=2.00, b=10.0).logtransform() * gmlfm.dim.K\nbeta_prior = Normal(scale=1.) * beta.size\n\nfitopts = {'logpsi_is_fixed': True, 'logpsi_prior': logpsi_prior,\n           'loggamma_is_fixed': False, 'loggamma_prior': loggamma_prior,\n           'beta_is_fixed': False, 'beta_prior': beta_prior,\n           'beta0': beta,\n           }\n\nnsample = 100\ngibbsRV = gmlfm.gibbsfit(tt, Y,\n                         sample=('g', 'beta', 'x'),\n                         size=nsample,\n                         **fitopts)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Learning the Coefficient Matrix\n-------------------------------\n\nThe goal in fitting models of dynamic systems is to learn the dynamics,\nand more subtly learn the dynamics of the model independent of the\nstate variables.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "aijRV = []\nfor g, b in zip(gibbsRV['g'], gibbsRV['beta']):\n    _beta = b.reshape((2, 3))\n    aijRV.append(gmlfm._component_functions(g, _beta))\naijRV = np.array(aijRV)\n\n# True component functions\nttd = np.linspace(0., tt[-1], 100)\naaTrue = gmlfm._component_functions(lf[0](ttd), beta, N=ttd.size)\n\n# Make some plots\ninds = [(0, 1), (0, 2), (1, 2)]\nsymbs = ['+', '+', '+']\ncolors = ['slateblue', 'peru', 'darkseagreen']\n\nfig = plt.figure()\nfor nt, (ind, symb) in enumerate(zip(inds, symbs)):\n\n    i, j = ind\n\n    ax = fig.add_subplot(1, 3, nt+1,\n                         adjustable='box', aspect=5.)\n    ax.plot(ttd, aaTrue[i, j, :], alpha=0.8,\n            label=r\"$a^*_{{ {}{} }}$\".format(i+1, j+1),\n            color=colors[nt])\n    ax.plot(tt, aijRV[:, i, j, :].T, 'k' + symb, alpha=0.1)\n\n    ax.set_title(r\"$a_{{ {}{} }}$\".format(i+1, j+1))\n    ax.set_ylim((-.7, .7))\n    ax.legend()\n\nfor i in range(gibbsRV['beta'].shape[-1]):\n    fig, ax = plt.subplots()\n    ax.hist(gibbsRV['beta'][:, i], density=True)\n\nplt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.0"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}