The jax-smfsb tutorial ---------------------- This tutorial assumes that the package has already been installed, following the instructions in the `package readme `__. We begin with non-spatial stochastic simulation. Non-spatial simulation ---------------------- Standard algorithms for simulating the (stochastic) dynamics of biochemical networks assume that the system is well-mixed, and that spatial effects can be reasonably ignored. Using a model built-in to the library ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ First, let's see how to simulate a built-in (Lotka-Volterra predator-prey) model: .. code:: python import jax import jsmfsb lvmod = jsmfsb.models.lv() step = lvmod.step_gillespie() k0 = jax.random.key(42) out = jsmfsb.sim_time_series(k0, lvmod.m, 0, 30, 0.1, step) assert(out.shape == (300, 2)) Here we used the ``lv`` model. Other built-in models include ``id`` (immigration-death), ``bd`` (birth-death), ``dimer`` (dimerisation kinetics), ``mm`` (Michaelis-Menten enzyme kinetics) and ``sir`` (SIR epdiemic model). The models are of class ``Spn`` (stochastic Petri net), the main data type used in the package. Note the use of the ``step_gillespie`` method, defined on all ``Spn`` models, which returns a function for simulating from the transition kernel of the model, using the Gillespie algorithm. This function can be used with the ``sim_time_series`` function for simulating model trajectories on a regular time grid. Note that all stochastic simulation functions in this package take a `JAX random number key `__ as their first argument. JAX uses an explict, splittable random number generator. Alternative simulation algorithms include ``step_poisson`` (Poisson time-stepping), ``step_cle`` (Euler-Maruyama simulation from the associated chemical Langevin equation) and ``step_euler`` (Euler simulation from the continuous deterministic approximation to the model). If you have ``matplotlib`` installed (``pip install matplotlib``), then you can also plot the results with: .. code:: python import matplotlib.pyplot as plt fig, axis = plt.subplots() for i in range(2): axis.plot(range(out.shape[0]), out[:,i]) axis.legend(lvmod.n) fig.savefig("lv.pdf") Standard python docstring documentation is available. Usage information can be obtained from the python REPL with commands like ``help(jsmfsb.Spn)``, ``help(jsmfsb.Spn.step_gillespie)`` or ``help(jsmfsb.sim_time_series)``. This documentation is also available on `ReadTheDocs `__. The API documentation contains minimal usage examples. Creating and simulating a model ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Next, let’s create and simulate our own (SIR epidemic) model by specifying a stochastic Petri net ``Spn`` object explicitly. We must provide species and reaction names, stoichiometry matrices, reaction rates and initial conditions. This time we use approximate Poisson simulation rather than exact simulation via the Gillespie algorithm. .. code:: python import jax.numpy as jnp sir = jsmfsb.Spn(["S", "I", "R"], ["S->I", "I->R"], [[1,1,0], [0,1,0]], [[0,2,0], [0,0,1]], lambda x, t: jnp.array([0.3*x[0]*x[1]/200, 0.1*x[1]]), [197.0, 3, 0]) step_sir = sir.step_poisson() sample = jsmfsb.sim_sample(k0, 500, sir.m, 0, 20, step_sir) fig, axis = plt.subplots() axis.hist(sample[:,1], 30) axis.set_title("Infected at time 20") plt.savefig("sIr.pdf") Here, rather than simulating a time series trajectory, we instead simulate a sample of 500 values from the transition kernel at time 20 using ``sim_sample``. Reading and parsing models in SBML and SBML-shorthand ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Note that you can read in `SBML `__ or `SBML-shorthand `__ models that have been designed for discrete stochastic simulation into a stochastic Petri net directly. To read and parse an SBML model, use .. code:: python m = jsmfsb.file_to_spn("myModel.xml") Note that if you are working with SBML models in Python using `libsbml `__, then there is also a function ``model_to_spn`` which takes a libsbml model object. To read and parse an SBML-shorthand model, use .. code:: python m = jsmfsb.mod_to_spn("myModel.mod") There is also a function ``shorthand_to_spn`` which expects a python string containing a shorthand model. This is convenient for embedding shorthand models inside python scripts, and is particularly convenient when working with things like Jupyter notebooks. Below follows a complete session to illustrate the idea by creating and simulating a realisation from a discrete stochastic SEIR model. .. code:: python import jax import jsmfsb import jax.numpy as jnp seir_sh = """ @model:3.1.1=SEIR "SEIR Epidemic model" s=item, t=second, v=litre, e=item @compartments Pop @species Pop:S=100 s Pop:E=0 s Pop:I=5 s Pop:R=0 s @reactions @r=Infection S + I -> E + I beta*S*I : beta=0.1 @r=Transition E -> I sigma*E : sigma=0.2 @r=Removal I -> R gamma*I : gamma=0.5 """ seir = jsmfsb.shorthand_to_spn(seir_sh) step_seir = seir.step_gillespie() k0 = jax.random.key(42) out = jsmfsb.sim_time_series(k0, seir.m, 0, 40, 0.05, step_seir) import matplotlib.pyplot as plt fig, axis = plt.subplots() for i in range(len(seir.m)): axis.plot(jnp.arange(0, 40, 0.05), out[:,i]) axis.legend(seir.n) fig.savefig("seir.pdf") A `collection of appropriate models `__ is associated with the book. Spatial simulation ------------------ In addition to methods such as ``step_gillespie`` and ``step_cle`` for well-mixed simulation, ``Spn`` objects also have methods such as ``step_gillespie_1d`` and ``step_cle_2d`` for 1d and 2d spatially explicit simulation of reaction-diffusion processes on a regular grid. These functions expect to be passed an array containing the diffusion coefficient for each species. 1d simulation ~~~~~~~~~~~~~ For 1d simulation, the state is a matrix with rows representing the levels of a given species on a 1d grid. The 1d transition kernels will update such a state. The function ``sim_time_series_1d`` will return a 3d array, with 2d slices representing the state at each time point. Slicing on the first index shows the spatio-temporal evolution of a given species. .. code:: python import jsmfsb import jax import jax.numpy as jnp import matplotlib.pyplot as plt import jsmfsb.models N = 20 T = 30 x0 = jnp.zeros((2, N)) lv = jsmfsb.models.lv() x0 = x0.at[:, int(N / 2)].set(lv.m) k0 = jax.random.key(42) step_lv_1d = lv.step_gillespie_1d(jnp.array([0.6, 0.6])) x1 = step_lv_1d(k0, x0, 0, 1) print(x1) out = jsmfsb.sim_time_series_1d(k0, x0, 0, T, 1, step_lv_1d, True) fig, axis = plt.subplots() for i in range(2): axis.imshow(out[i, :, :]) axis.set_title(lv.n[i]) fig.savefig(f"step_gillespie_1d{i}.pdf") 2d simulation ~~~~~~~~~~~~~ For 2d simulation, the state is a 3d array containing the levels of each species on a 2d grid. The 2d transition kernels will update such a state. Slicing on the first index will show the 2d spatial distribution of a given species. .. code:: python import jsmfsb import jax import jax.numpy as jnp import matplotlib.pyplot as plt import jsmfsb.models M = 200 N = 250 T = 30 x0 = jnp.zeros((2, M, N)) lv = jsmfsb.models.lv() x0 = x0.at[:, int(M / 2), int(N / 2)].set(lv.m) step_lv_2d = lv.step_cle_2d(jnp.array([0.6, 0.6]), 0.1) k0 = jax.random.key(42) x1 = step_lv_2d(k0, x0, 0, T) fig, axis = plt.subplots() for i in range(2): axis.imshow(x1[i, :, :]) axis.set_title(lv.n[i]) fig.savefig(f"step_cle_2df{i}.pdf") Note that on fine 2d grids, approximate simulation using ``step_cle_2d`` is much typically much faster than exact simulation from the reaction diffusion master equation (RDME) using ``step_gillespie_2d``. Bayesian parameter inference ---------------------------- In addition to providing tools for forward-simulation from stochastic kinetic models, the library also provides tools for conducting Bayesian parameter inference for stochastic kinetic models based on observed time course data. eg. given an observed (noisy) trajectory of one or more species from a given model, find rate constants that are most consistent with the observed data. The methods provided are simulation-based, or likelihood-free, based on either `approximate Bayesian computation `__ (ABC) or (bootstrap) `particle marginal Metropolis-Hastings `__ (PMMH) particle MCMC. ABC ~~~ In a very basic version of ABC, a candidate parameter vector is drawn from a prior distribution. This parameter vector is used in conjunction with a forward-simulation algorithm for the model of interest in order to generate a synthetic data set. This synthetic data set is compared against the real data set. If they are sufficiently "close", the originally sampled parameter vector will be kept as a sample from the posterior distribution, otherwise it will be rejected, and the process will start again. The function ``abc_run`` helps to scaffold this process. A complete example using simple euclidean distance between the real and synthetic trajectories is presented below. .. code:: python import jsmfsb import jax import jax.numpy as jnp import matplotlib.pyplot as plt data = jsmfsb.data.lv_perfect[:, 1:3] def rpr(k): k1, k2, k3 = jax.random.split(k, 3) return jnp.exp( jnp.array( [ jax.random.uniform(k1, minval=-3, maxval=3), jax.random.uniform(k2, minval=-8, maxval=-2), jax.random.uniform(k3, minval=-4, maxval=2), ] ) ) def rmod(k, th): return jsmfsb.sim_time_series( k, jnp.array([50.0, 100.0]), 0, 30, 2, jsmfsb.models.lv(th).step_cle(0.1) ) def sum_stats(dat): return dat ssd = sum_stats(data) def dist(ss): diff = ss - ssd return jnp.sqrt(jnp.sum(diff * diff)) def rdis(k, th): return dist(sum_stats(rmod(k, th))) k0 = jax.random.key(42) p, d = jsmfsb.abc_run(k0, 1000000, rpr, rdis, batch_size=100000, verb=False) q = jnp.nanquantile(d, 0.01) prmat = jnp.vstack(p) postmat = prmat[d < q, :] its, var = postmat.shape print(its, var) postmat = jnp.log(postmat) # look at posterior on log scale fig, axes = plt.subplots(2, 3) axes[0, 0].scatter(postmat[:, 0], postmat[:, 1], s=0.5) axes[0, 1].scatter(postmat[:, 0], postmat[:, 2], s=0.5) axes[0, 2].scatter(postmat[:, 1], postmat[:, 2], s=0.5) axes[1, 0].hist(postmat[:, 0], bins=30) axes[1, 1].hist(postmat[:, 1], bins=30) axes[1, 2].hist(postmat[:, 2], bins=30) fig.savefig("abc.pdf") Using simple euclidean distance between the trajectories is probably not a great idea. See the file ``abc-cal.py`` in the `demo directory `__ for an example using more sophisticated summary statistics, calibrated via a pilot run to be on a consistent scale. ABC-SMC ~~~~~~~ Even using well-tuned summary statistics, naive rejection-based ABC is a rather inefficient algorithm. By combining ideas of ABC with those of `sequential Monte Carlo `__ (SMC) one can develop an ABC-SMC algorithm which gradually "zooms in" on promising parts of the parameter space using a sequence of updates in conjunction with a parameter purturbation kernel. The precise details are beyond the scope of this tutorial, but below is a complete example, using calibrated summary statistics from a pilot run. The function ``abc_smc`` performs the Bayesian update. .. code:: python import jsmfsb import jax import jax.numpy as jnp import jax.scipy as jsp import matplotlib.pyplot as plt data = jsmfsb.data.lv_perfect[:, 1:3] def rpr(k): k1, k2, k3 = jax.random.split(k, 3) return jnp.array( [ jax.random.uniform(k1, minval=-2, maxval=2), jax.random.uniform(k2, minval=-7, maxval=-3), jax.random.uniform(k3, minval=-3, maxval=1), ] ) def dpr(th): return jnp.sum( jnp.log( jnp.array( [ ((th[0] > -2) & (th[0] < 2)) / 4, ((th[1] > -7) & (th[1] < -3)) / 4, ((th[2] > -3) & (th[2] < 1)) / 4, ] ) ) ) def rmod(k, th): return jsmfsb.sim_time_series( k, jnp.array([50.0, 100]), 0, 30, 2, jsmfsb.models.lv(jnp.exp(th)).step_cle(0.1) ) print("Pilot run...") def ss1d(vec): n = len(vec) mean = jnp.nanmean(vec) v0 = vec - mean var = jnp.nanvar(v0) acs = [ jnp.corrcoef(v0[0 : (n - 1)], v0[1:n])[0, 1], jnp.corrcoef(v0[0 : (n - 2)], v0[2:n])[0, 1], jnp.corrcoef(v0[0 : (n - 3)], v0[3:n])[0, 1], ] return jnp.array([jnp.log(mean + 1), jnp.log(var + 1), acs[0], acs[1], acs[2]]) def ssi(ts): return jnp.concatenate( ( ss1d(ts[:, 0]), ss1d(ts[:, 1]), jnp.array([jnp.corrcoef(ts[:, 0], ts[:, 1])[0, 1]]), ) ) key = jax.random.key(42) p, d = jsmfsb.abc_run(key, 20000, rpr, lambda k, th: ssi(rmod(k, th)), verb=False) prmat = jnp.vstack(p) dmat = jnp.vstack(d) print(prmat.shape) print(dmat.shape) dmat = dmat.at[dmat == jnp.inf].set(jnp.nan) sds = jnp.nanstd(dmat, 0) print(sds) def sum_stats(dat): return ssi(dat) / sds ssd = sum_stats(data) print("Main ABC-SMC run") def dist(ss): diff = ss - ssd return jnp.sqrt(jnp.sum(diff * diff)) def rdis(k, th): return dist(sum_stats(rmod(k, th))) def rper(k, th): return th + jax.random.normal(k, 3) * 0.5 def dper(ne, ol): return jnp.sum(jsp.stats.norm.logpdf(ne, ol, 0.5)) postmat = jsmfsb.abc_smc( key, 10000, rpr, dpr, rdis, rper, dper, factor=5, steps=8, verb=True ) its, var = postmat.shape print(its, var) fig, axes = plt.subplots(2, 3) axes[0, 0].scatter(postmat[:, 0], postmat[:, 1], s=0.5) axes[0, 1].scatter(postmat[:, 0], postmat[:, 2], s=0.5) axes[0, 2].scatter(postmat[:, 1], postmat[:, 2], s=0.5) axes[1, 0].hist(postmat[:, 0], bins=30) axes[1, 1].hist(postmat[:, 1], bins=30) axes[1, 2].hist(postmat[:, 2], bins=30) fig.savefig("abc_smc.pdf") PMMH particle MCMC ~~~~~~~~~~~~~~~~~~ PMMH is in many ways the "gold standard" likelihood free inference strategy (at least in the case of noisy observations). By combining an unbiased estimate of the model's marginal likelihood (computed using a particle filter) with a Metropolis-Hastings MCMC algorithm, it is possible to generate a Markov chain with equilibrium distribution equal to the exact posterior distribution of the parameters given the observations. Again, the technical details are beyond the scope of this tutorial, but a complete example is given below. The key functions are ``pf_marginal_ll`` and ``metropolis_hastings``. .. code:: python import jsmfsb import mcmc # extra functions in the demo directory import jax import jax.scipy as jsp import jax.numpy as jnp def obsll(x, t, y, th): return jnp.sum(jsp.stats.norm.logpdf(y - x, scale=10)) def sim_x(k, t0, th): k1, k2 = jax.random.split(k) return jnp.array([jax.random.poisson(k1, 50), jax.random.poisson(k2, 100)]).astype( jnp.float32 ) def step(k, x, t, dt, th): sf = jsmfsb.models.lv(th).step_cle(0.1) return sf(k, x, t, dt) mll = jsmfsb.pf_marginal_ll(100, sim_x, 0, step, obsll, jsmfsb.data.lv_noise_10) k0 = jax.random.key(42) k1, k2, k3 = jax.random.split(k0, 3) def prop(k, th, tune=0.01): return jnp.exp(jax.random.normal(k, shape=(3)) * tune) * th thmat = jsmfsb.metropolis_hastings( k3, jnp.array([1, 0.005, 0.6]), mll, prop, iters=5000, thin=1, verb=False ) mcmc.mcmc_summary(thmat, "pmmh.pdf") Note that the summary stats and plots are produced using some additional functions defined in the file ``mcmc.py`` in the demo directory. Converting from the ``smfsb`` python package -------------------------------------------- The API for this package is very similar to that of the ``smfsb`` package. The main difference is that non-deterministic (random) functions have an extra argument (typically the first argument) that corresponds to a JAX random number key. See the `relevant section `__ of the JAX documentation for further information regarding random numbers in JAX code. Further information ------------------- For further information, see the `demo directory `__ and the `API documentation `__.