Metadata-Version: 2.1
Name: IMNN
Version: 0.3.1
Summary: Using neural networks to extract sufficient statistics from         data by maximising the Fisher information
Home-page: https://bitbucket.org/tomcharnock/imnn.git
Author: Tom Charnock
Author-email: charnock@iap.fr
License: UNKNOWN
Platform: UNKNOWN
Classifier: Programming Language :: Python :: 3
Classifier: License :: OSI Approved :: MIT License
Classifier: Operating System :: OS Independent
Requires-Python: >=3.6
Description-Content-Type: text/x-rst
License-File: LICENSE

IMNN: Information maximising neural networks
============================================

|doc| |pypi| |bit| |git| |doi| |zen|

The IMNN is a statistical method for transformation and compression of data from complex, high-dimensional distributions down to the number of physical parameters in the model which generates that data. It is asymptotically lossless in terms of information about the physical parameters. The method uses neural networks as a backbone for the transformation although any parameterised transformation with enough flexibility to map the data could be used.

Using simulations generated from the model the Fisher information (for a Gaussian distribution with parameter independent covariance) is calculated from the output of the transformation and its log determinant is maximised under the condition that the covariance of the outputted transformed simulations are approximately constant and approach the identity matrix.

Since the Fisher information only needs to be evaluated at a single fiducial parameter value it is exceptionally cheap to fit in comparison to other types of neural networks at the expense that the information is extracted from the data optimally about that fiducial parameter value and that the gradients of the simulations are needed (although this can be done numerically).

The ideal situation for performing inference is to fit an IMNN at a fiducial choice of parameters, evaluate the quasi-maximum-likelihood estimate of the parameters for the target data of choice and retrain a new IMNN at this estimate (still very cheap). This can be repeated iteratively until the IMNN is being trained at the maximum likelihood values of the parameters which will be most sensitive to the features in the data at that point and allow for the most precise constraints.

Note that the parameter estimates from the IMNN are not intended to be indicative of unbiased parameter estimates if the fiducial parameter choice is far from the maximum likelihood parameter values from some target. Instead they are meant to be used in a likelihood-free (simulation-based) inference scenario.



Using the IMNN
--------------

Implemented in `JAX <https://jax.readthedocs.io/en/latest/>`_, the IMNN module provides a way to setup the fitting algorithm and fit an IMNN and then use it to make parameter estimates. There are several scenarios for the types of inputs that can be used.

SimulatorIMNN
_____________

Using a simulator (written in JAX or at least XLA compilable) for the generation of data provides the best results since generation of data is done on-the-fly. Since the focus of the IMNN is to maximise the information about the parameters from the data from the model, having a limited dataset means that spurious accidental correlations will almost certainly be learned about, overestimating the amount of information that can be extracted. For this reason, if using a limited sized training set like with `GradientIMNN`_ or `NumericalGradientIMNN`_, then it is important to use a validation set and use early stopping based on the information extracted from this validation set. All these problems are side-stepped if simulations are generated on the fly since the same data is never seen multiple times. Of course, if the simulator is expensive then this may not be feasible.

The simulator needs to be of the form

.. code-block:: python

    def simulator(rng, parameters):
        """ Simulate a realisation of the data

        Parameters
        ----------
        rng : int(2,)
            A jax stateless random number generator
        parameters : float(n_params,)
            parameter values at which to generate the simulations at

        Returns
        -------
        float(input_shape):
            a simulation generated at parameter values `parameters`
        """
        # Do stuff
        return simulation

GradientIMNN
____________

If simulations are expensive but a set of simulations with their gradients can be precalculated then it is possible to use these instead to train the IMNN. In this case the simulations are passed through the network and the Jacobian of the network outputs are calculated with respect to their inputs which using the chain rule can be combined with the gradient of the simulation with respect to the model parameters to get the Jacobian of the network outputs with respect to the physical model parameters, used for calculating the Fisher information. Although possible to not use a validation set when fitting it is highly recommended to use early stopping with a validation set to avoid overestimation of the amount of information extracted.

The simulations and their gradients should be of the form

.. code-block:: python

    import jax
    from imnn.utils import value_and_jacrev

    rng = jax.random.PRNGKey(0)


    fiducial_parameters = # fiducial parameter values as a float(n_params,)
    n_s = # number of simulations used to estimate covariance of network outputs

    rng, *keys = jax.random.split(rng, num=n_s + 1)
    fiducial, derivative = jax.vmap(
        lambda key: value_and_jacrev(
            simulator,
            argnums=1)(key, fiducial_parameters))(
        np.array(keys))

    fiducial.shape
    >>> (n_s, input_shape)

    derivative.shape
    >>> (n_s, input_shape, n_params)

Note that ``n_s`` derivatives are necessarily needed since only the mean of the derivatives is calculated which is more stable than the covariance. Therefore only ``n_d`` < ``n_s`` are required, although most stable optimisation is achieved using ``n_d = n_s``.

NumericalGradientIMNN
_____________________

If the gradient of the simulations with respect to the physical model parameters is not possible then numerical derivatives can be done. In this case simulations are made at the fiducial parameter value and then varied slightly with respect to each parameter independently with each of these simulations made at the same seed. Theses varied simulations are passed through the network and the outputs are used to make a numerical estimate via finite differences. There is quite a lot of fitting optimisation sensitivity to the choice of the finite difference size. Note that, again, it is VERY highly recommended to use a validation set for early stopping to prevent overestimation of the amount of information that can be extracted and the extraction of information from spurious features only existing in the limited dataset.

The simulations and their numerical derivatives should be made something like:

.. code-block:: python

    import jax
    import jax.numpy as np

    rng = jax.random.PRNGKey(0)


    fiducial_parameters = # fiducial parameter values as a float(n_params,)
    parameter_differences = # differences between varied parameter values for
                            # finite differences as a float(n_params,)
    n_s = # number of simulations used to estimate covariance of network outputs
    n_d = # number of simulations used to estimate the numerical derivative of
          # the mean of the network outputs

    rng, *keys = jax.random.split(rng, num=n_s + 1)

    fiducial = jax.vmap(
        lambda key: simulator(key, fiducial_parameters))(
        np.array(keys))

    varied_below = (fiducial_parameters - np.diag(parameter_differences) / 2)
    varied_above = (fiducial_parameters + np.diag(parameter_differences) / 2)

    below_fiducial = jax.vmap(
      lambda key: jax.vmap(
          lambda param: simulator(key, param))(varied_below))(
      np.array(keys)[:n_d])
    above_fiducial = jax.vmap(
      lambda key: jax.vmap(
          lambda param: simulator(key, param))(varied_above))(
      np.array(keys)[:n_d])

    derivative = np.stack([below_fiducial, above_fiducial], 1)

    fiducial.shape
    >>> (n_s, input_shape)

    derivative.shape
    >>> (n_s, 2, n_params, input_shape)

Matching seeds across pairs of varied parameters is fairly important for efficient training - stateless simulating like above makes this much easier.

Aggregation
___________

If simulations or networks are very large then it can be difficult to fit an IMNN since the Fisher information requires the covariance to be well approximated to be able to maximise it. This means that all of the simulations must be passed through the network before doing a backpropagation step. To help with this, aggregation of computation and accumulated gradients are implemented. In this framework a list of XLA devices is passed to the IMNN class and data is passed to each device (via TensorFlow dataset iteration) to calculate the network outputs (and their derivatives using any of the `SimulatorIMNN`_, `GradientIMNN`_ or `NumericalGradientIMNN`_). These outputs are relatively small in size and so the gradient of the loss function (covariance regularised log determinant of the Fisher information) can be calculated easily. All of the data is then passed through the network again (a small number [``n_per_device``] at a time) and the Jacobian of the network outputs with respect to the neural network parameters is calculated. The chain rule is then used to combine these with the gradient of the loss function with respect to the network outputs to get the gradient of the loss function with respect to the network parameters. These gradients are summed together ``n_per_device`` at a time until a single gradient pytree for each parameter in the network is obtained which is then passed to the optimiser to implement the backpropagation. This requires two passes of the data through the network per iteration which is expensive, but is currently the only way to implement this for large data inputs which do not fit into memory. If the whole computation does fit in memory then there will be orders of magnitudes speed up compared to aggregation. However, aggregation can be done over as many XLA devices are available which should help things a bit. It is recommended to process as many simulations as possible at once by setting ``n_per_device`` to as large a value as can be handled. All central operations are computed on a host device which should be easily accessible (in terms of I/O) from all the other devices.

.. code-block:: python

    import jax

    host = jax.devices("cpu")[0]
    devices = jax.devices("gpu")
    n_per_device = # number as high as makes sense for the size of data

TensorFlow Datasets
___________________

If memory is really tight and data needs to be loaded from disk then it is possible to use TensorFlow Datasets to do this, but the datasets must be EXTREMELY specifically made. There are examples in the ``examples`` directly, but shortly there are two different variants, the ``DatasetGradientIMNN`` and the ``DatasetNumericalGradientIMNN``. For the ``DatasetNumericalGradientIMNN`` the datasets must be of the form

.. code-block:: python

    import tensorflow as tf

    fiducial = [
            tf.data.TFRecordDataset(
                    sorted(glob.glob("fiducial_*.tfrecords")),
                    num_parallel_reads=1
                ).map(writer.parser
                ).skip(i * n_s // n_devices
                ).take(n_s // n_devices
                ).batch(n_per_device
                ).repeat(
                ).as_numpy_iterator()
            for i in range(n_devices)]

    derivative = [
        tf.data.TFRecordDataset(
                sorted(glob.glob("derivative_*.tfrecords")),
                num_parallel_reads=1
            ).map(writer.parser
            ).skip(i * 2 * n_params * n_d // n_devices
            ).take(2 * n_params * n_d // n_devices
            ).batch(n_per_device
            ).repeat(
            ).as_numpy_iterator()
        for i in range(n_devices)]

Here the ``tfrecords`` contains the simulations which are parsed by the ``writer.parser`` (there is a demonstration in ``imnn.TFRecords``). The simulations are split into ``n_devices`` different datasets each which contain ``n_s // n_devices`` simulations which are passed to the network ``n_per_device`` at a time and repeated and not shuffled. For derivative, because there are multiple simulations at each seed for the finite differences then ``2 * n_params * n_d // n_devices`` need to be available to each device before passing ``n_per_device`` to the network on each device.

For the ``DatasetGradientIMNN`` the loops are made quicker by separating the derivative and simulation calculation from the simulation only calculations (the difference between ``n_s`` and ``n_d``). In this case the datasets must be constructed like:

.. code-block:: python

    fiducial = [
        tf.data.TFRecordDataset(
                sorted(glob.glob("fiducial_*.tfrecords")),
                num_parallel_reads=1
            ).map(writer.parser
            ).skip(i * n_s // n_devices
            ).take(n_s // n_devices)
        for i in range(n_devices)]

    main = [
        tf.data.Dataset.zip((
            fiducial[i],
            tf.data.TFRecordDataset(
                sorted(glob.glob("derivative_*.tfrecords")),
                num_parallel_reads=1).map(
                    lambda example: writer.derivative_parser(
                        example, n_params=n_params)))
            ).take(n_d // n_devices
            ).batch(n_per_device
            ).repeat(
            ).as_numpy_iterator()
        for i in range(n_devices)]

    remaining = [
        fiducial[i].skip(n_d // n_devices
            ).take((n_s - n_d) // n_devices
            ).batch(n_per_device
            ).repeat(
            ).as_numpy_iterator()
        for i in range(n_devices)]

Note that using datasets can be pretty tricky, aggregated versions of `GradientIMNN`_ and `NumericalGradientIMNN`_ does all the hard work as long as the data can be fit in memory.

Neural models
_____________

The IMNN is designed with `stax <https://github.com/google/jax/blob/master/jax/experimental/README.md#neural-net-building-with-stax>`_-like models and `jax optimisers <https://github.com/google/jax/blob/master/jax/experimental/README.md#First-order-optimization>`_ which are very flexible and designed to be quickly developed. Note that these modules don't need to be used exactly, but they should look like them. Models should contain

.. code-block:: python

    def initialiser(rng, input_shape):
        """ Initialise the parameters of the model

        Parameters
        ----------
        rng : int(2,)
            A jax stateless random number generator
        input_shape : tuple
            The shape of the input to the network

        Returns
        -------
        tuple:
            The shape of the output of the network
        pytree (list or tuple):
            The values of the initialised parameters of the network
        """
        # Do stuff
        return output_shape, initialised_parameters

    def apply_model(parameters, inputs):
        """ Passes inputs through the network

        Parameters
        ----------
        parameters : pytree (list or tuple)
            The values of the parameters of the network
        inputs : float(input_shape)
            The data to put through the network

        Returns
        -------
        float(output_shape):
            The output of the network
        """
        # Do neural networky stuff
        return output

    model = (initialiser, apply_model)

The optimiser also doesn't specifically need to be a ``jax.experimental.optimizer``, but it must contain

.. code-block:: python

    def initialiser(initial_parameters):
        """ Initialise the state of the optimiser

        Parameters
        ----------
        parameters : pytree (list or tuple)
            The initial values of the parameters of the network

        Returns
        -------
        pytree (list or tuple) or object:
            The initial state of the optimiser containing everything needed to
            update the state, i.e. current state, the running mean of the
            weights for momentum-like optimisers, any decay rates, etc.
        """
        # Do stuff
        return state

    def updater(it, gradient, state):
        """ Updates state based on current iteration and calculated gradient

        Parameters
        ----------
        it : int
            A counter for the number of iterations
        gradient : pytree (list or tuple)
            The gradients of the parameters to update
        state : pytree (list or tuple) or object
            The state of the optimiser containing everything needed to update
            the state, i.e. current state, the running mean of the weights for
            momentum-like optimisers, any decay rates, etc.

        Returns
        -------
        pytree (list or tuple) or object:
            The updated state of the optimiser containing everything needed to
            update the state, i.e. current state, the running mean of the
            weights for momentum-like optimisers, any decay rates, etc.
        """
        # Do updating stuff
        return updated_state

    def get_parameters(state):
        """ Returns the values of the parameters at the current state

        Parameters
        ----------
        state : pytree (list or tuple) or object
            The current state of the optimiser containing everything needed to
            update the state, i.e. current state, the running mean of the
            weights for momentum-like optimisers, any decay rates, etc.

        Returns
        -------
        pytree (list or tuple):
            The current values of the parameters of the network
        """
        # Get parameters
        return current_parameters

    optimiser = (initialiser, updater, get_parameters)

IMNN
____

Because there are many different cases where we might want to use different types of IMNN subclasses. i.e. with a simulator, aggregated over GPUs, using numerical derivatives, etc. then there is a handy single function will try and return the intended subclass. This is

.. code-block:: python

    import imnn

    IMNN = imnn.IMNN(
      n_s,                        # number of simulations for covariance
      n_d,                        # number of simulations for derivative mean
      n_params,                   # number of parameters in physical model
      n_summaries,                # number of outputs from the network
      input_shape,                # the shape a single input simulation
      θ_fid,                      # the fiducial parameter values for the sims
      model,                      # the stax-like model
      optimiser,                  # the jax optimizers-like optimiser
      key_or_state,               # either a random number generator or a state
      simulator=None,             # SimulatorIMNN simulations on-the-fly
      fiducial=None,              # GradientIMNN or NumericalGradientIMNN sims
      derivative=None,            # GradientIMNN or NumericalGradientIMNN ders
      main=None,                  # DatasetGradientIMNN sims and derivatives
      remaining=None,             # DatasetGradientIMNN simulations
      δθ=None,                    # NumericalGradientIMNN finite differences
      validation_fiducial=None,   # GradientIMNN or NumericalGradientIMNN sims
      validation_derivative=None, # GradientIMNN or NumericalGradientIMNN ders
      validation_main=None,       # DatasetGradientIMNN sims and derivatives
      validation_remaining=None,  # DatasetGradientIMNN simulations
      host=None,                  # Aggregated.. host computational device
      devices=None,               # Aggregated.. devices for running network
      n_per_device=None,          # Aggregated.. amount of data to pass at once
      cache=None,                 # Aggregated.. whether to cache simulations
      prefetch=None,)             # Aggregated.. whether to prefetch sims

So for example to initialise an ``AggregatedSimulatorIMNN`` and train it we can do

.. code-block:: python

    rng, key = jax.random.split(rng)
    IMNN = imnn.IMNN(n_s, n_d, n_params, n_summaries, input_shape,
                     fiducial_parameters, model, optimiser, key,
                     simulator=simulator, host=host, devices=devices,
                     n_per_device=n_per_device)

    rng, key = jax.random.split(rng)
    IMNN.fit(λ=10., ϵ=0.1, rng=key)
    IMNN.plot(expected_detF=50.)

.. image:: https://bitbucket.org/tomcharnock/imnn/raw/master/docs/_images/history_plot.png

Or for a`` NumericalGradientIMNN``

.. code-block:: python

    rng, key = jax.random.split(rng)
    IMNN = imnn.IMNN(n_s, n_d, n_params, n_summaries, input_shape,
                     fiducial_parameters, model, optimiser, key,
                     fiducial=fiducial, derivative=derivative,
                     δθ=parameter_differences,
                     validation_fiducial=validation_fiducial,
                     validation_derivative=validation_derivative)

    IMNN.fit(λ=10., ϵ=0.1)
    IMNN.plot(expected_detF=50.)

``λ`` and ``ϵ`` control the strength of regularisation and should help with speed of convergence but not really impact the final results.

Doing likelihood-free inference
-------------------------------

With a trained IMNN it is possible to get an estimate of some data using

.. code-block:: python

    estimate = IMNN.get_estimate(target_data)

Along with the Fisher information from the network, we can use this to make a Gaussian approximation of the posterior under the assumption that the fiducial parameter values used to calculate the Fisher information coincide with the parameter estimate. This posterior can be calculated using the ``imnn.lfi`` module. For all of the available functions in the lfi module a `TensorFlow Probability <https://www.tensorflow.org/probability/>`_-like distribution is used for the prior, e.g. for a uniform distribution for two parameters between 0 and 10 each we could write

.. code-block:: python

    import tensorflow_probability
    tfp = tensorflow_probability.substrates.jax

    prior = tfp.distributions.Blockwise(
    [tfp.distributions.Uniform(low=low, high=high)
     for low, high in zip([0., 0.], [10., 10.])])

     prior.low = np.array([0., 0.])
     prior.high = np.array([10., 10.])

We set the values of ``prior.low`` and ``prior.high`` since they are used to define the plotting ranges. Note that ``prior.event_shape`` should be equal to ``n_params``, i.e. the number of parameters in the physical model.

GaussianApproximation
_____________________

The ``GaussianApproximation`` simply evaluates a multivariate Gaussian with mean at ``estimate`` and covariance given by ``np.linalg.inv(IMNN.F)`` on a grid defined by the prior ranges.

.. code-block:: python

    GA = imnn.lfi.GaussianApproximation(
      parameter_estimates=estimate,
      invF=np.linalg.inv(IMNN.F),
      prior=prior,
      gridsize=100)

And corner plots of the Gaussian approximation can be made using

.. code-block:: python

    GA.marginal_plot(
      ax=None,                   # Axes object to plot (constructs new if None)
      ranges=None,               # Ranges for each parameter (None=preset)
      marginals=None,            # Marginal distributions to plot (None=preset)
      known=None,                # Plots known parameter values if not None
      label=None,                # Adds legend element if not None
      axis_labels=None,          # Adds labels to the axes if not None
      levels=None,               # Plot specified approx significance contours
      linestyle="solid",         # Linestyle for the contours
      colours=None,              # Colour for the contours
      target=None,               # If multiple target data, which index to plot
      format=False,              # Whether to set up the plot decoration
      ncol=2,                    # Number of columns in the legend
      bbox_to_anchor=(1.0, 1.0)) # Where to place the legend

Note that this approximation shouldn't be necessarily a good estimate of the true posterior, for that actual LFI methods should be used.

ApproximateBayesianComputation
______________________________

To generate simulations and accept or reject these simulations based on a distance based criterion from some target data we can use

.. code-block:: python

    ABC = imnn.lfi.ApproximateBayesianComputation(
      target_data=target_data,
      prior=prior,
      simulator=simulator,
      compressor=IMNN.get_estimate,
      gridsize=100,
      F=IMNN.F,
      distance_measure=None)

This takes in the target data and compresses it using the provided compressor (like ``IMNN.get_estimate``). The Fisher information matrix can be provided to rescale the parameter directions to make meaningful distance measurements as long as summaries are parameter estimates. If a different distance measure is better for the specific problem this can be passed as a function. Note that if simulations have already been done for the ABC and only the plotting and the acceptance and rejection is needed then ``simulator`` can be set to ``None``. The ABC can actually be run by calling the module

.. code-block:: python

    parameters, summaries, distances = ABC(
        ϵ=None,             # The size of the epsilon ball to accept summaries
        rng=None,           # Random number generator for params and simulation
        n_samples=None,     # The number of samples to run (at one time)
        parameters=None,    # Values of parameters with premade compressed sims
        summaries=None,     # Premade compressed sims to avoid running new sims
        min_accepted=None,  # Num of required sims in epsilon ball (iterative)
        max_iterations=10,  # Max num of iterations to try and get min_accepted
        smoothing=None,     # Amount of smoothing on the histogrammed marginals
        replace=False)      # Whether to remove all previous run summaries

If not run and no premade simulations have been made then ``n_samples`` and ``rng`` must be passed. Note that if ``ϵ`` is too large then the accepted samples should not be considered to be drawn from the posterior but rather some partially marginalised part of the joint distribution of summaries and parameters, and hence it can be very misleading - ``ϵ`` should be a small as possible! Like with the ``GaussianApproximation`` there is a ``ABC.marginal_plot(...)`` but the parameter samples can also be plotted as a scatter plot on the corner plot

.. code-block:: python

    ABC.scatter_plot(
        ax=None,                 # Axes object to plot (constructs new if None)
        ranges=None,             # Ranges for each parameter (None=preset)
        points=None,             # Parameter values to scatter (None=preset)
        label=None,              # Adds legend element if not None
        axis_labels=None,        # Adds labels to the axes if not None
        colours=None,            # Colour for the scatter points (and hists)
        hist=True,               # Whether to plot 1D histograms of points
        s=5,                     # Marker size for points
        alpha=1.,                # Amount of transparency for the points
        figsize=(10, 10),        # Size of the figure if not premade
        linestyle="solid",       # Linestyle for the histograms
        target=None,             # If multiple target data, which index to plot
        ncol=2,                  # Number of columns in the legend
        bbox_to_anchor=(0., 1.)) # Where to place the legend

And the summaries can also be plotted on a corner plot with exactly the same parameters as ``scatter_plot`` (apart from ``gridsize`` being added) but if ``points`` is left ``None`` then ``ABC.summaries.accepted`` is used instead and the ranges calculated from these values. If points is supplied but ``ranges`` is None then the ranges are calculated from the minimum and maximum values of the points are used as the edges.

.. code-block:: python

    ABC.scatter_summaries(
        ax=None,
        ranges=None,
        points=None,
        label=None,
        axis_labels=None,
        colours=None,
        hist=True,
        s=5,
        alpha=1.,
        figsize=(10, 10),
        linestyle="solid",
        gridsize=100,
        target=None,
        format=False,
        ncol=2,
        bbox_to_anchor=(0.0, 1.0))

PopulationMonteCarlo
____________________

To more efficiently accept samples than using a simple ABC where samples are drawn from the prior, the `PopulationMonteCarlo`_ provides a JAX accelerated iterative acceptance and rejection scheme where each iteration the population of samples with summaries closest to the summary of the desired target defines a new proposal distribution to force a fixed population to converge towards the posterior without setting an explicit size for the epsilon ball of normal ABC. The PMC is stopped using a criterion on the number of accepted proposals compared to the number of total draws from the proposal. When this gets very small it suggests the distribution is stationary and that the proposal has been reached. It works similarly to `ApproximateBayesianComputation`_.

.. code-block:: python

    PMC = imnn.lfi.PopulationMonteCarlo(
      target_data=target_data,
      prior=prior,
      simulator=simulator,
      compressor=IMNN.get_estimate,
      gridsize=100,
      F=IMNN.F,
      distance_measure=None)

And it can be run using

.. code-block:: python

    parameters, summaries, distances = PMC(
        rng,                     # Random number generator for params and sims
        n_points,                # Number of points from the final distribution
        percentile=None,         # Percentage of points making the population
        acceptance_ratio=0.1,    # Fraction of accepted draws vs total draws
        max_iteration=10,        # Maximum number of iterations of the PMC
        max_acceptance=1,        # Maximum number of tries to get an accepted
        max_samples=int(1e5),    # Maximum number of attempts to get parameter
        n_initial_points=None,   # Number of points in the initial ABC step
        n_parallel_simulations=None, # Number of simulations to do in parallel
        proposed=None,           # Prerun parameter values for the initial ABC
        summaries=None,          # Premade compressed simulations for ABC
        distances=None,          # Precalculated distances for the initial ABC
        smoothing=None,          # Amount of smoothing on histogrammed marginal
        replace=False)           # Whether to remove all previous run summaries

The same plotting functions as `ApproximateBayesianComputation`_ are also available in the PMC

Note
....

There seems to be a bug in PopulationMonteCarlo and the parallel sampler is turned off

Installation
------------

The IMNN can be install by cloning the repository and installing via python or by pip installing, i.e.

.. code-block::

    git clone https://bitbucket.org/tomcharnock/imnn.git
    cd imnn
    python setup.py install

or

.. code-block::

    pip install IMNN

Notes on installation
_____________________

The IMNN was quite an early adopter of JAX and as such it uses some experimental features. It is known to be working with ``jax>=0.2.10,<=0.2.12`` and should be fine with newer versions for a while. One of the main limitations is with the use of TensorFlow Probability in the LFI module which also depends on JAX but is also dealing with the development nature of this language. The TensorFlow Datasets also requires TensorFlow>=2.1.0, but this requirement is not explicitly set so that python3.9 users can install a newer compatible version of TensorFlow without failing.

During the development of this code I implemented the value_and_jac* functions in JAX, which saves a huge amount of time for the IMNN, but these had not yet been pulled into the JAX api and as such there is a copy of these functions in ``imnn.utils.jac`` but they depend on ``jax.api`` and other functions which may change with jax development. If this becomes a problem then it will be necessary to install jax and jaxlib first, i.e. via

.. code-block::

    pip install jax==0.2.11 jaxlib==0.1.64

or whichever CUDA enabled version suits you.

The previous version of the IMNN is still available (and works well) built on a TensorFlow backend. If you want to use keras models, etc. it will probably be easier to use that. It is not as complete as this module, but is likely to be a bit more stable due to not depending on JAXs development as heavily. This can be installed via either

.. code-block::

    git clone https://bitbucket.org/tomcharnock/imnn-tf.git
    cd imnn-tf
    python setup.py install

or

.. code-block::

    pip install imnn-tf

Note that that code isn't as well documented, but there are plenty of examples still.


References
----------

If you use this code please cite

.. code-block::

    @article{charnock2018,
      author={Charnock, Tom and Lavaux, Guilhem and Wandelt, Benjamin D.},
      title={Automatic physical inference with information maximizing neural networks},
      volume={97},
      ISSN={2470-0029},
      url={http://dx.doi.org/10.1103/PhysRevD.97.083004},
      DOI={10.1103/physrevd.97.083004},
      number={8},
      journal={Physical Review D},
      publisher={American Physical Society (APS)},
      year={2018},
      month={Apr}
    }

and maybe also

.. code-block::

    @software{imnn2021,
      author = {Tom Charnock},
      title = {{IMNN}: Information maximising neural networks},
      url = {http://bitbucket.org/tomcharnock/imnn},
      version = {0.3.0},
      year = {2021},
    }

.. |doc| image:: https://bitbucket.org/tomcharnock/imnn/raw/master/docs/_images/doc.svg
    :target: https://www.aquila-consortium.org/doc/imnn/

.. |pypi| image:: https://bitbucket.org/tomcharnock/imnn/raw/master/docs/_images/pypi.svg
    :target: https://pypi.org/project/IMNN/

.. |bit| image:: https://bitbucket.org/tomcharnock/imnn/raw/master/docs/_images/bit.svg
    :target: https://bitbucket.org/tomcharnock/imnn

.. |git| image:: https://bitbucket.org/tomcharnock/imnn/raw/master/docs/_images/git.svg
    :target: https://github.com/tomcharnock/imnn

.. |doi| image:: https://zenodo.org/badge/DOI/10.1103/PhysRevD.97.083004.svg
   :target: https://doi.org/10.1103/PhysRevD.97.083004

.. |zen| image:: https://zenodo.org/badge/DOI/10.5281/zenodo.1175196.svg
   :target: https://doi.org/10.5281/zenodo.1175196


