Lightweight MultiprocessingΒΆ

Download: example_16_multiproc_context.py

This example shows you how to use a MultiprocContext.

import logging
import multiprocessing as mp
import os

from pypet import MultiprocContext, Trajectory


def manipulate_multiproc_safe(traj):
    """Target function that manipulates the trajectory.

    Stores the current name of the process into the trajectory and
    **overwrites** previous settings.

    :param traj:

        Trajectory container with multiprocessing safe storage service

    """

    # Manipulate the data in the trajectory
    traj.last_process_name = mp.current_process().name
    # Store the manipulated data
    traj.results.f_store(store_data=3)  # Overwrites data on disk
    # Not recommended, here only for demonstration purposes :-)


def main():
    # We don't use an environment so we enable logging manually
    logging.basicConfig(level=logging.INFO)

    filename = os.path.join("hdf5", "example_16.hdf5")
    traj = Trajectory(filename=filename, overwrite_file=True)

    # The result that will be manipulated
    traj.f_add_result(
        "last_process_name",
        "N/A",
        comment="Name of the last process that manipulated the trajectory",
    )

    with MultiprocContext(trajectory=traj, wrap_mode="LOCK") as _mc:
        # The multiprocessing context manager wraps the storage service of the trajectory
        # and passes the wrapped service to the trajectory.
        # Also restores the original storage service in the end.
        # Moreover, wee need to use the `MANAGER_LOCK` wrapping because the locks
        # are pickled and send to the pool for all function executions

        # Start a pool of processes manipulating the trajectory
        iterable = (traj for x in range(50))
        pool = mp.Pool(processes=4)
        # Pass the trajectory and the function to the pool and execute it 20 times
        pool.map_async(manipulate_multiproc_safe, iterable)
        pool.close()
        # Wait for all processes to join
        pool.join()

    # Reload the data from disk and overwrite the existing result in RAM
    traj.results.f_load(load_data=3)
    # Print the name of the last process the trajectory was manipulated by
    print("The last process to manipulate the trajectory was: `%s`" % traj.last_process_name)


if __name__ == "__main__":
    main()