Jax-DimeNet++

PyPI version Documentation Status

Haiku implementation of the DimeNet++ architecture.

Getting started

This repository provides 2 interfaces for the DimeNet++ model. The first interface allows using DimeNet++ as a Jax M.D. energy_fn to run molecular dynamics simulations. The second interface allows prediction of global molecular properties.

from jax_dimenet import dimenet, sparse_graph

# Jax M.D. energy function:
init_fn, dimenet_energy_fn = dimenet.energy_neighborlist(displacement_fn, r_cut)
init_params = init_fn(random.PRNGKey(0), positions, neighbor=neighbors)
energy_fn = partial(dimenet_energy_fn, init_params)  # jax_md energy_fn interface
print('Predicted energy:', energy_fn(positions, neighbors))

# Molecular property prediction:
mol_graph, _ = sparse_graph.sparse_graph_from_neighborlist(
    displacement_fn, positions, neighbors, r_cut)
init_fn, property_predictor = dimenet.property_prediction(r_cut, n_targets=5)
init_params = init_fn(random.PRNGKey(0), mol_graph)
print('Predicted properties:', property_predictor(init_params, mol_graph))

A minimum usage example is available in the notebooks folder. For more real-world applications of the DimeNet++ model in MD simulations, please refer to the DiffTRe repository.

Installation

You can install Jax-DimeNet via pip:

pip install jax-dimenet

Requirements

The repository uses the following packages:

    jax>=0.2.12
    jaxlib>=0.1.65
    jax-md>=0.1.13
    dm-haiku>=0.0.4
    sympy
    chex

The code was run with Python 3.8.

Contribution

Contributions are always welcome! Please open a pull request to discuss the code additions.

Contact

For questions, please contact stephan.thaler@tum.de.

Citation

If you use this code in your own work, please consider citing the paper that introduced this DimeNet++ implementation,

@article{thaler_difftre_2021,
  title = {Learning neural network potentials from experimental data via Differentiable Trajectory Reweighting},
  author = {Thaler, Stephan and Zavadlav, Julija},
  journal={Nature Communications},
  volume={12},
  pages={6884},
  doi={10.1038/s41467-021-27241-4}
  year = {2021}
}

as well as the original DimeNet++ paper.

@inproceedings{klicpera_dimenetpp_2020,
  title = {Fast and Uncertainty-Aware Directional Message Passing for Non-Equilibrium Molecules},
  author = {Klicpera, Johannes and Giri, Shankari and Margraf, Johannes T. and G{\"u}nnemann, Stephan},
  booktitle={NeurIPS-W},
  year = {2020}
}

Indices and tables