jsmfsb package
Submodules
jsmfsb.models module
- jsmfsb.models.bd(th=[1, 1.1])
Create a birth-death model
Create and return a Spn object representing a discrete stochastic birth-death model.
Parameters
- th: array
array of length 2 containing the birth and death rates
Returns
Spn model object with rates th
Examples
>>> import jsmfsb >>> import jax >>> bd = jsmfsb.models.bd() >>> step = bd.step_gillespie() >>> k = jax.random.key(42) >>> jsmfsb.sim_time_series(k, bd.m, 0, 50, 0.1, step)
- jsmfsb.models.dimer(th=[0.00166, 0.2])
Create a dimerisation kinetics model
Create and return a Spn object representing a discrete stochastic dimerisation kinetics model.
Parameters
- th: array
array of length 2 containing the rates of the bind and unbind reactions
Returns
Spn model object with rates th
Examples
>>> import jsmfsb >>> import jax >>> dimer = jsmfsb.models.dimer() >>> step = dimer.step_gillespie() >>> k = jax.random.key(42) >>> jsmfsb.sim_time_series(k, dimer.m, 0, 50, 0.1, step)
- jsmfsb.models.id(th=[1, 0.1])
Create an immigration-death model
Create and return a Spn object representing a discrete stochastic immigration-death model.
Parameters
- th: array
array of length 2 containing the immigration and death rates
Returns
Spn model object with rates th
Examples
>>> import jsmfsb >>> import jax >>> id = jsmfsb.models.id() >>> step = id.step_gillespie() >>> k = jax.random.key(42) >>> jsmfsb.sim_time_series(k, id.m, 0, 50, 0.1, step)
- jsmfsb.models.lv(th=[1, 0.005, 0.6])
Create a Lotka-Volterra model
Create and return a Spn object representing a discrete stochastic Lotka-Volterra model.
Parameters
- th: array
array of length 3 containing the rates of the three governing reactions, prey reproduction, predator-prey interaction, and predator death
Returns
Spn model object with rates th
Examples
>>> import jsmfsb >>> import jax >>> lv = jsmfsb.models.lv() >>> step = lv.step_gillespie() >>> k = jax.random.key(42) >>> jsmfsb.sim_time_series(k, lv.m, 0, 50, 0.1, step)
- jsmfsb.models.mm(th=[0.00166, 0.0001, 0.1])
Create a Michaelis-Menten enzyme kinetic model
Create and return a Spn object representing a discrete stochastic Michaelis-Menten enzyme kinetic model.
Parameters
- th: array
array of length 3 containing the binding, unbinding and production rates
Returns
Spn model object with rates th
Examples
>>> import jsmfsb >>> import jax >>> mm = jsmfsb.models.mm() >>> step = mm.step_gillespie() >>> k = jax.random.key(42) >>> jsmfsb.sim_time_series(k, mm.m, 0, 50, 0.1, step)
- jsmfsb.models.sir(th=[0.0015, 0.1])
Create a basic SIR compartmental epidemic model
Create and return a Spn object representing a discrete stochastic SIR model.
Parameters
- th: array
array of length 2 containing the rates of the two governing transitions
Returns
Spn model object with rates th
Examples
>>> import jsmfsb >>> import jax >>> sir = jsmfsb.models.sir() >>> step = sir.step_gillespie() >>> k = jax.random.key(42) >>> jsmfsb.sim_time_series(k, sir.m, 0, 50, 0.1, step)
jsmfsb.sim module
- jsmfsb.sim.sim_sample(key, n, x0, t0, deltat, step_fun, batch_size=None)
Simulate a many realisations of a model at a given fixed time in the future given an initial time and state, using a function (closure) for advancing the state of the model
This function simulates many realisations of a model at a given fixed time in the future given an initial time and state, using a function (closure) for advancing the state of the model , such as created by ‘step_gillespie’ or ‘step_cle’.
Parameters
- key: JAX random number key
An unused random number key.
- n: int
The number of samples required.
- x0: array of numbers
The intial state of the system at time t0.
- t0: float
The intial time to be associated with the initial state.
- deltat: float
The amount of time in the future of t0 at which samples of the system state are required.
- step_fun: function
A function (closure) for advancing the state of the process, such as produced by `step_gillespie’ or `step_cle’.
- batch_size: int
A batch size for “jax.lax.map”. If provided, will parallelise.
Returns
A matrix with rows representing simulated states at time t0+deltat.
Examples
>>> import jax >>> import jsmfsb.models >>> lv = jsmfsb.models.lv() >>> stepLv = lv.step_gillespie() >>> jsmfsb.sim_sample(jax.random.key(42), 10, lv.m, 0, 30, stepLv)
- jsmfsb.sim.sim_time_series(key, x0, t0, tt, dt, step_fun)
Simulate a model on a regular grid of times, using a function (closure) for advancing the state of the model
This function simulates single realisation of a model on a regular grid of times using a function (closure) for advancing the state of the model, such as created by ‘step_gillespie’ or ‘step_cle’.
Parameters
- key: JAX random number key
An unused random number key.
- x0: array of numbers
The intial state of the system at time t0
- t0: float
This intial time to be associated with the intial state.
- tt: float
The terminal time of the simulation.
- dt: float
The time step of the output. Note that this time step relates only to the recorded output, and has no bearing on the accuracy of the simulation process.
- step_fun: function
A function (closure) for advancing the state of the process, such as produced by ‘step_gillespie’ or ‘step_cle’.
Returns
A matrix with rows representing the state of the system at successive times.
Examples
>>> import jax >>> import jsmfsb.models >>> lv = jsmfsb.models.lv() >>> stepLv = lv.step_gillespie() >>> jsmfsb.sim_time_series(jax.random.key(42), lv.m, 0, 100, 0.1, stepLv)
jsmfsb.smfsb_sbml module
- jsmfsb.smfsb_sbml.file_to_spn(filename, verb=False)
Convert an SBML model into a Spn object
Read a file containing a model in SBML and convert into an Spn object for simulation and analysis.
Parameters
- filename: string
String name of file containing the model
- verb: boolean
Output some debugging info
Returns
An Spn object
Examples
>>> import jsmfsb >>> myMod = jsmfsb.file_to_spn("myModel.xml") >>> step = myMod.step_gillespie()
- jsmfsb.smfsb_sbml.mod_to_spn(filename, verb=False)
Convert an SBML-shorthand model into a Spn object
Read a file containing a model in SBML-shorthand and convert into an Spn object for simulation and analysis.
Parameters
- filename: string
String name of file containing the model
- verb: boolean
Output some debugging info
Returns
An Spn object
Examples
>>> import jsmfsb >>> myMod = jsmfsb.mod_to_spn("myModel.mod") >>> step = myMod.step_gillespie()
- jsmfsb.smfsb_sbml.model_to_spn(m, verb=False)
Convert a libSBML model into a Spn object
Convert a libSBML model into a Spn object for simulation and analysis.
Parameters
- m: model
A libsbml model (not document) object
- verb: boolean
Output some debugging info
Returns
An Spn object
Examples
>>> import jsmfsb >>> import libsbml >>> d = libsbml.readSBML("myModel.xml") >>> m = d.getModel() >>> myMod = jsmfsb.model_to_spn(m) >>> step = myMod.step_gillespie()
- jsmfsb.smfsb_sbml.shorthand_to_spn(sh_string, verb=False)
Convert an SBML-shorthand model string into a Spn object
Parse a string containing a model in SBML-shorthand and convert into an Spn object for simulation and analysis.
Parameters
- sh_string: string
String containing the model
- verb: boolean
Output some debugging info
Returns
An Spn object
Examples
>>> import jsmfsb >>> file = open('myModel.mod', 'r') >>> myModStr = file.read() >>> file.close() >>> myMod = jsmfsb.shorthand_to_spn(myModStr) >>> step = myMod.step_gillespie()
jsmfsb.spn module
- class jsmfsb.spn.Spn(n, t, pre, post, h, m)
Bases:
objectClass for stochastic Petri net models.
- step_cle(dt=0.01)
Create a function for advancing the state of an SPN by using a simple Euler-Maruyama integration method for the associated CLE
This method returns a function for advancing the state of an SPN model using a simple Euler-Maruyama integration method method for the chemical Langevin equation form of the model.The resulting function (closure) can be used in conjunction with other functions (such as sim_time_series) for simulating realisations of SPN models.
Parameters
- dtfloat
The time step for the time-stepping integration method. Defaults to 0.01.
Returns
A function which can be used to advance the state of the SPN model by using an Euler-Maruyama method with step size ‘dt’. The function closure has interface ‘function(key, x0, t0, deltat)’, where ‘x0’ and ‘t0’ represent the initial state and time, and ‘deltat’ represents the amount of time by which the process should be advanced. The function closure returns a vector representing the simulated state of the system at the new time.
Examples
>>> import jsmfsb.models >>> import jax >>> lv = jsmfsb.models.lv() >>> stepLv = lv.step_cle(0.001) >>> stepLv(jax.random.key(42), lv.m, 0, 1)
- step_cle_1d(d, dt=0.01)
Create a function for advancing the state of an SPN by using a simple Euler-Maruyama discretisation of the CLE on a 1D regular grid
This method creates a function for advancing the state of an SPN model using a simple Euler-Maruyama discretisation of the CLE on a 1D regular grid. The resulting function (closure) can be used in conjunction with other functions (such as sim_time_series_1d) for simulating realisations of SPN models in space and time.
Parameters
- darray
A vector of diffusion coefficients - one coefficient for each reacting species, in order. The coefficient is the reaction rate for a reaction for a molecule moving into an adjacent compartment. The hazard for a given molecule leaving the compartment is therefore twice this value (as it can leave to the left or the right).
- dtfloat
Time step for the Euler-Maruyama discretisation.
Returns
A function which can be used to advance the state of the SPN model by using a simple Euler-Maruyama algorithm. The function closure has parameters key, x0, t0, deltat, where key is a JAX random number key, x0 is a matrix with rows corresponding to species and columns corresponding to voxels, representing the initial condition, t0 represents the initial state and time, and deltat represents the amount of time by which the process should be advanced. The function closure returns a matrix representing the simulated state of the system at the new time.
Examples
>>> import jsmfsb.models >>> import jax >>> import jax.numpy as jnp >>> lv = jsmfsb.models.lv() >>> stepLv1d = lv.step_cle_1d(jnp.array([0.6,0.6])) >>> N = 20 >>> x0 = jnp.zeros((2,N)) >>> x0 = x0.at[:,int(N/2)].set(lv.m) >>> k0 = jax.random.key(42) >>> stepLv1d(k0, x0, 0, 1)
- step_cle_2d(d, dt=0.01)
Create a function for advancing the state of an SPN by using a simple Euler-Maruyama discretisation of the CLE on a 2D regular grid
This method creates a function for advancing the state of an SPN model using a simple Euler-Maruyama discretisation of the CLE on a 2D regular grid. The resulting function (closure) can be used in conjunction with other functions (such as sim_time_series_2d) for simulating realisations of SPN models in space and time.
Parameters
- darray
A vector of diffusion coefficients - one coefficient for each reacting species, in order. The coefficient is the reaction rate for a reaction for a molecule moving into an adjacent compartment. The hazard for a given molecule leaving the compartment is therefore four times this value (as it can leave in one of 4 directions).
- dtfloat
Time step for the Euler-Maruyama discretisation.
Returns
A function which can be used to advance the state of the SPN model by using a simple Euler-Maruyama algorithm. The function closure has parameters key, x0, t0, deltat, where x0 is a 3d array with indices species, then rows and columns corresponding to voxels, representing the initial condition, t0 represents the initial state and time, and deltat represents the amount of time by which the process should be advanced. The function closure returns a matrix representing the simulated state of the system at the new time.
Examples
>>> import jsmfsb.models >>> import jax >>> import jax.numpy as jnp >>> lv = jsmfsb.models.lv() >>> stepLv2d = lv.step_cle_2d(jnp.array([0.6,0.6])) >>> M = 15 >>> N = 20 >>> x0 = jnp.zeros((2,M,N)) >>> x0 = x0.at[:,int(M/2),int(N/2)].set(lv.m) >>> k0 = jax.random.key(42) >>> stepLv2d(k0, x0, 0, 1)
- step_euler(dt=0.01)
Create a function for advancing the state of an SPN by using a simple continuous deterministic Euler integration method
This method returns a function for advancing the state of an SPN model using a simple continuous deterministic Euler integration method. The resulting function (closure) can be used in conjunction with other functions (such as ‘sim_time_series’) for simulating realisations of SPN models.
Parameters
- dtfloat
The time step for the time-stepping integration method. Defaults to 0.01.
Returns
A function which can be used to advance the state of the SPN model by using an Euler method with step size ‘dt’. The function closure has interface ‘function(key, x0, t0, deltat)’, where ‘x0’ and ‘t0’ represent the initial state and time, and ‘deltat’ represents the amount of time by which the process should be advanced. The random key, key, is ignored. The function closure returns a vector representing the simulated state of the system at the new time.
Examples
>>> import jsmfsb.models >>> import jax >>> lv = jsmfsb.models.lv() >>> stepLv = lv.step_euler(0.001) >>> k = jax.random.key(42) >>> stepLv(k, lv.m, 0, 1)
- step_euler_1d(d, dt=0.01)
Create a function for advancing the state of an SPN by using a simple forward Euler discretisation of the reaction-diffusion on a 1D regular grid
This method creates a function for advancing the state of an SPN model using a simple Euler discretisation of the reaction-diffusion on a 1D regular grid. The resulting function (closure) can be used in conjunction with other functions (such as sim_time_series_1d) for simulating realisations of SPN models in space and time.
Parameters
- darray
A vector of diffusion coefficients - one coefficient for each reacting species, in order. The coefficient is the reaction rate for a reaction for a molecule moving into an adjacent compartment. The hazard for a given molecule leaving the compartment is therefore twice this value (as it can leave to the left or the right).
- dtfloat
Time step for the Euler discretisation.
Returns
A function which can be used to advance the state of the SPN model by using a simple forward Euler algorithm. The function closure has parameters key, x0, t0, deltat, where key is a JAX random number key (which is ignored), x0 is a matrix with rows corresponding to species and columns corresponding to voxels, representing the initial condition, t0 represents the initial state and time, and deltat represents the amount of time by which the process should be advanced. The function closure returns a matrix representing the simulated state of the system at the new time.
Examples
>>> import jsmfsb.models >>> import jax >>> import jax.numpy as jnp >>> lv = jsmfsb.models.lv() >>> stepLv1d = lv.step_euler_1d(jnp.array([0.6,0.6])) >>> N = 20 >>> x0 = jnp.zeros((2,N)) >>> x0 = x0.at[:,int(N/2)].set(lv.m) >>> k0 = jax.random.key(42) >>> stepLv1d(k0, x0, 0, 1)
- step_euler_2d(d, dt=0.01)
Create a function for advancing the state of an SPN by using a simple forward Euler discretisation of the reaction-diffusion on a 2D regular grid
This method creates a function for advancing the state of an SPN model using a simple Euler discretisation of the reaction-diffusion on a 2D regular grid. The resulting function (closure) can be used in conjunction with other functions (such as sim_time_series_2d) for simulating realisations of SPN models in space and time.
Parameters
- darray
A vector of diffusion coefficients - one coefficient for each reacting species, in order. The coefficient is the reaction rate for a reaction for a molecule moving into an adjacent compartment. The hazard for a given molecule leaving the compartment is therefore four times this value (as it can leave in one of 4 directions).
- dtfloat
Time step for the Euler-Maruyama discretisation.
Returns
A function which can be used to advance the state of the SPN model by using a simple Euler-Maruyama algorithm. The function closure has parameters key (ignored), x0, t0, deltat, where x0 is a 3d array with indices species, then rows and columns corresponding to voxels, representing the initial condition, t0 represents the initial state and time, and deltat represents the amount of time by which the process should be advanced. The function closure returns a matrix representing the simulated state of the system at the new time.
Examples
>>> import jsmfsb.models >>> import jax >>> import jax.numpy as jnp >>> lv = jsmfsb.models.lv() >>> stepLv2d = lv.step_euler_2d(jnp.array([0.6,0.6])) >>> M = 15 >>> N = 20 >>> x0 = jnp.zeros((2,M,N)) >>> x0 = x0.at[:,int(M/2),int(N/2)].set(lv.m) >>> k0 = jax.random.key(42) >>> stepLv2d(k0, x0, 0, 1)
- step_gillespie(min_haz=1e-10, max_haz=10000000.0)
Create a function for advancing the state of a SPN by using the Gillespie algorithm
This method returns a function for advancing the state of an SPN model using the Gillespie algorithm. The resulting function (closure) can be used in conjunction with other functions (such as sim_time_series) for simulating realisations of SPN models.
Parameters
- min_hazfloat
Minimum hazard to consider before assuming 0. Defaults to 1e-10.
- max_hazfloat
Maximum hazard to consider before assuming an explosion and bailing out. Defaults to 1e07.
Returns
A function which can be used to advance the state of the SPN model by using the Gillespie algorithm. The function closure has interface function(key, x0, t0, deltat), where key is an unused JAX random key, x0 and t0 represent the initial state and time, and deltat represents the amount of time by which the process should be advanced. The function closure returns a vector representing the simulated state of the system at the new time.
Examples
>>> import jsmfsb.models >>> import jax >>> lv = jsmfsb.models.lv() >>> stepLv = lv.step_gillespie() >>> stepLv(jax.random.key(42), lv.m, 0, 1)
- step_gillespie_1d(d, min_haz=1e-10, max_haz=10000000.0)
Create a function for advancing the state of an SPN by using the Gillespie algorithm on a 1D regular grid
This method creates a function for advancing the state of an SPN model using the Gillespie algorithm. The resulting function (closure) can be used in conjunction with other functions (such as sim_time_series_1d) for simulating realisations of SPN models in space and time.
Parameters
- darray
A vector of diffusion coefficients - one coefficient for each reacting species, in order. The coefficient is the reaction rate for a reaction for a molecule moving into an adjacent compartment. The hazard for a given molecule leaving the compartment is therefore twice this value (as it can leave to the left or the right).
- min_hazfloat
Minimum hazard to consider before assuming 0. Defaults to 1e-10.
- max_hazfloat
Maximum hazard to consider before assuming an explosion and bailing out. Defaults to 1e07.
Returns
A function which can be used to advance the state of the SPN model by using the Gillespie algorithm. The function closure has arguments key, x0, t0, deltat, where key is a JAX random key, x0 is a matrix with rows corresponding to species and columns corresponding to voxels, representing the initial condition, t0 represent the initial state and time, and deltat represents the amount of time by which the process should be advanced. The function closure returns a matrix representing the simulated state of the system at the new time.
Examples
>>> import jsmfsb.models >>> import jax >>> import jax.numpy as jnp >>> lv = jsmfsb.models.lv() >>> stepLv1d = lv.step_gillespie_1d(jnp.array([0.6, 0.6])) >>> N = 20 >>> x0 = jnp.zeros((2,N)) >>> x0 = x0.at[:,int(N/2)].set(lv.m) >>> k0 = jax.random.key(42) >>> stepLv1d(k0, x0, 0, 1)
- step_gillespie_2d(d, min_haz=1e-10, max_haz=10000000.0)
Create a function for advancing the state of an SPN by using the Gillespie algorithm on a 2D regular grid
This method creates a function for advancing the state of an SPN model using the Gillespie algorithm. The resulting function (closure) can be used in conjunction with other functions (such as sim_time_series_2d) for simulating realisations of SPN models in space and time.
Parameters
- darray
A vector of diffusion coefficients - one coefficient for each reacting species, in order. The coefficient is the reaction rate for a reaction for a molecule moving into an adjacent compartment. The hazard for a given molecule leaving the compartment is therefore four times this value (as it can leave in one of 4 directions).
- min_hazfloat
Minimum hazard to consider before assuming 0. Defaults to 1e-10.
- max_hazfloat
Maximum hazard to consider before assuming an explosion and bailing out. Defaults to 1e07.
Returns
A function which can be used to advance the state of the SPN model by using the Gillespie algorithm. The function closure has arguments key, x0, t0, deltat, where key is a JAX random key, x0 is a 3d array with dimensions corresponding to species then two spatial dimensions, representing the initial condition, t0 represents the time of the initial state, and deltat represents the amount of time by which the process should be advanced. The function closure returns an array representing the simulated state of the system at the new time.
Examples
>>> import jsmfsb.models >>> import jax >>> import jax.numpy as jnp >>> lv = jsmfsb.models.lv() >>> stepLv2d = lv.step_gillespie_2d(jnp.array([0.6, 0.6])) >>> N = 20 >>> x0 = jnp.zeros((2, N, N)) >>> x0 = x0.at[:, int(N/2), int(N/2)].set(lv.m) >>> k0 = jax.random.key(42) >>> stepLv2d(k0, x0, 0, 1)
- step_poisson(dt=0.01)
Create a function for advancing the state of an SPN by using a simple approximate Poisson time stepping method
This method returns a function for advancing the state of an SPN model using a simple approximate Poisson time stepping method. The resulting function (closure) can be used in conjunction with other functions (such as ‘sim_time_series’) for simulating realisations of SPN models.
Parameters
- dtfloat
The time step for the time-stepping integration method. Defaults to 0.01.
Returns
A function which can be used to advance the state of the SPN model by using a Poisson time stepping method with step size ‘dt’. The function closure has interface ‘function(key, x0, t0, deltat)’, where ‘x0’ and ‘t0’ represent the initial state and time, and ‘deltat’ represents the amount of time by which the process should be advanced. The function closure returns a vector representing the simulated state of the system at the new time.
Examples
>>> import jsmfsb.models >>> import jax >>> lv = jsmfsb.models.lv() >>> stepLv = lv.step_poisson(0.001) >>> k = jax.random.key(42) >>> stepLv(k, lv.m, 0, 1)
jsmfsb.spatial module
- jsmfsb.spatial.sim_time_series_1d(key, x0, t0, tt, dt, step_fun, verb=False)
Simulate a model on a regular grid of times, using a function (closure) for advancing the state of the model
This function simulates single realisation of a model on a 1D regular spatial grid and regular grid of times using a function (closure) for advancing the state of the model, such as created by step_gillespie_1d.
Parameters
- key: JAX random number key
Initial random number key to seed the simulation.
- x0array
The initial state of the process at time t0, a matrix with rows corresponding to reacting species and columns corresponding to spatial location.
- t0float
The initial time to be associated with the initial state x0.
- ttfloat
The terminal time of the simulation.
- dtfloat
The time step of the output. Note that this time step relates only to the recorded output, and has no bearing on the accuracy of the simulation process.
- step_funfunction
A function (closure) for advancing the state of the process, such as produced by step_gillespie_1d.
- verbboolean
Output progress to the console (this function can be very slow).
Returns
A 3d array representing the simulated process. The dimensions are species, space, and time.
Examples
>>> import jsmfsb.models >>> import jax >>> import jax.numpy as jnp >>> lv = jsmfsb.models.lv() >>> stepLv1d = lv.step_gillespie_1d(jnp.array([0.6,0.6])) >>> N = 10 >>> T = 5 >>> x0 = jnp.zeros((2,N)) >>> x0 = x0.at[:,int(N/2)].set(lv.m) >>> k0 = jax.random.key(42) >>> jsmfsb.sim_time_series_1d(k0, x0, 0, T, 1, stepLv1d, True)
- jsmfsb.spatial.sim_time_series_2d(key, x0, t0, tt, dt, step_fun, verb=False)
Simulate a model on a regular grid of times, using a function (closure) for advancing the state of the model
This function simulates single realisation of a model on a 2D regular spatial grid and regular grid of times using a function (closure) for advancing the state of the model, such as created by step_gillespie_2d.
Parameters
- key: JAX random number key
Random key to seed the simulation.
- x0array
The initial state of the process at time t0, a 3d array with dimensions corresponding to reacting species and then two corresponding to spatial location.
- t0float
The initial time to be associated with the initial state x0.
- ttfloat
The terminal time of the simulation.
- dtfloat
The time step of the output. Note that this time step relates only to the recorded output, and has no bearing on the accuracy of the simulation process.
- step_funfunction
A function (closure) for advancing the state of the process, such as produced by step_gillespie_2d.
- verbboolean
Output progress to the console (this function can be very slow).
Returns
A 4d array representing the simulated process. The dimensions are species, two space, and time.
Examples
>>> import jsmfsb.models >>> import jax >>> import jax.numpy as jnp >>> lv = jsmfsb.models.lv() >>> stepLv2d = lv.step_gillespie_2d(jnp.array([0.6,0.6])) >>> M = 10 >>> N = 15 >>> T = 5 >>> x0 = jnp.zeros((2,M,N)) >>> x0 = x0.at[:,int(M/2),int(N/2)].set(lv.m) >>> k0 = jax.random.key(42) >>> jsmfsb.sim_time_series_2d(k0, x0, 0, T, 1, stepLv2d, True)
jsmfsb.inference module
- jsmfsb.inference.abc_run(key, n, rprior, rdist, batch_size=None, verb=False)
Run a set of simulations initialised with parameters sampled from a given prior distribution, and compute statistics required for an ABC analaysis
Run a set of simulations initialised with parameters sampled from a given prior distribution, and compute statistics required for an ABC analaysis. Typically used to calculate “distances” of simulated synthetic data from observed data.
Parameters
- key: JAX random number key
Key to initialise the ABC simulation.
- nint
An integer representing the number of simulations to run.
- rpriorfunction
A function with one argument, a JAX random key, generating a single parameter (vector) from a prior distribution.
- rdistfunction
A function with two arguments, a JAX random key, and a parameter (vector). It returns the required statistic of interest. This will typically be computed by first using the parameter to run a forward model, then computing required summary statistics, then computing a distance. See the example for details.
- batch_size: int
batch_size to use in call to jax.lax.map for parallelisation. Defaults to None.
- verbboolean
Print progress information to console? Defaults to False.
Returns
A tuple with first component a matrix of parameters (in rows) and second component a vector of corresponding distances.
Examples
>>> import jsmfsb >>> import jax >>> import jax.numpy as jnp >>> import jax.scipy as jsp >>> k0 = jax.random.key(42) >>> k1, k2 = jax.random.split(k0) >>> data = jax.random.normal(k1, 250)*2 + 5 >>> def rpr(k): >>> return jnp.exp(jax.random.uniform(k, 2, minval=-3, maxval=3)) >>> >>> def rmod(k, th): >>> return jax.random.normal(k, 250)*th[1] + th[0] >>> >>> def sumStats(dat): >>> return jnp.array([jnp.mean(dat), jnp.std(dat)]) >>> >>> ssd = sumStats(data) >>> def dist(ss): >>> diff = ss - ssd >>> return jnp.sqrt(jnp.sum(diff*diff)) >>> >>> def rdis(k, th): >>> return dist(sumStats(rmod(k, th))) >>> >>> jsmfsb.abc_run(k2, 100, rpr, rdis)
- jsmfsb.inference.abc_smc(key, n, rprior, dprior, rdist, rperturb, dperturb, factor=10, steps=15, verb=False, debug=False)
Run an ABC-SMC algorithm for infering the parameters of a forward model
Run an ABC-SMC algorithm for infering the parameters of a forward model. This sequential Monte Carlo algorithm often performs better than simple rejection-ABC in practice.
Parameters
- keyJAX random key
A key to initialise the simulation.
- nint
An integer representing the number of simulations to pass on at each stage of the SMC algorithm. Note that the TOTAL number of forward simulations required by the algorithm will be (roughly) ‘N*steps*factor’.
- rpriorfunction
A function with a single argument, a JAX random key, which generates a single parameter (vector) from the prior.
- dpriorfunction
A function taking a parameter vector as argumnent and returning the log of the prior density.
- rdistfunction
A function with two arguments: a JAX random key and a parameter vector. It should return a scalar “distance” representing a measure of how good the chosen parameter is. This will typically be computed by first using the parameter to run a forward model, then computing required summary statistics, then computing a distance. See the example for details.
- rperturbfunction
A function with two arguments: a JAX random key and a parameter vector. It should return a perturbed parameter from an appropriate kernel.
- dperturbfunction
A function which takes a pair of parameters as its first two arguments (new first and old second), and returns the log of the density associated with this perturbation kernel.
- factorint
At each step of the algorithm, ‘N*factor’ proposals are generated and the best ‘N’ of these are weighted and passed on to the next stage. Note that the effective sample size of the parameters passed on to the next step may be (much) smaller than ‘N’, since some of the particles may be assigned small (or zero) weight. Defaults to 10.
- stepsint
The number of steps of the ABC-SMC algorithm. Typically, somewhere between 5 and 100 steps seems to be used in practice. Defaults to 15.
- verbboolean
Boolean indicating whether some progress should be printed to the console.
Returns
A matrix with rows representing samples from the approximate posterior distribution.
Examples
>>> import jsmfsb >>> import jax >>> import jax.numpy as jnp >>> import jax.scipy as jsp >>> k0 = jax.random.key(42) >>> k1, k2 = jax.random.split(k0) >>> data = jax.random.normal(k1, 250)*2 + 5 >>> def rpr(k): >>> return jnp.exp(jax.random.uniform(k, 2, minval=-3, maxval=3)) >>> >>> def rmod(k, th): >>> return jax.random.normal(k, 250)*jnp.exp(th[1]) + jnp.exp(th[0]) >>> >>> def sumStats(dat): >>> return jnp.array([jnp.mean(dat), jnp.std(dat)]) >>> >>> ssd = sumStats(data) >>> def dist(ss): >>> diff = ss - ssd >>> return jnp.sqrt(jnp.sum(diff*diff)) >>> >>> def rdis(k, th): >>> return dist(sumStats(rmod(k, th))) >>> >>> jsmfsb.abc_smc(k2, 100, rpr, >>> lambda x: jnp.sum(jnp.log(((x<3)&(x>-3))/6)), >>> rdis, >>> lambda k,x: jax.random.normal(k)*0.1 + x, >>> lambda x,y: jnp.sum(jsp.stats.norm.logpdf(y, x, 0.1)))
- jsmfsb.inference.abc_smc_step(key, dprior, prior_sample, prior_lw, rdist, rperturb, dperturb, factor)
Carry out one step of an ABC-SMC algorithm
Not meant to be directly called by users. See abc_smc.
- jsmfsb.inference.metropolis_hastings(key, init, log_lik, rprop, ldprop=<function <lambda>>, ldprior=<function <lambda>>, iters=10000, thin=10, verb=True)
Run a Metropolis-Hastings MCMC algorithm for the parameters of a Bayesian posterior distribution
Run a Metropolis-Hastings MCMC algorithm for the parameters of a Bayesian posterior distribution. Note that the algorithm carries over the old likelihood from the previous iteration, making it suitable for problems with expensive likelihoods, and also for “exact approximate” pseudo-marginal or particle marginal MH algorithms.
Parameters
- key: JAX random number key
A key to seed the simulation.
- initvector
A parameter vector with which to initialise the MCMC algorithm.
- log_lik(stochastic) function
A function which takes two arguments: a JAX random key and a parameter (the same type as init) as its second argument. It should return the log-likelihood of the data. Note that it is fine for this to return the log of an unbiased estimate of the likelihood, in which case the algorithm will be an “exact approximate” pseudo-marginal MH algorithm. This is the reason why the function should accept a JAX random key. In the “vanilla” case, where the log-likelihood is deterministic, the function should simply ignore the key that is passed in.
- rpropstochastic function
A function which takes a random key and a current parameter as its two required arguments and returns a single sample from a proposal distribution.
- ldpropfunction
A function which takes a new and old parameter as its first two required arguments and returns the log density of the new value conditional on the old. Defaults to a flat function which causes this term to drop out of the acceptance probability. It is fine to use the default for _any_ _symmetric_ proposal, since the term will also drop out for any symmetric proposal.
- ldpriorfunction
A function which take a parameter as its only required argument and returns the log density of the parameter value under the prior. Defaults to a flat function which causes this term to drop out of the acceptance probability. People often use a flat prior when they are trying to be “uninformative” or “objective”, but this is slightly naive. In particular, what is “flat” is clearly dependent on the parametrisation of the model.
- itersint
The number of MCMC iterations required (_after_ thinning).
- thinint
The required thinning factor. eg. only store every thin iterations.
- verbboolean
Boolean indicating whether some progress information should be printed to the console. Defaults to True.
Returns
A matrix with rows representing samples from the posterior distribution.
Examples
>>> import jsmfsb >>> import jax >>> import jax.numpy as jnp >>> import jax.scipy as jsp >>> k0 = jax.random.key(42) >>> k1, k2 = jax.random.split(k0) >>> data = jax.random.normal(k1, 250)*2 + 5 >>> llik = lambda k, x: jnp.sum(jsp.stats.norm.logpdf(data, x[0], x[1])) >>> prop = lambda k, x: jax.random.normal(k, 2)*0.1 + x >>> jsmfsb.metropolis_hastings(k2, jnp.array([1.0,1.0]), llik, prop)
- jsmfsb.inference.pf_marginal_ll(n, sim_x0, t0, step_fun, data_ll, data, debug=False)
Create a function for computing the log of an unbiased estimate of marginal likelihood of a time course data set
Create a function for computing the log of an unbiased estimate of marginal likelihood of a time course data set using a simple bootstrap particle filter.
Parameters
- nint
An integer representing the number of particles to use in the particle filter.
- sim_x0function
A function with arguments key, t0 and th, where ‘t0’ is a time at which to simulate from an initial distribution for the state of the particle filter and th is a vector of parameters. The return value should be a state vector randomly sampled from the prior distribution. The function therefore represents a prior distribution on the initial state of the Markov process.
- t0float
The time corresponding to the starting point of the Markov process. Can be no bigger than the smallest observation time.
- step_funfunction
A function for advancing the state of the Markov process, with arguments key, x, t0, deltat and th, with th representing a vector of parameters.
- data_llfunction
A function with arguments x, t, y, th, where x and t represent the true state and time of the process, y is the observed data, and th is a parameter vector. The return value should be the log of the likelihood of the observation. The function therefore represents the observation model.
- datamatrix
A matrix with first column an increasing set of times. The remaining columns represent the observed values of y at those times.
Returns
A function with arguments key and th, representing a parameter vector, which evaluates to the log of the particle filters unbiased estimate of the marginal likelihood of the data (for parameter th).
Examples
>>> import jax >>> import jax.numpy as jnp >>> import jax.scipy as jsp >>> import jsmfsb >>> def obsll(x, t, y, th): >>> return jnp.sum(jsp.stats.norm.logpdf(y-x, scale=10)) >>> >>> def simX(key, t0, th): >>> k1, k2 = jax.random.split(key) >>> return jnp.array([jax.random.poisson(k1, 50), >>> jax.random.poisson(k2, 100)]).astype(jnp.float32) >>> >>> def step(key, x, t, dt, th): >>> sf = jsmfsb.models.lv(th).step_gillespie() >>> return sf(key, x, t, dt) >>> >>> mll = jsmfsb.pf_marginal_ll(80, simX, 0, step, obsll, jsmfsb.data.lv_noise_10) >>> k0 = jax.random.key(42) >>> mll(k0, jnp.array([1, 0.005, 0.6])) >>> mll(k0, jnp.array([2, 0.005, 0.6]))