{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "(parallel)=\n",
    "\n",
    "# Parallelization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "tags": [
     "hide-cell"
    ]
   },
   "outputs": [],
   "source": [
    "%config InlineBackend.figure_format = \"retina\"\n",
    "\n",
    "from matplotlib import rcParams\n",
    "\n",
    "rcParams[\"savefig.dpi\"] = 100\n",
    "rcParams[\"figure.dpi\"] = 100\n",
    "rcParams[\"font.size\"] = 20\n",
    "\n",
    "import multiprocessing\n",
    "\n",
    "multiprocessing.set_start_method(\"fork\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    ":::{note}\n",
    "Some builds of NumPy (including the version included with Anaconda) will automatically parallelize some operations using something like the MKL linear algebra. This can cause problems when used with the parallelization methods described here so it can be good to turn that off (by setting the environment variable `OMP_NUM_THREADS=1`, for example).\n",
    ":::"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ[\"OMP_NUM_THREADS\"] = \"1\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With emcee, it's easy to make use of multiple CPUs to speed up slow sampling.\n",
    "There will always be some computational overhead introduced by parallelization so it will only be beneficial in the case where the model is expensive, but this is often true for real research problems.\n",
    "All parallelization techniques are accessed using the `pool` keyword argument in the :class:`EnsembleSampler` class but, depending on your system and your model, there are a few pool options that you can choose from.\n",
    "In general, a `pool` is any Python object with a `map` method that can be used to apply a function to a list of numpy arrays.\n",
    "Below, we will discuss a few options."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In all of the following examples, we'll test the code with the following convoluted model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "def log_prob(theta):\n",
    "    t = time.time() + np.random.uniform(0.005, 0.008)\n",
    "    while True:\n",
    "        if time.time() >= t:\n",
    "            break\n",
    "    return -0.5 * np.sum(theta**2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This probability function will randomly sleep for a fraction of a second every time it is called.\n",
    "This is meant to emulate a more realistic situation where the model is computationally expensive to compute.\n",
    "\n",
    "To start, let's sample the usual (serial) way:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:21<00:00,  4.71it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Serial took 21.5 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "import emcee\n",
    "\n",
    "np.random.seed(42)\n",
    "initial = np.random.randn(32, 5)\n",
    "nwalkers, ndim = initial.shape\n",
    "nsteps = 100\n",
    "\n",
    "sampler = emcee.EnsembleSampler(nwalkers, ndim, log_prob)\n",
    "start = time.time()\n",
    "sampler.run_mcmc(initial, nsteps, progress=True)\n",
    "end = time.time()\n",
    "serial_time = end - start\n",
    "print(\"Serial took {0:.1f} seconds\".format(serial_time))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Multiprocessing\n",
    "\n",
    "The simplest method of parallelizing emcee is to use the [multiprocessing module from the standard library](https://docs.python.org/3/library/multiprocessing.html).\n",
    "To parallelize the above sampling, you could update the code as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:06<00:00, 15.65it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Multiprocessing took 6.5 seconds\n",
      "3.3 times faster than serial\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from multiprocessing import Pool\n",
    "\n",
    "with Pool() as pool:\n",
    "    sampler = emcee.EnsembleSampler(nwalkers, ndim, log_prob, pool=pool)\n",
    "    start = time.time()\n",
    "    sampler.run_mcmc(initial, nsteps, progress=True)\n",
    "    end = time.time()\n",
    "    multi_time = end - start\n",
    "    print(\"Multiprocessing took {0:.1f} seconds\".format(multi_time))\n",
    "    print(\"{0:.1f} times faster than serial\".format(serial_time / multi_time))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "I have 4 cores on the machine where this is being tested:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4 CPUs\n"
     ]
    }
   ],
   "source": [
    "from multiprocessing import cpu_count\n",
    "\n",
    "ncpu = cpu_count()\n",
    "print(\"{0} CPUs\".format(ncpu))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We don't quite get the factor of 4 runtime decrease that you might expect because there is some overhead in the parallelization, but we're getting pretty close with this example and this will get even closer for more expensive models."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## MPI\n",
    "\n",
    "Multiprocessing can only be used for distributing calculations across processors on one machine.\n",
    "If you want to take advantage of a bigger cluster, you'll need to use MPI.\n",
    "In that case, you need to execute the code using the `mpiexec` executable, so this demo is slightly more convoluted.\n",
    "For this example, we'll write the code to a file called `script.py` and then execute it using MPI, but when you really use the MPI pool, you'll probably just want to edit the script directly.\n",
    "To run this example, you'll first need to install [the schwimmbad library](https://github.com/adrn/schwimmbad) because emcee no longer includes its own `MPIPool`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MPI took 8.9 seconds\n",
      "2.4 times faster than serial\n"
     ]
    }
   ],
   "source": [
    "with open(\"script.py\", \"w\") as f:\n",
    "    f.write(\n",
    "        \"\"\"\n",
    "import sys\n",
    "import time\n",
    "import emcee\n",
    "import numpy as np\n",
    "from schwimmbad import MPIPool\n",
    "\n",
    "def log_prob(theta):\n",
    "    t = time.time() + np.random.uniform(0.005, 0.008)\n",
    "    while True:\n",
    "        if time.time() >= t:\n",
    "            break\n",
    "    return -0.5*np.sum(theta**2)\n",
    "\n",
    "with MPIPool() as pool:\n",
    "    if not pool.is_master():\n",
    "        pool.wait()\n",
    "        sys.exit(0)\n",
    "        \n",
    "    np.random.seed(42)\n",
    "    initial = np.random.randn(32, 5)\n",
    "    nwalkers, ndim = initial.shape\n",
    "    nsteps = 100\n",
    "\n",
    "    sampler = emcee.EnsembleSampler(nwalkers, ndim, log_prob, pool=pool)\n",
    "    start = time.time()\n",
    "    sampler.run_mcmc(initial, nsteps)\n",
    "    end = time.time()\n",
    "    print(end - start)\n",
    "\"\"\"\n",
    "    )\n",
    "\n",
    "mpi_time = !mpiexec -n {ncpu} python script.py\n",
    "mpi_time = float(mpi_time[0])\n",
    "print(\"MPI took {0:.1f} seconds\".format(mpi_time))\n",
    "print(\"{0:.1f} times faster than serial\".format(serial_time / mpi_time))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "There is often more overhead introduced by MPI than multiprocessing so we get less of a gain this time.\n",
    "That being said, MPI is much more flexible and it can be used to scale to huge systems."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Pickling, data transfer & arguments\n",
    "\n",
    "All parallel Python implementations work by spinning up multiple `python` processes with identical environments then and passing information between the processes using `pickle`.\n",
    "This means that the probability  function [must be picklable](https://docs.python.org/3/library/pickle.html#pickle-picklable).\n",
    "\n",
    "Some users might hit issues when they use `args` to pass data to their model.\n",
    "These args must be pickled and passed every time the model is called.\n",
    "This can be a problem if you have a large dataset, as you can see here:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:21<00:00,  4.70it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Serial took 21.5 seconds\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "def log_prob_data(theta, data):\n",
    "    a = data[0]  # Use the data somehow...\n",
    "    t = time.time() + np.random.uniform(0.005, 0.008)\n",
    "    while True:\n",
    "        if time.time() >= t:\n",
    "            break\n",
    "    return -0.5 * np.sum(theta**2)\n",
    "\n",
    "\n",
    "data = np.random.randn(5000, 200)\n",
    "\n",
    "sampler = emcee.EnsembleSampler(nwalkers, ndim, log_prob_data, args=(data,))\n",
    "start = time.time()\n",
    "sampler.run_mcmc(initial, nsteps, progress=True)\n",
    "end = time.time()\n",
    "serial_data_time = end - start\n",
    "print(\"Serial took {0:.1f} seconds\".format(serial_data_time))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We basically get no change in performance when we include the `data` argument here.\n",
    "Now let's try including this naively using multiprocessing:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [01:05<00:00,  1.52it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Multiprocessing took 66.0 seconds\n",
      "0.3 times faster(?) than serial\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "with Pool() as pool:\n",
    "    sampler = emcee.EnsembleSampler(\n",
    "        nwalkers, ndim, log_prob_data, pool=pool, args=(data,)\n",
    "    )\n",
    "    start = time.time()\n",
    "    sampler.run_mcmc(initial, nsteps, progress=True)\n",
    "    end = time.time()\n",
    "    multi_data_time = end - start\n",
    "    print(\"Multiprocessing took {0:.1f} seconds\".format(multi_data_time))\n",
    "    print(\n",
    "        \"{0:.1f} times faster(?) than serial\".format(\n",
    "            serial_data_time / multi_data_time\n",
    "        )\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Brutal.\n",
    "\n",
    "We can do better than that though.\n",
    "It's a bit ugly, but if we just make `data` a global variable and use that variable within the model calculation, then we take no hit at all."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:06<00:00, 14.60it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Multiprocessing took 6.9 seconds\n",
      "3.1 times faster than serial\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "def log_prob_data_global(theta):\n",
    "    a = data[0]  # Use the data somehow...\n",
    "    t = time.time() + np.random.uniform(0.005, 0.008)\n",
    "    while True:\n",
    "        if time.time() >= t:\n",
    "            break\n",
    "    return -0.5 * np.sum(theta**2)\n",
    "\n",
    "\n",
    "with Pool() as pool:\n",
    "    sampler = emcee.EnsembleSampler(\n",
    "        nwalkers, ndim, log_prob_data_global, pool=pool\n",
    "    )\n",
    "    start = time.time()\n",
    "    sampler.run_mcmc(initial, nsteps, progress=True)\n",
    "    end = time.time()\n",
    "    multi_data_global_time = end - start\n",
    "    print(\n",
    "        \"Multiprocessing took {0:.1f} seconds\".format(multi_data_global_time)\n",
    "    )\n",
    "    print(\n",
    "        \"{0:.1f} times faster than serial\".format(\n",
    "            serial_data_time / multi_data_global_time\n",
    "        )\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "That's better!\n",
    "This works because, in the global variable case, the dataset is only pickled and passed between processes once (when the pool is created) instead of once for every model evaluation."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
