jsmfsb package

Submodules

jsmfsb.models module

jsmfsb.models.bd(th=[1, 1.1])

Create a birth-death model

Create and return a Spn object representing a discrete stochastic birth-death model.

Parameters

th: array

array of length 2 containing the birth and death rates

Returns

Spn model object with rates th

Examples

>>> import jsmfsb
>>> import jax
>>> bd = jsmfsb.models.bd()
>>> step = bd.step_gillespie()
>>> k = jax.random.key(42)
>>> jsmfsb.sim_time_series(k, bd.m, 0, 50, 0.1, step)
jsmfsb.models.dimer(th=[0.00166, 0.2])

Create a dimerisation kinetics model

Create and return a Spn object representing a discrete stochastic dimerisation kinetics model.

Parameters

th: array

array of length 2 containing the rates of the bind and unbind reactions

Returns

Spn model object with rates th

Examples

>>> import jsmfsb
>>> import jax
>>> dimer = jsmfsb.models.dimer()
>>> step = dimer.step_gillespie()
>>> k = jax.random.key(42)
>>> jsmfsb.sim_time_series(k, dimer.m, 0, 50, 0.1, step)
jsmfsb.models.id(th=[1, 0.1])

Create an immigration-death model

Create and return a Spn object representing a discrete stochastic immigration-death model.

Parameters

th: array

array of length 2 containing the immigration and death rates

Returns

Spn model object with rates th

Examples

>>> import jsmfsb
>>> import jax
>>> id = jsmfsb.models.id()
>>> step = id.step_gillespie()
>>> k = jax.random.key(42)
>>> jsmfsb.sim_time_series(k, id.m, 0, 50, 0.1, step)
jsmfsb.models.lv(th=[1, 0.005, 0.6])

Create a Lotka-Volterra model

Create and return a Spn object representing a discrete stochastic Lotka-Volterra model.

Parameters

th: array

array of length 3 containing the rates of the three governing reactions, prey reproduction, predator-prey interaction, and predator death

Returns

Spn model object with rates th

Examples

>>> import jsmfsb
>>> import jax
>>> lv = jsmfsb.models.lv()
>>> step = lv.step_gillespie()
>>> k = jax.random.key(42)
>>> jsmfsb.sim_time_series(k, lv.m, 0, 50, 0.1, step)
jsmfsb.models.mm(th=[0.00166, 0.0001, 0.1])

Create a Michaelis-Menten enzyme kinetic model

Create and return a Spn object representing a discrete stochastic Michaelis-Menten enzyme kinetic model.

Parameters

th: array

array of length 3 containing the binding, unbinding and production rates

Returns

Spn model object with rates th

Examples

>>> import jsmfsb
>>> import jax
>>> mm = jsmfsb.models.mm()
>>> step = mm.step_gillespie()
>>> k = jax.random.key(42)
>>> jsmfsb.sim_time_series(k, mm.m, 0, 50, 0.1, step)
jsmfsb.models.sir(th=[0.0015, 0.1])

Create a basic SIR compartmental epidemic model

Create and return a Spn object representing a discrete stochastic SIR model.

Parameters

th: array

array of length 2 containing the rates of the two governing transitions

Returns

Spn model object with rates th

Examples

>>> import jsmfsb
>>> import jax
>>> sir = jsmfsb.models.sir()
>>> step = sir.step_gillespie()
>>> k = jax.random.key(42)
>>> jsmfsb.sim_time_series(k, sir.m, 0, 50, 0.1, step)

jsmfsb.sim module

jsmfsb.sim.sim_sample(key, n, x0, t0, deltat, step_fun, batch_size=None)

Simulate a many realisations of a model at a given fixed time in the future given an initial time and state, using a function (closure) for advancing the state of the model

This function simulates many realisations of a model at a given fixed time in the future given an initial time and state, using a function (closure) for advancing the state of the model , such as created by ‘step_gillespie’ or ‘step_cle’.

Parameters

key: JAX random number key

An unused random number key.

n: int

The number of samples required.

x0: array of numbers

The intial state of the system at time t0.

t0: float

The intial time to be associated with the initial state.

deltat: float

The amount of time in the future of t0 at which samples of the system state are required.

step_fun: function

A function (closure) for advancing the state of the process, such as produced by `step_gillespie’ or `step_cle’.

batch_size: int

A batch size for “jax.lax.map”. If provided, will parallelise.

Returns

A matrix with rows representing simulated states at time t0+deltat.

Examples

>>> import jax
>>> import jsmfsb.models
>>> lv = jsmfsb.models.lv()
>>> stepLv = lv.step_gillespie()
>>> jsmfsb.sim_sample(jax.random.key(42), 10, lv.m, 0, 30, stepLv)
jsmfsb.sim.sim_time_series(key, x0, t0, tt, dt, step_fun)

Simulate a model on a regular grid of times, using a function (closure) for advancing the state of the model

This function simulates single realisation of a model on a regular grid of times using a function (closure) for advancing the state of the model, such as created by ‘step_gillespie’ or ‘step_cle’.

Parameters

key: JAX random number key

An unused random number key.

x0: array of numbers

The intial state of the system at time t0

t0: float

This intial time to be associated with the intial state.

tt: float

The terminal time of the simulation.

dt: float

The time step of the output. Note that this time step relates only to the recorded output, and has no bearing on the accuracy of the simulation process.

step_fun: function

A function (closure) for advancing the state of the process, such as produced by ‘step_gillespie’ or ‘step_cle’.

Returns

A matrix with rows representing the state of the system at successive times.

Examples

>>> import jax
>>> import jsmfsb.models
>>> lv = jsmfsb.models.lv()
>>> stepLv = lv.step_gillespie()
>>> jsmfsb.sim_time_series(jax.random.key(42), lv.m, 0, 100, 0.1, stepLv)

jsmfsb.smfsb_sbml module

jsmfsb.smfsb_sbml.file_to_spn(filename, verb=False)

Convert an SBML model into a Spn object

Read a file containing a model in SBML and convert into an Spn object for simulation and analysis.

Parameters

filename: string

String name of file containing the model

verb: boolean

Output some debugging info

Returns

An Spn object

Examples

>>> import jsmfsb
>>> myMod = jsmfsb.file_to_spn("myModel.xml")
>>> step = myMod.step_gillespie()
jsmfsb.smfsb_sbml.mod_to_spn(filename, verb=False)

Convert an SBML-shorthand model into a Spn object

Read a file containing a model in SBML-shorthand and convert into an Spn object for simulation and analysis.

Parameters

filename: string

String name of file containing the model

verb: boolean

Output some debugging info

Returns

An Spn object

Examples

>>> import jsmfsb
>>> myMod = jsmfsb.mod_to_spn("myModel.mod")
>>> step = myMod.step_gillespie()
jsmfsb.smfsb_sbml.model_to_spn(m, verb=False)

Convert a libSBML model into a Spn object

Convert a libSBML model into a Spn object for simulation and analysis.

Parameters

m: model

A libsbml model (not document) object

verb: boolean

Output some debugging info

Returns

An Spn object

Examples

>>> import jsmfsb
>>> import libsbml
>>> d = libsbml.readSBML("myModel.xml")
>>> m = d.getModel()
>>> myMod = jsmfsb.model_to_spn(m)
>>> step = myMod.step_gillespie()
jsmfsb.smfsb_sbml.shorthand_to_spn(sh_string, verb=False)

Convert an SBML-shorthand model string into a Spn object

Parse a string containing a model in SBML-shorthand and convert into an Spn object for simulation and analysis.

Parameters

sh_string: string

String containing the model

verb: boolean

Output some debugging info

Returns

An Spn object

Examples

>>> import jsmfsb
>>> file = open('myModel.mod', 'r')
>>> myModStr = file.read()
>>> file.close()
>>> myMod = jsmfsb.shorthand_to_spn(myModStr)
>>> step = myMod.step_gillespie()

jsmfsb.spn module

class jsmfsb.spn.Spn(n, t, pre, post, h, m)

Bases: object

Class for stochastic Petri net models.

step_cle(dt=0.01)

Create a function for advancing the state of an SPN by using a simple Euler-Maruyama integration method for the associated CLE

This method returns a function for advancing the state of an SPN model using a simple Euler-Maruyama integration method method for the chemical Langevin equation form of the model.The resulting function (closure) can be used in conjunction with other functions (such as sim_time_series) for simulating realisations of SPN models.

Parameters

dtfloat

The time step for the time-stepping integration method. Defaults to 0.01.

Returns

A function which can be used to advance the state of the SPN model by using an Euler-Maruyama method with step size ‘dt’. The function closure has interface ‘function(key, x0, t0, deltat)’, where ‘x0’ and ‘t0’ represent the initial state and time, and ‘deltat’ represents the amount of time by which the process should be advanced. The function closure returns a vector representing the simulated state of the system at the new time.

Examples

>>> import jsmfsb.models
>>> import jax
>>> lv = jsmfsb.models.lv()
>>> stepLv = lv.step_cle(0.001)
>>> stepLv(jax.random.key(42), lv.m, 0, 1)
step_cle_1d(d, dt=0.01)

Create a function for advancing the state of an SPN by using a simple Euler-Maruyama discretisation of the CLE on a 1D regular grid

This method creates a function for advancing the state of an SPN model using a simple Euler-Maruyama discretisation of the CLE on a 1D regular grid. The resulting function (closure) can be used in conjunction with other functions (such as sim_time_series_1d) for simulating realisations of SPN models in space and time.

Parameters

darray

A vector of diffusion coefficients - one coefficient for each reacting species, in order. The coefficient is the reaction rate for a reaction for a molecule moving into an adjacent compartment. The hazard for a given molecule leaving the compartment is therefore twice this value (as it can leave to the left or the right).

dtfloat

Time step for the Euler-Maruyama discretisation.

Returns

A function which can be used to advance the state of the SPN model by using a simple Euler-Maruyama algorithm. The function closure has parameters key, x0, t0, deltat, where key is a JAX random number key, x0 is a matrix with rows corresponding to species and columns corresponding to voxels, representing the initial condition, t0 represents the initial state and time, and deltat represents the amount of time by which the process should be advanced. The function closure returns a matrix representing the simulated state of the system at the new time.

Examples

>>> import jsmfsb.models
>>> import jax
>>> import jax.numpy as jnp
>>> lv = jsmfsb.models.lv()
>>> stepLv1d = lv.step_cle_1d(jnp.array([0.6,0.6]))
>>> N = 20
>>> x0 = jnp.zeros((2,N))
>>> x0 = x0.at[:,int(N/2)].set(lv.m)
>>> k0 = jax.random.key(42)
>>> stepLv1d(k0, x0, 0, 1)
step_cle_2d(d, dt=0.01)

Create a function for advancing the state of an SPN by using a simple Euler-Maruyama discretisation of the CLE on a 2D regular grid

This method creates a function for advancing the state of an SPN model using a simple Euler-Maruyama discretisation of the CLE on a 2D regular grid. The resulting function (closure) can be used in conjunction with other functions (such as sim_time_series_2d) for simulating realisations of SPN models in space and time.

Parameters

darray

A vector of diffusion coefficients - one coefficient for each reacting species, in order. The coefficient is the reaction rate for a reaction for a molecule moving into an adjacent compartment. The hazard for a given molecule leaving the compartment is therefore four times this value (as it can leave in one of 4 directions).

dtfloat

Time step for the Euler-Maruyama discretisation.

Returns

A function which can be used to advance the state of the SPN model by using a simple Euler-Maruyama algorithm. The function closure has parameters key, x0, t0, deltat, where x0 is a 3d array with indices species, then rows and columns corresponding to voxels, representing the initial condition, t0 represents the initial state and time, and deltat represents the amount of time by which the process should be advanced. The function closure returns a matrix representing the simulated state of the system at the new time.

Examples

>>> import jsmfsb.models
>>> import jax
>>> import jax.numpy as jnp
>>> lv = jsmfsb.models.lv()
>>> stepLv2d = lv.step_cle_2d(jnp.array([0.6,0.6]))
>>> M = 15
>>> N = 20
>>> x0 = jnp.zeros((2,M,N))
>>> x0 = x0.at[:,int(M/2),int(N/2)].set(lv.m)
>>> k0 = jax.random.key(42)
>>> stepLv2d(k0, x0, 0, 1)
step_euler(dt=0.01)

Create a function for advancing the state of an SPN by using a simple continuous deterministic Euler integration method

This method returns a function for advancing the state of an SPN model using a simple continuous deterministic Euler integration method. The resulting function (closure) can be used in conjunction with other functions (such as ‘sim_time_series’) for simulating realisations of SPN models.

Parameters

dtfloat

The time step for the time-stepping integration method. Defaults to 0.01.

Returns

A function which can be used to advance the state of the SPN model by using an Euler method with step size ‘dt’. The function closure has interface ‘function(key, x0, t0, deltat)’, where ‘x0’ and ‘t0’ represent the initial state and time, and ‘deltat’ represents the amount of time by which the process should be advanced. The random key, key, is ignored. The function closure returns a vector representing the simulated state of the system at the new time.

Examples

>>> import jsmfsb.models
>>> import jax
>>> lv = jsmfsb.models.lv()
>>> stepLv = lv.step_euler(0.001)
>>> k = jax.random.key(42)
>>> stepLv(k, lv.m, 0, 1)
step_euler_1d(d, dt=0.01)

Create a function for advancing the state of an SPN by using a simple forward Euler discretisation of the reaction-diffusion on a 1D regular grid

This method creates a function for advancing the state of an SPN model using a simple Euler discretisation of the reaction-diffusion on a 1D regular grid. The resulting function (closure) can be used in conjunction with other functions (such as sim_time_series_1d) for simulating realisations of SPN models in space and time.

Parameters

darray

A vector of diffusion coefficients - one coefficient for each reacting species, in order. The coefficient is the reaction rate for a reaction for a molecule moving into an adjacent compartment. The hazard for a given molecule leaving the compartment is therefore twice this value (as it can leave to the left or the right).

dtfloat

Time step for the Euler discretisation.

Returns

A function which can be used to advance the state of the SPN model by using a simple forward Euler algorithm. The function closure has parameters key, x0, t0, deltat, where key is a JAX random number key (which is ignored), x0 is a matrix with rows corresponding to species and columns corresponding to voxels, representing the initial condition, t0 represents the initial state and time, and deltat represents the amount of time by which the process should be advanced. The function closure returns a matrix representing the simulated state of the system at the new time.

Examples

>>> import jsmfsb.models
>>> import jax
>>> import jax.numpy as jnp
>>> lv = jsmfsb.models.lv()
>>> stepLv1d = lv.step_euler_1d(jnp.array([0.6,0.6]))
>>> N = 20
>>> x0 = jnp.zeros((2,N))
>>> x0 = x0.at[:,int(N/2)].set(lv.m)
>>> k0 = jax.random.key(42)
>>> stepLv1d(k0, x0, 0, 1)
step_euler_2d(d, dt=0.01)

Create a function for advancing the state of an SPN by using a simple forward Euler discretisation of the reaction-diffusion on a 2D regular grid

This method creates a function for advancing the state of an SPN model using a simple Euler discretisation of the reaction-diffusion on a 2D regular grid. The resulting function (closure) can be used in conjunction with other functions (such as sim_time_series_2d) for simulating realisations of SPN models in space and time.

Parameters

darray

A vector of diffusion coefficients - one coefficient for each reacting species, in order. The coefficient is the reaction rate for a reaction for a molecule moving into an adjacent compartment. The hazard for a given molecule leaving the compartment is therefore four times this value (as it can leave in one of 4 directions).

dtfloat

Time step for the Euler-Maruyama discretisation.

Returns

A function which can be used to advance the state of the SPN model by using a simple Euler-Maruyama algorithm. The function closure has parameters key (ignored), x0, t0, deltat, where x0 is a 3d array with indices species, then rows and columns corresponding to voxels, representing the initial condition, t0 represents the initial state and time, and deltat represents the amount of time by which the process should be advanced. The function closure returns a matrix representing the simulated state of the system at the new time.

Examples

>>> import jsmfsb.models
>>> import jax
>>> import jax.numpy as jnp
>>> lv = jsmfsb.models.lv()
>>> stepLv2d = lv.step_euler_2d(jnp.array([0.6,0.6]))
>>> M = 15
>>> N = 20
>>> x0 = jnp.zeros((2,M,N))
>>> x0 = x0.at[:,int(M/2),int(N/2)].set(lv.m)
>>> k0 = jax.random.key(42)
>>> stepLv2d(k0, x0, 0, 1)
step_gillespie(min_haz=1e-10, max_haz=10000000.0)

Create a function for advancing the state of a SPN by using the Gillespie algorithm

This method returns a function for advancing the state of an SPN model using the Gillespie algorithm. The resulting function (closure) can be used in conjunction with other functions (such as sim_time_series) for simulating realisations of SPN models.

Parameters

min_hazfloat

Minimum hazard to consider before assuming 0. Defaults to 1e-10.

max_hazfloat

Maximum hazard to consider before assuming an explosion and bailing out. Defaults to 1e07.

Returns

A function which can be used to advance the state of the SPN model by using the Gillespie algorithm. The function closure has interface function(key, x0, t0, deltat), where key is an unused JAX random key, x0 and t0 represent the initial state and time, and deltat represents the amount of time by which the process should be advanced. The function closure returns a vector representing the simulated state of the system at the new time.

Examples

>>> import jsmfsb.models
>>> import jax
>>> lv = jsmfsb.models.lv()
>>> stepLv = lv.step_gillespie()
>>> stepLv(jax.random.key(42), lv.m, 0, 1)
step_gillespie_1d(d, min_haz=1e-10, max_haz=10000000.0)

Create a function for advancing the state of an SPN by using the Gillespie algorithm on a 1D regular grid

This method creates a function for advancing the state of an SPN model using the Gillespie algorithm. The resulting function (closure) can be used in conjunction with other functions (such as sim_time_series_1d) for simulating realisations of SPN models in space and time.

Parameters

darray

A vector of diffusion coefficients - one coefficient for each reacting species, in order. The coefficient is the reaction rate for a reaction for a molecule moving into an adjacent compartment. The hazard for a given molecule leaving the compartment is therefore twice this value (as it can leave to the left or the right).

min_hazfloat

Minimum hazard to consider before assuming 0. Defaults to 1e-10.

max_hazfloat

Maximum hazard to consider before assuming an explosion and bailing out. Defaults to 1e07.

Returns

A function which can be used to advance the state of the SPN model by using the Gillespie algorithm. The function closure has arguments key, x0, t0, deltat, where key is a JAX random key, x0 is a matrix with rows corresponding to species and columns corresponding to voxels, representing the initial condition, t0 represent the initial state and time, and deltat represents the amount of time by which the process should be advanced. The function closure returns a matrix representing the simulated state of the system at the new time.

Examples

>>> import jsmfsb.models
>>> import jax
>>> import jax.numpy as jnp
>>> lv = jsmfsb.models.lv()
>>> stepLv1d = lv.step_gillespie_1d(jnp.array([0.6, 0.6]))
>>> N = 20
>>> x0 = jnp.zeros((2,N))
>>> x0 = x0.at[:,int(N/2)].set(lv.m)
>>> k0 = jax.random.key(42)
>>> stepLv1d(k0, x0, 0, 1)
step_gillespie_2d(d, min_haz=1e-10, max_haz=10000000.0)

Create a function for advancing the state of an SPN by using the Gillespie algorithm on a 2D regular grid

This method creates a function for advancing the state of an SPN model using the Gillespie algorithm. The resulting function (closure) can be used in conjunction with other functions (such as sim_time_series_2d) for simulating realisations of SPN models in space and time.

Parameters

darray

A vector of diffusion coefficients - one coefficient for each reacting species, in order. The coefficient is the reaction rate for a reaction for a molecule moving into an adjacent compartment. The hazard for a given molecule leaving the compartment is therefore four times this value (as it can leave in one of 4 directions).

min_hazfloat

Minimum hazard to consider before assuming 0. Defaults to 1e-10.

max_hazfloat

Maximum hazard to consider before assuming an explosion and bailing out. Defaults to 1e07.

Returns

A function which can be used to advance the state of the SPN model by using the Gillespie algorithm. The function closure has arguments key, x0, t0, deltat, where key is a JAX random key, x0 is a 3d array with dimensions corresponding to species then two spatial dimensions, representing the initial condition, t0 represents the time of the initial state, and deltat represents the amount of time by which the process should be advanced. The function closure returns an array representing the simulated state of the system at the new time.

Examples

>>> import jsmfsb.models
>>> import jax
>>> import jax.numpy as jnp
>>> lv = jsmfsb.models.lv()
>>> stepLv2d = lv.step_gillespie_2d(jnp.array([0.6, 0.6]))
>>> N = 20
>>> x0 = jnp.zeros((2, N, N))
>>> x0 = x0.at[:, int(N/2), int(N/2)].set(lv.m)
>>> k0 = jax.random.key(42)
>>> stepLv2d(k0, x0, 0, 1)
step_poisson(dt=0.01)

Create a function for advancing the state of an SPN by using a simple approximate Poisson time stepping method

This method returns a function for advancing the state of an SPN model using a simple approximate Poisson time stepping method. The resulting function (closure) can be used in conjunction with other functions (such as ‘sim_time_series’) for simulating realisations of SPN models.

Parameters

dtfloat

The time step for the time-stepping integration method. Defaults to 0.01.

Returns

A function which can be used to advance the state of the SPN model by using a Poisson time stepping method with step size ‘dt’. The function closure has interface ‘function(key, x0, t0, deltat)’, where ‘x0’ and ‘t0’ represent the initial state and time, and ‘deltat’ represents the amount of time by which the process should be advanced. The function closure returns a vector representing the simulated state of the system at the new time.

Examples

>>> import jsmfsb.models
>>> import jax
>>> lv = jsmfsb.models.lv()
>>> stepLv = lv.step_poisson(0.001)
>>> k = jax.random.key(42)
>>> stepLv(k, lv.m, 0, 1)

jsmfsb.spatial module

jsmfsb.spatial.sim_time_series_1d(key, x0, t0, tt, dt, step_fun, verb=False)

Simulate a model on a regular grid of times, using a function (closure) for advancing the state of the model

This function simulates single realisation of a model on a 1D regular spatial grid and regular grid of times using a function (closure) for advancing the state of the model, such as created by step_gillespie_1d.

Parameters

key: JAX random number key

Initial random number key to seed the simulation.

x0array

The initial state of the process at time t0, a matrix with rows corresponding to reacting species and columns corresponding to spatial location.

t0float

The initial time to be associated with the initial state x0.

ttfloat

The terminal time of the simulation.

dtfloat

The time step of the output. Note that this time step relates only to the recorded output, and has no bearing on the accuracy of the simulation process.

step_funfunction

A function (closure) for advancing the state of the process, such as produced by step_gillespie_1d.

verbboolean

Output progress to the console (this function can be very slow).

Returns

A 3d array representing the simulated process. The dimensions are species, space, and time.

Examples

>>> import jsmfsb.models
>>> import jax
>>> import jax.numpy as jnp
>>> lv = jsmfsb.models.lv()
>>> stepLv1d = lv.step_gillespie_1d(jnp.array([0.6,0.6]))
>>> N = 10
>>> T = 5
>>> x0 = jnp.zeros((2,N))
>>> x0 = x0.at[:,int(N/2)].set(lv.m)
>>> k0 = jax.random.key(42)
>>> jsmfsb.sim_time_series_1d(k0, x0, 0, T, 1, stepLv1d, True)
jsmfsb.spatial.sim_time_series_2d(key, x0, t0, tt, dt, step_fun, verb=False)

Simulate a model on a regular grid of times, using a function (closure) for advancing the state of the model

This function simulates single realisation of a model on a 2D regular spatial grid and regular grid of times using a function (closure) for advancing the state of the model, such as created by step_gillespie_2d.

Parameters

key: JAX random number key

Random key to seed the simulation.

x0array

The initial state of the process at time t0, a 3d array with dimensions corresponding to reacting species and then two corresponding to spatial location.

t0float

The initial time to be associated with the initial state x0.

ttfloat

The terminal time of the simulation.

dtfloat

The time step of the output. Note that this time step relates only to the recorded output, and has no bearing on the accuracy of the simulation process.

step_funfunction

A function (closure) for advancing the state of the process, such as produced by step_gillespie_2d.

verbboolean

Output progress to the console (this function can be very slow).

Returns

A 4d array representing the simulated process. The dimensions are species, two space, and time.

Examples

>>> import jsmfsb.models
>>> import jax
>>> import jax.numpy as jnp
>>> lv = jsmfsb.models.lv()
>>> stepLv2d = lv.step_gillespie_2d(jnp.array([0.6,0.6]))
>>> M = 10
>>> N = 15
>>> T = 5
>>> x0 = jnp.zeros((2,M,N))
>>> x0 = x0.at[:,int(M/2),int(N/2)].set(lv.m)
>>> k0 = jax.random.key(42)
>>> jsmfsb.sim_time_series_2d(k0, x0, 0, T, 1, stepLv2d, True)

jsmfsb.inference module

jsmfsb.inference.abc_run(key, n, rprior, rdist, batch_size=None, verb=False)

Run a set of simulations initialised with parameters sampled from a given prior distribution, and compute statistics required for an ABC analaysis

Run a set of simulations initialised with parameters sampled from a given prior distribution, and compute statistics required for an ABC analaysis. Typically used to calculate “distances” of simulated synthetic data from observed data.

Parameters

key: JAX random number key

Key to initialise the ABC simulation.

nint

An integer representing the number of simulations to run.

rpriorfunction

A function with one argument, a JAX random key, generating a single parameter (vector) from a prior distribution.

rdistfunction

A function with two arguments, a JAX random key, and a parameter (vector). It returns the required statistic of interest. This will typically be computed by first using the parameter to run a forward model, then computing required summary statistics, then computing a distance. See the example for details.

batch_size: int

batch_size to use in call to jax.lax.map for parallelisation. Defaults to None.

verbboolean

Print progress information to console? Defaults to False.

Returns

A tuple with first component a matrix of parameters (in rows) and second component a vector of corresponding distances.

Examples

>>> import jsmfsb
>>> import jax
>>> import jax.numpy as jnp
>>> import jax.scipy as jsp
>>> k0 = jax.random.key(42)
>>> k1, k2 = jax.random.split(k0)
>>> data = jax.random.normal(k1, 250)*2 + 5
>>> def rpr(k):
>>>   return jnp.exp(jax.random.uniform(k, 2, minval=-3, maxval=3))
>>>
>>> def rmod(k, th):
>>>   return jax.random.normal(k, 250)*th[1] + th[0]
>>>
>>> def sumStats(dat):
>>>   return jnp.array([jnp.mean(dat), jnp.std(dat)])
>>>
>>> ssd = sumStats(data)
>>> def dist(ss):
>>>   diff = ss - ssd
>>>   return jnp.sqrt(jnp.sum(diff*diff))
>>>
>>> def rdis(k, th):
>>>   return dist(sumStats(rmod(k, th)))
>>>
>>> jsmfsb.abc_run(k2, 100, rpr, rdis)
jsmfsb.inference.abc_smc(key, n, rprior, dprior, rdist, rperturb, dperturb, factor=10, steps=15, verb=False, debug=False)

Run an ABC-SMC algorithm for infering the parameters of a forward model

Run an ABC-SMC algorithm for infering the parameters of a forward model. This sequential Monte Carlo algorithm often performs better than simple rejection-ABC in practice.

Parameters

keyJAX random key

A key to initialise the simulation.

nint

An integer representing the number of simulations to pass on at each stage of the SMC algorithm. Note that the TOTAL number of forward simulations required by the algorithm will be (roughly) ‘N*steps*factor’.

rpriorfunction

A function with a single argument, a JAX random key, which generates a single parameter (vector) from the prior.

dpriorfunction

A function taking a parameter vector as argumnent and returning the log of the prior density.

rdistfunction

A function with two arguments: a JAX random key and a parameter vector. It should return a scalar “distance” representing a measure of how good the chosen parameter is. This will typically be computed by first using the parameter to run a forward model, then computing required summary statistics, then computing a distance. See the example for details.

rperturbfunction

A function with two arguments: a JAX random key and a parameter vector. It should return a perturbed parameter from an appropriate kernel.

dperturbfunction

A function which takes a pair of parameters as its first two arguments (new first and old second), and returns the log of the density associated with this perturbation kernel.

factorint

At each step of the algorithm, ‘N*factor’ proposals are generated and the best ‘N’ of these are weighted and passed on to the next stage. Note that the effective sample size of the parameters passed on to the next step may be (much) smaller than ‘N’, since some of the particles may be assigned small (or zero) weight. Defaults to 10.

stepsint

The number of steps of the ABC-SMC algorithm. Typically, somewhere between 5 and 100 steps seems to be used in practice. Defaults to 15.

verbboolean

Boolean indicating whether some progress should be printed to the console.

Returns

A matrix with rows representing samples from the approximate posterior distribution.

Examples

>>> import jsmfsb
>>> import jax
>>> import jax.numpy as jnp
>>> import jax.scipy as jsp
>>> k0 = jax.random.key(42)
>>> k1, k2 = jax.random.split(k0)
>>> data = jax.random.normal(k1, 250)*2 + 5
>>> def rpr(k):
>>>   return jnp.exp(jax.random.uniform(k, 2, minval=-3, maxval=3))
>>>
>>> def rmod(k, th):
>>>   return jax.random.normal(k, 250)*jnp.exp(th[1]) + jnp.exp(th[0])
>>>
>>> def sumStats(dat):
>>>   return jnp.array([jnp.mean(dat), jnp.std(dat)])
>>>
>>> ssd = sumStats(data)
>>> def dist(ss):
>>>   diff = ss - ssd
>>>   return jnp.sqrt(jnp.sum(diff*diff))
>>>
>>> def rdis(k, th):
>>>   return dist(sumStats(rmod(k, th)))
>>>
>>> jsmfsb.abc_smc(k2, 100, rpr,
>>>                        lambda x: jnp.sum(jnp.log(((x<3)&(x>-3))/6)),
>>>                        rdis,
>>>                        lambda k,x: jax.random.normal(k)*0.1 + x,
>>>                        lambda x,y: jnp.sum(jsp.stats.norm.logpdf(y, x, 0.1)))
jsmfsb.inference.abc_smc_step(key, dprior, prior_sample, prior_lw, rdist, rperturb, dperturb, factor)

Carry out one step of an ABC-SMC algorithm

Not meant to be directly called by users. See abc_smc.

jsmfsb.inference.metropolis_hastings(key, init, log_lik, rprop, ldprop=<function <lambda>>, ldprior=<function <lambda>>, iters=10000, thin=10, verb=True)

Run a Metropolis-Hastings MCMC algorithm for the parameters of a Bayesian posterior distribution

Run a Metropolis-Hastings MCMC algorithm for the parameters of a Bayesian posterior distribution. Note that the algorithm carries over the old likelihood from the previous iteration, making it suitable for problems with expensive likelihoods, and also for “exact approximate” pseudo-marginal or particle marginal MH algorithms.

Parameters

key: JAX random number key

A key to seed the simulation.

initvector

A parameter vector with which to initialise the MCMC algorithm.

log_lik(stochastic) function

A function which takes two arguments: a JAX random key and a parameter (the same type as init) as its second argument. It should return the log-likelihood of the data. Note that it is fine for this to return the log of an unbiased estimate of the likelihood, in which case the algorithm will be an “exact approximate” pseudo-marginal MH algorithm. This is the reason why the function should accept a JAX random key. In the “vanilla” case, where the log-likelihood is deterministic, the function should simply ignore the key that is passed in.

rpropstochastic function

A function which takes a random key and a current parameter as its two required arguments and returns a single sample from a proposal distribution.

ldpropfunction

A function which takes a new and old parameter as its first two required arguments and returns the log density of the new value conditional on the old. Defaults to a flat function which causes this term to drop out of the acceptance probability. It is fine to use the default for _any_ _symmetric_ proposal, since the term will also drop out for any symmetric proposal.

ldpriorfunction

A function which take a parameter as its only required argument and returns the log density of the parameter value under the prior. Defaults to a flat function which causes this term to drop out of the acceptance probability. People often use a flat prior when they are trying to be “uninformative” or “objective”, but this is slightly naive. In particular, what is “flat” is clearly dependent on the parametrisation of the model.

itersint

The number of MCMC iterations required (_after_ thinning).

thinint

The required thinning factor. eg. only store every thin iterations.

verbboolean

Boolean indicating whether some progress information should be printed to the console. Defaults to True.

Returns

A matrix with rows representing samples from the posterior distribution.

Examples

>>> import jsmfsb
>>> import jax
>>> import jax.numpy as jnp
>>> import jax.scipy as jsp
>>> k0 = jax.random.key(42)
>>> k1, k2 = jax.random.split(k0)
>>> data = jax.random.normal(k1, 250)*2 + 5
>>> llik = lambda k, x: jnp.sum(jsp.stats.norm.logpdf(data, x[0], x[1]))
>>> prop = lambda k, x: jax.random.normal(k, 2)*0.1 + x
>>> jsmfsb.metropolis_hastings(k2, jnp.array([1.0,1.0]), llik, prop)
jsmfsb.inference.pf_marginal_ll(n, sim_x0, t0, step_fun, data_ll, data, debug=False)

Create a function for computing the log of an unbiased estimate of marginal likelihood of a time course data set

Create a function for computing the log of an unbiased estimate of marginal likelihood of a time course data set using a simple bootstrap particle filter.

Parameters

nint

An integer representing the number of particles to use in the particle filter.

sim_x0function

A function with arguments key, t0 and th, where ‘t0’ is a time at which to simulate from an initial distribution for the state of the particle filter and th is a vector of parameters. The return value should be a state vector randomly sampled from the prior distribution. The function therefore represents a prior distribution on the initial state of the Markov process.

t0float

The time corresponding to the starting point of the Markov process. Can be no bigger than the smallest observation time.

step_funfunction

A function for advancing the state of the Markov process, with arguments key, x, t0, deltat and th, with th representing a vector of parameters.

data_llfunction

A function with arguments x, t, y, th, where x and t represent the true state and time of the process, y is the observed data, and th is a parameter vector. The return value should be the log of the likelihood of the observation. The function therefore represents the observation model.

datamatrix

A matrix with first column an increasing set of times. The remaining columns represent the observed values of y at those times.

Returns

A function with arguments key and th, representing a parameter vector, which evaluates to the log of the particle filters unbiased estimate of the marginal likelihood of the data (for parameter th).

Examples

>>> import jax
>>> import jax.numpy as jnp
>>> import jax.scipy as jsp
>>> import jsmfsb
>>> def obsll(x, t, y, th):
>>>     return jnp.sum(jsp.stats.norm.logpdf(y-x, scale=10))
>>>
>>> def simX(key, t0, th):
>>>     k1, k2 = jax.random.split(key)
>>>     return jnp.array([jax.random.poisson(k1, 50),
>>>              jax.random.poisson(k2, 100)]).astype(jnp.float32)
>>>
>>> def step(key, x, t, dt, th):
>>>     sf = jsmfsb.models.lv(th).step_gillespie()
>>>     return sf(key, x, t, dt)
>>>
>>> mll = jsmfsb.pf_marginal_ll(80, simX, 0, step, obsll, jsmfsb.data.lv_noise_10)
>>> k0 = jax.random.key(42)
>>> mll(k0, jnp.array([1, 0.005, 0.6]))
>>> mll(k0, jnp.array([2, 0.005, 0.6]))

jsmfsb.data module

Module contents