Minimal usage example

This notebook provides a minimal example on how to use the Haiku DimeNet++ model. For more real-world applications of the DimeNet++ model in MD simulations, please refer to the DiffTRe repository.

from functools import partial
import warnings

from jax import random, numpy as jnp
from jax_md import space, partition
import numpy as onp

from jax_dimenet import dimenet, sparse_graph
warnings.filterwarnings('ignore')  # disable warnings about float64 usage

Example molecular state

We build a molecular snapshot as input to the DimeNet++ model.

r_cut = 0.5  # cut-off for DimeNet++ graph connectivity and neighbor list
side_length = 3.
particles_per_side = 10
box = jnp.ones(3) * side_length

positions = onp.stack([onp.array(r) for r in onp.ndindex(particles_per_side,
                                                         particles_per_side,
                                                         particles_per_side)]
                      )
positions = jnp.array(positions) * side_length / particles_per_side
displacement_fn, shift = space.periodic(box)
neighbor_fn = partition.neighbor_list(displacement_fn, box,
                                      r_cut, dr_threshold=0.05)
neighbors = neighbor_fn.allocate(positions)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

DimeNet++ energy function

Now we want to use the DimeNet++ model as a Jax M.D. energy_fn, e.g. to run MD simulations.

init_fn, dimenet_energy_fn = dimenet.energy_neighborlist(displacement_fn, r_cut)
init_params = init_fn(random.PRNGKey(0), positions, neighbor=neighbors)
energy_fn = partial(dimenet_energy_fn, init_params)  # jax_md energy_fn interface
print('Predicted energy:', energy_fn(positions, neighbors))
Predicted energy: 24.967882

DimeNet++ property prediction

Finally, we can also predict global molecular properties. In this case, we can pre-compute the molecular graph.

mol_graph, _ = sparse_graph.sparse_graph_from_neighborlist(
    displacement_fn, positions, neighbors, r_cut)
init_fn, property_predictor = dimenet.property_prediction(r_cut, n_targets=5)
init_params = init_fn(random.PRNGKey(0), mol_graph)
print('Predicted properties:', property_predictor(init_params, mol_graph))
Predicted properties: [  -6.040734 -112.69328   185.22845    54.96091  -210.01091 ]