Sparse Graph

Functions to extract the sparse directional graph representation of a molecular state.

The SparseDirectionalGraph is the input to DimeNetPP.

Graph dataclass

class jax_dimenet.sparse_graph.SparseDirectionalGraph(species, distance_ij, idx_i, idx_j, angles, reduce_to_ji, expand_to_kj, species_mask=None, edge_mask=None, triplet_mask=None, n_edges=None, n_triplets=None)[source]

Sparse directial graph representation of a molecular state.

Required arguments are necessary inputs for DimeNet++. If masks are not provided, all entities are assumed to be present.

distance_ij

A (N_edges,) array storing for each the radial distances between particle i and j

Type

jax._src.numpy.ndarray.ndarray

idx_i

A (N_edges,) array storing for each edge particle index i

Type

jax._src.numpy.ndarray.ndarray

idx_j

A (N_edges,) array storing for each edge particle index j

Type

jax._src.numpy.ndarray.ndarray

angles

A (N_triplets,) array storing for each triplet the angle formed by the 3 particles

Type

jax._src.numpy.ndarray.ndarray

reduce_to_ji

A (N_triplets,) array storing for each triplet kji edge index j->i to aggregate messages via a segment_sum: each m_ji is a distinct segment containing all incoming m_kj.

Type

jax._src.numpy.ndarray.ndarray

expand_to_kj

A (N_triplets,) array storing for each triplet kji edge index k->j to gather all incoming edges for message passing.

Type

jax._src.numpy.ndarray.ndarray

edge_mask

A (N_edges,) boolean array storing for each edge whether the edge exists. By default, all edges are considered.

Type

Optional[jax._src.numpy.ndarray.ndarray]

triplet_mask

A (N_triplets,) boolean array storing for each triplet whether the triplet exists. By default, all triplets are considered.

Type

Optional[jax._src.numpy.ndarray.ndarray]

n_edges

Number of non-masked edges in the graph. None assumes all edges are real.

Type

Optional[int]

n_triplets

Number of non-masked triplets in the graph. None assumes all triplets are real.

Type

Optional[int]

n_particles

Number of non-masked species in the graph.

cap_exactly()[source]

Deletes all non-existing edges and triplets from the stored graph.

This is a non-pure function and hence not available in a jit-context. Returning the capped graph does not solve the problem when n_edges and n_triplets are computed within the jit-compiled function.

classmethod from_dict(graph_dict)[source]

Initializes instance from dictionary containing all necessary keys for initialization.

to_dict()[source]

Returns the stored graph data as a dictionary of arrays. This format is often beneficial for dataloaders.

Graph building

Functions to extract SparseDirectionalGraph from molecular positions in a box.

jax_dimenet.sparse_graph.sparse_graph_from_neighborlist(displacement_fn, positions, neighbor, r_cutoff, species=None, max_edges=None, max_triplets=None, species_mask=None)[source]

Constructs a sparse representation of graph edges and angles to save memory and computations over neighbor list.

The speed-up over simply using the dense jax_md neighbor list is significant, particularly regarding triplets. To allow for a representation of constant size required by jit, we pad the resulting vectors.

Parameters
  • displacement_fn (Callable) – Jax_MD displacement function encoding box dimensions

  • positions (ndarray) – (N_particles, dim) array of particle positions

  • neighbor (NeighborList) – Jax_MD neighbor list that is in sync with positions

  • r_cutoff (array) – Radial cutoff distance, below which 2 particles are considered to be connected by an edge.

  • species (Optional[array]) – (N_particles,) array encoding atom types. If None, assumes type 0 for all atoms.

  • max_edges (Optional[int]) – Maximum number of edges storable in the graph. Can be used to reduce the number of padded edges, but should be used carefully, such that no existing edges are capped. Default None uses the maximum possible number of edges as given by the dense neighbor list.

  • max_triplets (Optional[int]) – Maximum number of triplets storable in the graph. Can be used to reduce the number of padded triplets, but should be used carefully, such that no existing triplets are capped. Default None uses the maximum possible number of triplets as given by the dense neighbor list.

  • species_mask (Optional[array]) – (N_particles,) array encoding atom types. Default None, assumes no masking necessary.

Return type

Tuple[SparseDirectionalGraph, bool]

Returns

Tuple (sparse_graph, too_many_edges_error_code) containing the SparseDirectionalGraph and whether max_edges or max_triplets overflowed.

jax_dimenet.sparse_graph.angle_triplets(positions, displacement_fn, angle_idxs, angle_mask)[source]

Computes the angle for all triplets between 0 and pi.

Masked angles are set to pi/2.

Parameters
  • positions – Array pf particle positions (N_particles x 3)

  • displacement_fn – Jax_md displacement function

  • angle_idxs – Array of particle indeces that form a triplet (N_triples x 3)

  • angle_mask – Boolean mask for each triplet, which is False for triplets that need to be masked.

Returns

A (N_triples,) array with the angle for each triplet.

jax_dimenet.sparse_graph.safe_angle_mask(r_ji, r_kj, angle_mask)[source]

Sets masked angles to pi/2 to ensure differentiablility.

Parameters
  • r_ji – Array (N_triplets, dim) of vectors pointing to position of particle i from particle j

  • r_kj – Array (N_triplets, dim) of vectors pointing to position of particle k from particle j

  • angle_mask – (N_triplets, ) or (N_triplets, 1) Boolean mask for each triplet, which is False for triplets that need to be masked.

Returns

A tuple (r_ji_safe, r_kj_safe) of vectors r_ji and r_kj, where masked triplets are replaced such that the angle between them is pi/2.

jax_dimenet.sparse_graph.angle(r_ij, r_kj)[source]

Computes the angle (kj, ij) from vectors r_kj and r_ij, correctly selecting the quadrant.

Based on \(\tan(\theta) = |(r_{ji} \times r_{kj})| / (r_{ji} \cdot r_{kj})\). Beware the non-differentability of arctan2(0,0).

Parameters
  • r_ij – Vector pointing to position of particle i from particle j

  • r_kj – Vector pointing to position of particle k from particle j

Returns

Angle between vectors

Dataset preprocessing

For direct molecular property prediction tasks (i.e. predicting potential energy, band-gap, …), one can pre-compute the SparseDirectionalGraph for the whole dataset.

jax_dimenet.sparse_graph.convert_dataset_to_graphs(r_cutoff, position_data, box, species, padding=True)[source]

Converts input consisting of particle poistions and boxes to a dataset of sparse graph representations.

Due to the high memory cost of saving padded graphs, this preprocessing step is only recomended for small datasets and only slightly changing number of particles per box.

This function tackles the general case, where the number of particles and boxes vary across different snapshots, introducing some overhead if particle number and the box is fixed. Due to this general setting, this function is not jittable.

Parameters
  • r_cutoff – Radial cut-off distance below which 2 particles form an edge

  • position_data – Either a list of (N_particles, dim) arrays of particle positions in case N_particles is not constant accross snapshots or a (N_snapshots, N_particles, dim) array. The positions need to be given in real (non-fractional) coordinates.

  • box – Either a single 1 or 2-dimensional box (if the box is constant across snapshots) or an (N_snapshots, dim) or (N_snapshots, dim, dim) array of boxes.

  • species – Either a list of (N_particles,) arrays of atom types in case N_particles is not constant accross snapshots or a single (N_particles,) array.

  • padding – If True, pads resulting edges and triplets to the maximum across the input data to allow for straightforward batching without re-compilation. If False, returns edges and triplets with varying shapes, but to-be-masked non-existing edges / triplets.

Returns

With padding, a SparseDirectionalGraph pytree containing all graphs of the dataset, stacked along axis 0. Without padding, a dictionary containing the whole definitions of the sparse molecular graph, given as Lists. Refer to SparseDirectionalGraph for respective definitions.

jax_dimenet.sparse_graph.pad_per_atom_quantities(per_atom_data)[source]

Pads list arrays containing per-atom quantities (e.g. species).

Allows for straightforward batching without re-compilations in case of non-constant number of particles across snapshots.

Parameters

per_atom_data – List of (N_particles,) arrays containing a scalar quantity of each particle.

Returns

A (N_snapshots, N_particles) array and corresponding mask array.