# Copyright 2022 Multiscale Modeling of Fluid Materials, TU Munich
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions to extract the sparse directional graph representation of a
molecular state.
The :class:`SparseDirectionalGraph` is the input to
:class:`~jax_dimenet.neural_networks.DimeNetPP`.
"""
import inspect
from typing import Optional, Callable, Tuple
import chex
import numpy as onp
from jax import numpy as jnp, vmap, lax
from jax_md import space, partition, smap
from jax_dimenet import custom_space
[docs]@chex.dataclass
class SparseDirectionalGraph:
"""Sparse directial graph representation of a molecular state.
Required arguments are necessary inputs for DimeNet++.
If masks are not provided, all entities are assumed to be present.
Attributes:
distance_ij: A (N_edges,) array storing for each the radial distances
between particle i and j
idx_i: A (N_edges,) array storing for each edge particle index i
idx_j: A (N_edges,) array storing for each edge particle index j
angles: A (N_triplets,) array storing for each triplet the angle formed
by the 3 particles
reduce_to_ji: A (N_triplets,) array storing for each triplet kji edge
index j->i to aggregate messages via a segment_sum: each
m_ji is a distinct segment containing all incoming m_kj.
expand_to_kj: A (N_triplets,) array storing for each triplet kji edge
index k->j to gather all incoming edges for message
passing.
edge_mask: A (N_edges,) boolean array storing for each edge whether the
edge exists. By default, all edges are considered.
triplet_mask: A (N_triplets,) boolean array storing for each triplet
whether the triplet exists. By default, all triplets are
considered.
n_edges: Number of non-masked edges in the graph. None assumes all
edges are real.
n_triplets: Number of non-masked triplets in the graph. None assumes
all triplets are real.
n_particles: Number of non-masked species in the graph.
"""
species: jnp.ndarray
distance_ij: jnp.ndarray
idx_i: jnp.ndarray
idx_j: jnp.ndarray
angles: jnp.ndarray
reduce_to_ji: jnp.ndarray
expand_to_kj: jnp.ndarray
species_mask: Optional[jnp.ndarray] = None
edge_mask: Optional[jnp.ndarray] = None
triplet_mask: Optional[jnp.ndarray] = None
n_edges: Optional[int] = None
n_triplets: Optional[int] = None
def __post_init__(self):
if self.species_mask is None:
self.species_mask = jnp.ones_like(self.species, dtype=bool)
if self.edge_mask is None:
self.edge_mask = jnp.ones_like(self.distance_ij, dtype=bool)
if self.triplet_mask is None:
self.triplet_mask = jnp.ones_like(self.angles, dtype=bool)
@property
def n_particles(self):
return jnp.sum(self.species_mask)
[docs] def to_dict(self):
"""Returns the stored graph data as a dictionary of arrays.
This format is often beneficial for dataloaders.
"""
return {
'species': self.species,
'distance_ij': self.distance_ij,
'idx_i': self.idx_i,
'idx_j': self.idx_j,
'angles': self.angles,
'reduce_to_ji': self.reduce_to_ji,
'expand_to_kj': self.expand_to_kj,
'species_mask': self.species_mask,
'edge_mask': self.edge_mask,
'triplet_mask': self.triplet_mask
}
[docs] @classmethod
def from_dict(cls, graph_dict):
"""Initializes instance from dictionary containing all necessary keys
for initialization.
"""
return cls(**{
key: value for key, value in graph_dict.items()
if key in inspect.signature(cls).parameters
})
[docs] def cap_exactly(self):
"""Deletes all non-existing edges and triplets from the stored graph.
This is a non-pure function and hence not available in a jit-context.
Returning the capped graph does not solve the problem when n_edges
and n_triplets are computed within the jit-compiled function.
"""
# edges are sorted, hence all non-existing edges are at the end
self.species = self.species[:self.n_particles]
self.species_mask = self.species_mask[:self.n_particles]
self.distance_ij = self.distance_ij[:self.n_edges]
self.idx_i = self.idx_i[:self.n_edges]
self.idx_j = self.idx_j[:self.n_edges]
self.edge_mask = self.edge_mask[:self.n_edges]
self.angles = self.angles[:self.n_triplets]
self.reduce_to_ji = self.reduce_to_ji[:self.n_triplets]
self.expand_to_kj = self.expand_to_kj[:self.n_triplets]
self.triplet_mask = self.triplet_mask[:self.n_triplets]
[docs]def angle(r_ij, r_kj):
"""Computes the angle (kj, ij) from vectors r_kj and r_ij,
correctly selecting the quadrant.
Based on
:math:`\\tan(\\theta) = |(r_{ji} \\times r_{kj})| / (r_{ji} \\cdot r_{kj})`.
Beware the non-differentability of arctan2(0,0).
Args:
r_ij: Vector pointing to position of particle i from particle j
r_kj: Vector pointing to position of particle k from particle j
Returns:
Angle between vectors
"""
cross = jnp.linalg.norm(jnp.cross(r_ij, r_kj))
dot = jnp.dot(r_ij, r_kj)
theta = jnp.arctan2(cross, dot)
return theta
[docs]def safe_angle_mask(r_ji, r_kj, angle_mask):
"""Sets masked angles to pi/2 to ensure differentiablility.
Args:
r_ji: Array (N_triplets, dim) of vectors pointing to position of
particle i from particle j
r_kj: Array (N_triplets, dim) of vectors pointing to position of
particle k from particle j
angle_mask: (N_triplets, ) or (N_triplets, 1) Boolean mask for each
triplet, which is False for triplets that need to be masked.
Returns:
A tuple (r_ji_safe, r_kj_safe) of vectors r_ji and r_kj, where masked
triplets are replaced such that the angle between them is pi/2.
"""
if angle_mask.ndim == 1: # expand for broadcasing, if necessary
angle_mask = jnp.expand_dims(angle_mask, -1)
safe_ji = jnp.array([1., 0., 0.], dtype=jnp.float32)
safe_kj = jnp.array([0., 1., 0.], dtype=jnp.float32)
r_ji_safe = jnp.where(angle_mask, r_ji, safe_ji)
r_kj_safe = jnp.where(angle_mask, r_kj, safe_kj)
return r_ji_safe, r_kj_safe
[docs]def angle_triplets(positions, displacement_fn, angle_idxs, angle_mask):
"""Computes the angle for all triplets between 0 and pi.
Masked angles are set to pi/2.
Args:
positions: Array pf particle positions (N_particles x 3)
displacement_fn: Jax_md displacement function
angle_idxs: Array of particle indeces that form a triplet
(N_triples x 3)
angle_mask: Boolean mask for each triplet, which is False for triplets
that need to be masked.
Returns:
A (N_triples,) array with the angle for each triplet.
"""
r_i = positions[angle_idxs[:, 0]]
r_j = positions[angle_idxs[:, 1]]
r_k = positions[angle_idxs[:, 2]]
# Note: The original DimeNet implementation uses R_ji, however r_ij is the
# correct vector to get the angle between both vectors. This is a
# known issue in DimeNet. We apply the correct angle definition.
r_ij = vmap(displacement_fn)(r_i, r_j) # r_i - r_j respecting periodic BCs
r_kj = vmap(displacement_fn)(r_k, r_j)
# we need to mask as the case where r_ij is co-linear with r_kj.
# Otherwise, this generates NaNs on the backward pass
r_ij_safe, r_kj_safe = safe_angle_mask(r_ij, r_kj, angle_mask)
angles = vmap(angle)(r_ij_safe, r_kj_safe)
return angles
def _flatten_sort_and_capp(matrix, sorting_args, cap_size):
"""Helper function that takes a 2D array, flattens it, sorts it using the
args (usually provided via argsort) and capps the end of the resulting
vector. Used to delete non-existing edges and returns the capped vector.
"""
vect = jnp.ravel(matrix)
sorted_vect = vect[sorting_args]
capped_vect = sorted_vect[0:cap_size]
return capped_vect
[docs]def sparse_graph_from_neighborlist(displacement_fn: Callable,
positions: jnp.ndarray,
neighbor: partition.NeighborList,
r_cutoff: jnp.array,
species: jnp.array = None,
max_edges: Optional[int] = None,
max_triplets: Optional[int] = None,
species_mask: jnp.array = None,
) -> Tuple[SparseDirectionalGraph, bool]:
"""Constructs a sparse representation of graph edges and angles to save
memory and computations over neighbor list.
The speed-up over simply using the dense jax_md neighbor list is
significant, particularly regarding triplets. To allow for a representation
of constant size required by jit, we pad the resulting vectors.
Args:
displacement_fn: Jax_MD displacement function encoding box dimensions
positions: (N_particles, dim) array of particle positions
neighbor: Jax_MD neighbor list that is in sync with positions
r_cutoff: Radial cutoff distance, below which 2 particles are considered
to be connected by an edge.
species: (N_particles,) array encoding atom types. If None, assumes type
0 for all atoms.
max_edges: Maximum number of edges storable in the graph. Can be used to
reduce the number of padded edges, but should be used
carefully, such that no existing edges are capped. Default
None uses the maximum possible number of edges as given by
the dense neighbor list.
max_triplets: Maximum number of triplets storable in the graph. Can be
used to reduce the number of padded triplets, but should be
used carefully, such that no existing triplets are capped.
Default None uses the maximum possible number of triplets as
given by the dense neighbor list.
species_mask: (N_particles,) array encoding atom types. Default None,
assumes no masking necessary.
Returns:
Tuple (sparse_graph, too_many_edges_error_code) containing the
SparseDirectionalGraph and whether max_edges or max_triplets overflowed.
"""
# TODO might be worth updating this function to the new sparse-style
# neighborlist in jax_md
assert neighbor.format.name == 'Dense', ('Currently only dense neighbor'
' lists supported.')
n_particles, max_neighbors = neighbor.idx.shape
species = _canonicalize_species(species, n_particles)
neighbor_displacement_fn = space.map_neighbor(displacement_fn)
# compute pairwise distances
pos_neigh = positions[neighbor.idx]
pair_displacement = neighbor_displacement_fn(positions, pos_neigh)
pair_distances = space.distance(pair_displacement)
# compute adjacency matrix via neighbor_list, then build sparse graph
# representation to avoid part of padding overhead in dense neighborlist
# adds all edges > cut-off to masked edges
edge_idx_ji = jnp.where(pair_distances < r_cutoff, neighbor.idx,
n_particles)
# neighbor.idx: an index j in row i encodes a directed edge from
# particle j to particle i.
# edge_idx[i, j]: j->i. if j == N: encodes masked edge.
# Index N would index out-of-bounds, but in jax the last element is
# returned instead
# conservative estimates for initialization run
# use guess from initialization for tighter bound to save memory and
# computations during production runs
if max_edges is None:
max_edges = n_particles * max_neighbors
if max_triplets is None:
max_triplets = max_edges * max_neighbors
# sparse edge representation:
# construct vectors from adjacency matrix and only keep existing edges
# Target node (i) and source (j) of edges
pair_mask = edge_idx_ji != n_particles # non-existing neighbor encoded as N
# due to undirectedness, each edge is included twice
n_edges = jnp.count_nonzero(pair_mask)
pair_mask_flat = jnp.ravel(pair_mask)
# non-existing edges are sorted to the end for capping
sorting_idxs = jnp.argsort(~pair_mask_flat)
_, yy = jnp.meshgrid(jnp.arange(max_neighbors), jnp.arange(n_particles)) # pylint: disable=unbalanced-tuple-unpacking
idx_i = _flatten_sort_and_capp(yy, sorting_idxs, max_edges)
idx_j = _flatten_sort_and_capp(edge_idx_ji, sorting_idxs, max_edges)
d_ij = _flatten_sort_and_capp(pair_distances, sorting_idxs, max_edges)
sparse_pair_mask = _flatten_sort_and_capp(pair_mask_flat, sorting_idxs,
max_edges)
# build sparse angle combinations from adjacency matrix:
# angle defined for 3 particles with connections k->j and j->i
# directional message passing accumulates all k->j to update each m_ji
idx3_i = jnp.repeat(idx_i, max_neighbors)
idx3_j = jnp.repeat(idx_j, max_neighbors)
# retrieves for each j in idx_j its neighbors k: stored in 2nd axis
idx3_k_mat = edge_idx_ji[idx_j]
idx3_k = idx3_k_mat.ravel()
angle_idxs = jnp.column_stack([idx3_i, idx3_j, idx3_k])
# masking:
# k and j are different particles, by edge_idx_ji construction.
# The same applies to j - i, except for masked ones
mask_i_eq_k = idx3_i != idx3_k
# mask for ij known a priori
mask_ij = jnp.repeat(sparse_pair_mask, max_neighbors)
mask_k = idx3_k != n_particles
angle_mask = mask_ij * mask_k * mask_i_eq_k # union of masks
angle_mask, sorting_idx3 = lax.top_k(angle_mask, max_triplets)
angle_idxs = angle_idxs[sorting_idx3]
n_triplets = jnp.count_nonzero(angle_mask)
angles = angle_triplets(positions, displacement_fn, angle_idxs, angle_mask)
# retrieving edge_id m_ji from nodes i and j:
# idx_i < N by construction, but idx_j can be N: will override
# lookup[i, N-1], which is problematic if [i, N-1] is an existing edge.
# Hence, the lookup table is extended by 1.
edge_id_lookup = jnp.zeros([n_particles, n_particles + 1], dtype=jnp.int32)
edge_id_lookup_direct = edge_id_lookup.at[(idx_i, idx_j)].set(
jnp.arange(max_edges))
# stores for each angle kji edge index j->i to aggregate messages via a
# segment_sum: each m_ji is a distinct segment containing all incoming m_kj
reduce_to_ji = edge_id_lookup_direct[(angle_idxs[:, 0], angle_idxs[:, 1])]
# stores for each angle kji edge index k->j to gather all incoming edges
# for message passing
expand_to_kj = edge_id_lookup_direct[(angle_idxs[:, 1], angle_idxs[:, 2])]
too_many_edges_error_code = lax.cond(
jnp.bitwise_or(n_edges > max_edges, n_triplets > max_triplets),
lambda _: True, lambda _: False, n_edges
)
sparse_graph = SparseDirectionalGraph(
species=species, distance_ij=d_ij, idx_i=idx_i, idx_j=idx_j,
angles=angles, reduce_to_ji=reduce_to_ji, expand_to_kj=expand_to_kj,
edge_mask=sparse_pair_mask, triplet_mask=angle_mask, n_edges=n_edges,
n_triplets=n_triplets, species_mask=species_mask
)
return sparse_graph, too_many_edges_error_code
def _pad_graph(final_size, quantities, connectivities):
"""Helper function that returns padded edges or triplets, while
differentiating between quantities (distances, angles) and adge / triplet
connectivity.
"""
# Everything can be padded with 0, because 0 corresponds to False
# and the edge/triplet will hence have no effect
padded_quantities, padded_connectivities = [], []
for (quantity, connectivity) in zip(quantities, connectivities):
pad_size = final_size - quantity.shape[0]
connectivity_pad = jnp.zeros((pad_size, 3), dtype=jnp.int32)
quantity_pad = jnp.zeros(pad_size, dtype=jnp.float32)
padded_connectivities.append(jnp.vstack((connectivity,
connectivity_pad)))
padded_quantities.append(jnp.concatenate((quantity, quantity_pad)))
return padded_quantities, padded_connectivities
[docs]def convert_dataset_to_graphs(r_cutoff, position_data, box, species,
padding=True):
"""Converts input consisting of particle poistions and boxes to a dataset
of sparse graph representations.
Due to the high memory cost of saving padded graphs, this preprocessing
step is only recomended for small datasets and only slightly changing
number of particles per box.
This function tackles the general case, where the number of particles and
boxes vary across different snapshots, introducing some overhead if particle
number and the box is fixed. Due to this general setting, this function is
not jittable.
Args:
r_cutoff: Radial cut-off distance below which 2 particles form an edge
position_data: Either a list of (N_particles, dim) arrays of particle
positions in case N_particles is not constant accross
snapshots or a (N_snapshots, N_particles, dim) array.
The positions need to be given in real (non-fractional)
coordinates.
box: Either a single 1 or 2-dimensional box (if the box is constant
across snapshots) or an (N_snapshots, dim) or
(N_snapshots, dim, dim) array of boxes.
species: Either a list of (N_particles,) arrays of atom types in case
N_particles is not constant accross snapshots or a single
(N_particles,) array.
padding: If True, pads resulting edges and triplets to the maximum
across the input data to allow for straightforward batching
without re-compilation. If False, returns edges and triplets
with varying shapes, but to-be-masked non-existing edges /
triplets.
Returns:
With padding, a SparseDirectionalGraph pytree containing all graphs of
the dataset, stacked along axis 0. Without padding, a dictionary
containing the whole definitions of the sparse molecular graph, given
as Lists. Refer to :class:`SparseDirectionalGraph` for respective
definitions.
"""
# canonicalize inputs to lists
if not isinstance(position_data, list):
n_snapshots = position_data.shape[0]
position_data = [position_data[i] for i in range(n_snapshots)]
else:
n_snapshots = len(position_data)
if box.shape[0] == n_snapshots: # array of boxes
box = [box[i] for i in range(n_snapshots)]
else: # a single box
box = [box for _ in range(n_snapshots)]
if not isinstance(species, list):
species = [species for _ in range(n_snapshots)]
max_edges = 0
max_triplets = 0
dists, angles, edges, triplets = [], [], [], []
for (positions, cur_box) in zip(position_data, box):
box_tensor, scale_fn = custom_space.init_fractional_coordinates(cur_box)
displacement_fn, _ = space.periodic_general(box_tensor)
positions = scale_fn(positions) # to fractional coordinates
neighbor_fn = partition.neighbor_list( # only required for 1 state
displacement_fn, box_tensor, r_cutoff, dr_threshold=0.01,
capacity_multiplier=1.01, fractional_coordinates=True
)
nbrs = neighbor_fn.allocate(positions) # pylint: disable=not-callable
graph, _ = sparse_graph_from_neighborlist(displacement_fn, positions,
nbrs, r_cutoff)
graph.cap_exactly()
max_edges = max(max_edges, graph.n_edges)
max_triplets = max(max_triplets, graph.n_triplets)
# build arrays for edges and angles. Needs to be stored in lists due
# to different edge and angle count across snapshots in general
# Boolean mask arrays are converted to int32 1.
dists.append(graph.distance_ij)
angles.append(graph.angles)
edges.append(jnp.stack((graph.idx_i, graph.idx_j, graph.edge_mask),
axis=-1))
triplets.append(jnp.stack((graph.reduce_to_ji, graph.expand_to_kj,
graph.triplet_mask), axis=-1))
if padding:
dists, edges = _pad_graph(max_edges, dists, edges)
angles, triplets = _pad_graph(max_triplets, angles, triplets)
species, species_mask = pad_per_atom_quantities(species)
else:
species_mask = [jnp.ones_like(species_arr) for species_arr in species]
# save in dict for better transparency
graph_rep = {
'species': species,
'distance_ij': dists,
'idx_i': [edge[:, 0] for edge in edges],
'idx_j': [edge[:, 1] for edge in edges],
'angles': angles,
'reduce_to_ji': [triplet[:, 0] for triplet in triplets],
'expand_to_kj': [triplet[:, 1] for triplet in triplets],
'species_mask': species_mask,
'edge_mask': [onp.array(edge[:, 2], dtype=bool) for edge in edges],
'triplet_mask': [onp.array(triplet[:, 2], dtype=bool)
for triplet in triplets]
}
if padding: # when padded, we can return arrays instead of lists
graph_rep = {key: onp.array(value) for key, value in graph_rep.items()}
graph_rep = SparseDirectionalGraph(**graph_rep)
return graph_rep
[docs]def pad_per_atom_quantities(per_atom_data):
"""Pads list arrays containing per-atom quantities (e.g. species).
Allows for straightforward batching without re-compilations in case of
non-constant number of particles across snapshots.
Args:
per_atom_data: List of (N_particles,) arrays containing a scalar
quantity of each particle.
Returns:
A (N_snapshots, N_particles) array and corresponding mask array.
"""
max_particles = max([species.size for species in per_atom_data])
n_snapshots = len(per_atom_data)
padded_quantity = onp.zeros((n_snapshots, max_particles),
dtype=per_atom_data[0].dtype)
quantity_mask = onp.zeros((n_snapshots, max_particles), dtype=bool)
for i, quantity in enumerate(per_atom_data):
padded_quantity[i, :quantity.size] = quantity
quantity_mask[i, :quantity.size] = True
return padded_quantity, quantity_mask
def _canonicalize_species(species, n_particles):
"""Ensures species are integer and initializes species to 0 if species=None.
Args:
species: (N_particles,) array of atom types or None
n_particles: Number of particles
Returns:
Integer species array.
"""
if species is None:
species = jnp.zeros(n_particles, dtype=jnp.int32)
else:
smap._check_species_dtype(species) # assert species are int
return species