{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "execution": {}, "id": "view-in-github" }, "source": [ "\"Open   \"Open" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "# Tutorial 2: Hidden Markov Model\n", "\n", "**Week 3, Day 2: Hidden Dynamics**\n", "\n", "**By Neuromatch Academy**\n", "\n", "**Content creators:** Yicheng Fei with help from Jesse Livezey and Xaq Pitkow\n", "\n", "**Content reviewers:** John Butler, Matt Krause, Meenakshi Khosla, Spiros Chavlis, Michael Waskom\n", "\n", "**Production editors:** Ella Batty, Gagana B, Spiros Chavlis" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "# Tutorial objectives\n", "\n", "*Estimated timing of tutorial: 1 hour, 5 minutes*\n", "\n", "The world around us is often changing, but we only have noisy sensory measurements. Similarly, neural systems switch between discrete states (e.g. sleep/wake) which are observable only indirectly, through their impact on neural activity. **Hidden Markov Models** (HMM) let us reason about these unobserved (also called hidden or latent) states using a time series of measurements.\n", "\n", "Here we'll learn how changing the HMM's transition probability and measurement noise impacts the data. We'll look at how uncertainty increases as we predict the future, and how to gain information from the measurements.\n", "\n", "We will use a binary latent variable $s_t \\in \\{0,1\\}$ that switches randomly between the two states, and a 1D Gaussian emission model $m_t|s_t \\sim \\mathcal{N}(\\mu_{s_t},\\sigma^2_{s_t})$ that provides evidence about the current state.\n", "\n", "By the end of this tutorial, you should be able to:\n", "- Describe how the hidden states in a Hidden Markov model evolve over time, both in words, mathematically, and in code\n", "- Estimate hidden states from data using forward inference in a Hidden Markov model\n", "- Describe how measurement noise and state transition probabilities affect uncertainty in predictions in the future and the ability to estimate hidden states.\n", "\n", "
\n", "\n", "**Summary of Exercises**\n", "1. Generate data from an HMM.\n", "2. Calculate how predictions propagate in a Markov Chain without evidence.\n", "3. Combine new evidence and prediction from past evidence to estimate hidden states." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "# @markdown\n", "from IPython.display import IFrame\n", "from ipywidgets import widgets\n", "out = widgets.Output()\n", "with out:\n", " print(f\"If you want to download the slides: https://osf.io/download/zsfbn/\")\n", " display(IFrame(src=f\"https://mfr.ca-1.osf.io/render?url=https://osf.io/zsfbn/?direct%26mode=render%26action=download%26mode=render\", width=730, height=410))\n", "display(out)" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "# Setup" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Install and import feedback gadget\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Install and import feedback gadget\n", "\n", "!pip3 install vibecheck datatops --quiet\n", "\n", "from vibecheck import DatatopsContentReviewContainer\n", "def content_review(notebook_section: str):\n", " return DatatopsContentReviewContainer(\n", " \"\", # No text prompt\n", " notebook_section,\n", " {\n", " \"url\": \"https://pmyvdlilci.execute-api.us-east-1.amazonaws.com/klab\",\n", " \"name\": \"neuromatch_cn\",\n", " \"user_key\": \"y1x3mpx5\",\n", " },\n", " ).render()\n", "\n", "\n", "feedback_prefix = \"W3D2_T2\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "# Imports\n", "import numpy as np\n", "import time\n", "from scipy import stats\n", "from scipy.optimize import linear_sum_assignment\n", "from collections import namedtuple\n", "import matplotlib.pyplot as plt\n", "from matplotlib import patches" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Figure Settings\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Figure Settings\n", "import logging\n", "logging.getLogger('matplotlib.font_manager').disabled = True\n", "\n", "import ipywidgets as widgets # interactive display\n", "from ipywidgets import interactive, interact, HBox, Layout,VBox\n", "from IPython.display import HTML\n", "%config InlineBackend.figure_format = 'retina'\n", "plt.style.use(\"https://raw.githubusercontent.com/NeuromatchAcademy/course-content/NMA2020/nma.mplstyle\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plotting Functions\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Plotting Functions\n", "\n", "def plot_hmm1(model, states, measurements, flag_m=True):\n", " \"\"\"Plots HMM states and measurements for 1d states and measurements.\n", "\n", " Args:\n", " model (hmmlearn model): hmmlearn model used to get state means.\n", " states (numpy array of floats): Samples of the states.\n", " measurements (numpy array of floats): Samples of the states.\n", " \"\"\"\n", " T = states.shape[0]\n", " nsteps = states.size\n", " aspect_ratio = 2\n", " fig, ax1 = plt.subplots(figsize=(8,4))\n", " states_forplot = list(map(lambda s: model.means[s], states))\n", " ax1.step(np.arange(nstep), states_forplot, \"-\", where=\"mid\", alpha=1.0, c=\"green\")\n", " ax1.set_xlabel(\"Time\")\n", " ax1.set_ylabel(\"Latent State\", c=\"green\")\n", " ax1.set_yticks([-1, 1])\n", " ax1.set_yticklabels([\"-1\", \"+1\"])\n", " ax1.set_xticks(np.arange(0,T,10))\n", " ymin = min(measurements)\n", " ymax = max(measurements)\n", "\n", " ax2 = ax1.twinx()\n", " ax2.set_ylabel(\"Measurements\", c=\"crimson\")\n", "\n", " # show measurement gaussian\n", " if flag_m:\n", " ax2.plot([T,T],ax2.get_ylim(), color=\"maroon\", alpha=0.6)\n", " for i in range(model.n_components):\n", " mu = model.means[i]\n", " scale = np.sqrt(model.vars[i])\n", " rv = stats.norm(mu, scale)\n", " num_points = 50\n", " domain = np.linspace(mu-3*scale, mu+3*scale, num_points)\n", "\n", " left = np.repeat(float(T), num_points)\n", " # left = np.repeat(0.0, num_points)\n", " offset = rv.pdf(domain)\n", " offset *= T / 15\n", " lbl = \"measurement\" if i == 0 else \"\"\n", " # ax2.fill_betweenx(domain, left, left-offset, alpha=0.3, lw=2, color=\"maroon\", label=lbl)\n", " ax2.fill_betweenx(domain, left+offset, left, alpha=0.3, lw=2, color=\"maroon\", label=lbl)\n", " ax2.scatter(np.arange(nstep), measurements, c=\"crimson\", s=4)\n", " ax2.legend(loc=\"upper left\")\n", " ax1.set_ylim(ax2.get_ylim())\n", " plt.show(fig)\n", "\n", "\n", "def plot_marginal_seq(predictive_probs, switch_prob):\n", " \"\"\"Plots the sequence of marginal predictive distributions.\n", "\n", " Args:\n", " predictive_probs (list of numpy vectors): sequence of predictive probability vectors\n", " switch_prob (float): Probability of switching states.\n", " \"\"\"\n", " T = len(predictive_probs)\n", " prob_neg = [p_vec[0] for p_vec in predictive_probs]\n", " prob_pos = [p_vec[1] for p_vec in predictive_probs]\n", " fig, ax = plt.subplots()\n", " ax.plot(np.arange(T), prob_neg, color=\"blue\")\n", " ax.plot(np.arange(T), prob_pos, color=\"orange\")\n", " ax.legend([\n", " \"prob in state -1\", \"prob in state 1\"\n", " ])\n", " ax.text(T/2, 0.05, \"switching probability={}\".format(switch_prob), fontsize=12,\n", " bbox=dict(boxstyle=\"round\", facecolor=\"wheat\", alpha=0.6))\n", " ax.set_xlabel(\"Time\")\n", " ax.set_ylabel(\"Probability\")\n", " ax.set_title(\"Forgetting curve in a changing world\")\n", " plt.show(fig)\n", "\n", "\n", "def plot_evidence_vs_noevidence(posterior_matrix, predictive_probs):\n", " \"\"\"Plots the average posterior probabilities with evidence v.s. no evidence\n", "\n", " Args:\n", " posterior_matrix: (2d numpy array of floats): The posterior probabilities in state 1 from evidence (samples, time)\n", " predictive_probs (numpy array of floats): Predictive probabilities in state 1 without evidence\n", " \"\"\"\n", " nsample, T = posterior_matrix.shape\n", " posterior_mean = posterior_matrix.mean(axis=0)\n", " fig, ax = plt.subplots(1)\n", " ax.plot([0.0, T], [0., 0.], color=\"red\", linestyle=\"dashed\")\n", " ax.plot(np.arange(T), predictive_probs, c=\"orange\", linewidth=2, label=\"No evidence\")\n", " ax.scatter(np.tile(np.arange(T), (nsample, 1)), posterior_matrix, s=0.8,\n", " c=\"green\", alpha=0.3, label=\"With evidence(Sample)\")\n", " ax.plot(np.arange(T), posterior_mean, c='green',\n", " linewidth=2, label=\"With evidence(Average)\")\n", " ax.legend()\n", " ax.set_yticks([0.0, 0.25, 0.5, 0.75, 1.0])\n", " ax.set_xlabel(\"Time\")\n", " ax.set_ylabel(\"Probability in State +1\")\n", " ax.set_title(\"Gain confidence with evidence\")\n", " plt.show(fig)\n", "\n", "\n", "def plot_forward_inference(model, states, measurements, states_inferred,\n", " predictive_probs, likelihoods, posterior_probs,\n", " t=None, flag_m=True, flag_d=True, flag_pre=True,\n", " flag_like=True, flag_post=True):\n", " \"\"\"Plot ground truth state sequence with noisy measurements, and ground truth states v.s. inferred ones\n", "\n", " Args:\n", " model (instance of hmmlearn.GaussianHMM): an instance of HMM\n", " states (numpy vector): vector of 0 or 1(int or Bool), the sequences of true latent states\n", " measurements (numpy vector of numpy vector): the un-flattened Gaussian measurements at each time point, element has size (1,)\n", " states_inferred (numpy vector): vector of 0 or 1(int or Bool), the sequences of inferred latent states\n", " \"\"\"\n", " T = states.shape[0]\n", " if t is None:\n", " t = T-1\n", " nsteps = states.size\n", " fig, ax1 = plt.subplots(figsize=(11,6))\n", " # true states\n", " states_forplot = list(map(lambda s: model.means[s], states))\n", " ax1.step(np.arange(nstep)[:t+1], states_forplot[:t+1], \"-\", where=\"mid\", alpha=1.0, c=\"green\", label=\"true\")\n", " ax1.step(np.arange(nstep)[t+1:], states_forplot[t+1:], \"-\", where=\"mid\", alpha=0.3, c=\"green\", label=\"\")\n", " # Posterior curve\n", " delta = model.means[1] - model.means[0]\n", " states_interpolation = model.means[0] + delta * posterior_probs[:,1]\n", " if flag_post:\n", " ax1.step(np.arange(nstep)[:t+1], states_interpolation[:t+1], \"-\", where=\"mid\", c=\"grey\", label=\"posterior\")\n", "\n", " ax1.set_xlabel(\"Time\")\n", " ax1.set_ylabel(\"Latent State\", c=\"green\")\n", " ax1.set_yticks([-1, 1])\n", " ax1.set_yticklabels([\"-1\", \"+1\"])\n", " ax1.legend(bbox_to_anchor=(0,1.02,0.2,0.1), borderaxespad=0, ncol=2)\n", "\n", " ax2 = ax1.twinx()\n", " ax2.set_ylim(\n", " min(-1.2, np.min(measurements)),\n", " max(1.2, np.max(measurements))\n", " )\n", " if flag_d:\n", " ax2.scatter(np.arange(nstep)[:t+1], measurements[:t+1], c=\"crimson\", s=4, label=\"measurement\")\n", " ax2.set_ylabel(\"Measurements\", c=\"crimson\")\n", "\n", " # show measurement distributions\n", " if flag_m:\n", " for i in range(model.n_components):\n", " mu = model.means[i]\n", " scale = np.sqrt(model.vars[i])\n", " rv = stats.norm(mu, scale)\n", " num_points = 50\n", " domain = np.linspace(mu-3*scale, mu+3*scale, num_points)\n", "\n", " left = np.repeat(float(T), num_points)\n", " offset = rv.pdf(domain)\n", " offset *= T /15\n", " lbl = \"\"\n", " ax2.fill_betweenx(domain, left+offset, left, alpha=0.3, lw=2, color=\"maroon\", label=lbl)\n", " ymin, ymax = ax2.get_ylim()\n", " width = 0.1 * (ymax-ymin) / 2.0\n", " centers = [-1.0, 1.0]\n", " bar_scale = 15\n", "\n", " # Predictions\n", " data = predictive_probs\n", " if flag_pre:\n", " for i in range(model.n_components):\n", " domain = np.array([centers[i]-1.5*width, centers[i]-0.5*width])\n", " left = np.array([t,t])\n", " offset = np.array([data[t,i]]*2)\n", " offset *= bar_scale\n", " lbl = \"todays prior\" if i == 0 else \"\"\n", " ax2.fill_betweenx(domain, left+offset, left, alpha=0.3, lw=2, color=\"dodgerblue\", label=lbl)\n", "\n", " # Likelihoods\n", " data = likelihoods\n", " data /= np.sum(data,axis=-1, keepdims=True)\n", " if flag_like:\n", " for i in range(model.n_components):\n", " domain = np.array([centers[i]+0.5*width, centers[i]+1.5*width])\n", " left = np.array([t,t])\n", " offset = np.array([data[t,i]]*2)\n", " offset *= bar_scale\n", " lbl = \"likelihood\" if i == 0 else \"\"\n", " ax2.fill_betweenx(domain, left+offset, left, alpha=0.3, lw=2, color=\"crimson\", label=lbl)\n", " # Posteriors\n", " data = posterior_probs\n", " if flag_post:\n", " for i in range(model.n_components):\n", " domain = np.array([centers[i]-0.5*width, centers[i]+0.5*width])\n", " left = np.array([t,t])\n", " offset = np.array([data[t,i]]*2)\n", " offset *= bar_scale\n", " lbl = \"posterior\" if i == 0 else \"\"\n", " ax2.fill_betweenx(domain, left+offset, left, alpha=0.3, lw=2, color=\"grey\", label=lbl)\n", " if t