DimeNet++
Neural network models for potential energy and molecular property prediction.
The DimeNetPP
directly takes a
SparseDirectionalGraph
as input and predicts
per-atom quantities. As DimeNetPP
is a haiku Module, it needs to be
wrapped inside a hk.transform() before it can be applied.
We provide 2 interfaces to DimeNet++:
The function energy_neighborlist()
serves as a interface to Jax M.D.
The resulting apply function can be directly used as a jax_md energy_fn,
e.g. to run molecular dynamics simulations.
For direct prediction of global molecular properties,
property_prediction()
can be used.
Haiku model
- class jax_dimenet.dimenet.DimeNetPP(r_cutoff, n_species, num_targets, embed_size=128, n_interaction_blocks=4, num_residual_before_skip=1, num_residual_after_skip=2, out_embed_size=None, type_embed_size=None, angle_int_embed_size=None, basis_int_embed_size=8, num_dense_out=3, num_rbf=6, num_sbf=7, activation=<CompiledFunction of <function silu>>, envelope_p=6, init_kwargs=None, name='DimeNetPP')[source]
DimeNet++ for molecular property prediction.
This model takes as input a sparse representation of a molecular graph - consisting of pairwise distances and angular triplets - and predicts per-atom properties. Global properties can be obtained by summing over per-atom predictions.
The default values correspond to the orinal values of DimeNet++.
This custom implementation follows the original DimeNet / DimeNet++ (https://arxiv.org/abs/2011.14115), while correcting for known issues (see https://github.com/klicperajo/dimenet).
- __init__(r_cutoff, n_species, num_targets, embed_size=128, n_interaction_blocks=4, num_residual_before_skip=1, num_residual_after_skip=2, out_embed_size=None, type_embed_size=None, angle_int_embed_size=None, basis_int_embed_size=8, num_dense_out=3, num_rbf=6, num_sbf=7, activation=<CompiledFunction of <function silu>>, envelope_p=6, init_kwargs=None, name='DimeNetPP')[source]
Initializes the DimeNet++ model
The default values correspond to the orinal values of DimeNet++.
- Parameters
r_cutoff (
float
) – Radial cut-off distance of edgesn_species (
int
) – Number of different atom species the network is supposed to process.num_targets (
int
) – Number of different atomic properties to predictembed_size (
int
) – Size of message embeddings. Scale interaction and output embedding sizes accordingly, if not specified explicitly.n_interaction_blocks (
int
) – Number of interaction blocksnum_residual_before_skip (
int
) – Number of residual blocks before the skip connection in the Interaction block.num_residual_after_skip (
int
) – Number of residual blocks after the skip connection in the Interaction block.out_embed_size (
Optional
[int
]) – Embedding size of output block. If None is set to 2 * embed_size.type_embed_size (
Optional
[int
]) – Embedding size of atom type embeddings. If None is set to 0.5 * embed_size.angle_int_embed_size (
Optional
[int
]) – Embedding size of Linear layers for down-projected triplet interation. If None is 0.5 * embed_size.basis_int_embed_size (
int
) – Embedding size of Linear layers for interation of RBS/ SBF basis in interaction blocknum_dense_out (
int
) – Number of final Linear layers in output blocknum_rbf (
int
) – Number of radial Bessel embedding functionsnum_sbf (
int
) – Number of spherical Bessel embedding functionsactivation (
Callable
) – Activation functionenvelope_p (
int
) – Power of envelope polynomialinit_kwargs (
Optional
[Dict
[str
,Any
]]) – Kwargs for initializaion of Linear layersname (
str
) – Name of DimeNet++ model
- __call__(graph)[source]
Predicts per-atom quantities for a given molecular graph.
- Parameters
graph (
SparseDirectionalGraph
) – An instance of sparse_graph.SparseDirectionalGraph defining the molecular graph connectivity.- Return type
- Returns
An (n_partciles, num_targets) array of predicted per-atom quantities
Energy function
- jax_dimenet.dimenet.energy_neighborlist(displacement, r_cutoff, n_species=10, positions_test=None, neighbor_test=None, max_triplet_multiplier=1.25, max_edge_multiplier=1.25, **dimenetpp_kwargs)[source]
DimeNet++ energy function for Jax, M.D.
This function provides an interface for the DimeNet++ haiku model to be used as a jax_md energy_fn. Analogous to jax_md energy_fns, the initialized DimeNet++ energy_fn requires particle positions and a dense neighbor list as input - plus an array for species or other dynamic kwargs, if applicable.
From particle positions and neighbor list, the sparse graph representation with edges and angle triplets is computed. Due to the constant shape requirement of jit of the neighborlist in jax_md, the neighbor list contains many masked edges, i.e. pairwise interactions that only “fill” the neighbor list, but are set to 0 during computation. This translates to masked edges and triplets in the sparse graph representation.
For improved computational efficiency during jax_md simulations, the maximum number of edges and triplets can be estimated during model initialization. Edges and triplets beyond this maximum estimate can be capped to reduce computational and memory requirements. Capping is enabled by providing sample inputs (positions_test and neighbor_test) at initialization time. However, beware that currently, an overflow of max_edges and max_angles is not caught, as this requires passing an error code throgh jax_md simulators - analogous to the overflow detection in jax_md neighbor lists. If in doubt, increase the max edges/angles multipliers or disable capping.
- Parameters
displacement (
Callable
[[ndarray
,ndarray
],ndarray
]) – Jax_md displacement functionr_cutoff (
float
) – Radial cut-off distance of DimeNetPP and the neighbor listn_species (
int
) – Number of different atom species the network is supposed to process.positions_test (
Optional
[ndarray
]) – Sample positions to estimate max_edges / max_angles. Needs to be provided to enable capping.neighbor_test (
Optional
[NeighborList
]) – Sample neighborlist to estimate max_edges / max_angles. Needs to be provided to enable capping.max_edge_multiplier (
float
) – Multiplier for initial estimate of maximum edges.max_triplet_multiplier (
float
) – Multiplier for initial estimate of maximum triplets.dimenetpp_kwargs – Kwargs to change the default structure of DimeNet++. For definition of the kwargs, see DimeNetPP.
- Returns
A init_fn that initializes the model parameters and an energy function that computes the energy for a particular state given model parameters. The energy function requires the same input as other energy functions with neighbor lists in jax_md.energy.
- Return type
A tuple of 2 functions
Property prediction
- jax_dimenet.dimenet.property_prediction(r_cutoff, n_targets=1, n_species=100, **model_kwargs)[source]
Initializes a model that predicts global molecular properties.
- Parameters
r_cutoff (
float
) – Radial cut-off distance of DimeNetPP and the neighbor list.n_targets (
int
) – Number of different molecular properties to predict.n_species (
int
) – Number of different atom species the network is supposed to process.**model_kwargs – Kwargs to change the default structure of DimeNet++.
- Returns
A init_fn that initializes the model parameters and an apply_function that predicts global molecular properties.
- Return type
A tuple of 2 functions