Minimal usage example
This notebook provides a minimal example on how to use the Haiku DimeNet++ model. For more real-world applications of the DimeNet++ model in MD simulations, please refer to the DiffTRe repository.
from functools import partial
import warnings
from jax import random, numpy as jnp
from jax_md import space, partition
import numpy as onp
from jax_dimenet import dimenet, sparse_graph
warnings.filterwarnings('ignore') # disable warnings about float64 usage
Example molecular state
We build a molecular snapshot as input to the DimeNet++ model.
r_cut = 0.5 # cut-off for DimeNet++ graph connectivity and neighbor list
side_length = 3.
particles_per_side = 10
box = jnp.ones(3) * side_length
positions = onp.stack([onp.array(r) for r in onp.ndindex(particles_per_side,
particles_per_side,
particles_per_side)]
)
positions = jnp.array(positions) * side_length / particles_per_side
displacement_fn, shift = space.periodic(box)
neighbor_fn = partition.neighbor_list(displacement_fn, box,
r_cut, dr_threshold=0.05)
neighbors = neighbor_fn.allocate(positions)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
DimeNet++ energy function
Now we want to use the DimeNet++ model as a Jax M.D. energy_fn, e.g. to run MD simulations.
init_fn, dimenet_energy_fn = dimenet.energy_neighborlist(displacement_fn, r_cut)
init_params = init_fn(random.PRNGKey(0), positions, neighbor=neighbors)
energy_fn = partial(dimenet_energy_fn, init_params) # jax_md energy_fn interface
print('Predicted energy:', energy_fn(positions, neighbors))
Predicted energy: 24.967882
DimeNet++ property prediction
Finally, we can also predict global molecular properties. In this case, we can pre-compute the molecular graph.
mol_graph, _ = sparse_graph.sparse_graph_from_neighborlist(
displacement_fn, positions, neighbors, r_cut)
init_fn, property_predictor = dimenet.property_prediction(r_cut, n_targets=5)
init_params = init_fn(random.PRNGKey(0), mol_graph)
print('Predicted properties:', property_predictor(init_params, mol_graph))
Predicted properties: [ -6.040734 -112.69328 185.22845 54.96091 -210.01091 ]