Layers

Jax / Haiku implementation of layers to build the DimeNet++ architecture.

The DimeNet++ Building Blocks take components of SparseDirectionalGraph as input. Please refer to this class for input descriptions.

Initializers

class jax_dimenet.layers.OrthogonalVarianceScalingInit(scale=2.0)[source]

Initializer scaling variance of uniform orthogonal matrix distribution.

Generates a weight matrix with variance according to Glorot initialization. Based on a random (semi-)orthogonal matrix. Neural networks are expected to learn better when features are decorrelated e.g. stated by “Reducing overfitting in deep networks by decorrelating representations”.

The approach is adopted from the original DimeNet and the implementation is inspired by Haiku’s variance scaling initializer.

scale

Variance scaling factor

__init__(scale=2.0)[source]

Constructs the OrthogonalVarianceScaling Initializer.

Parameters

scale – Variance scaling factor

class jax_dimenet.layers.RBFFrequencyInitializer(*args, **kwds)[source]

Initializer of the frequencies of the RadialBesselLayer.

Initializes the frequency values to its canonical values.

DimeNet++ Layers

Basis Layers

class jax_dimenet.layers.SmoothingEnvelope(*args, **kwargs)[source]

Smoothing envelope function for radial edge embeddings.

Smoothing the cut-off enables twice continuous differentiability of the model output, including the potential energy. The envelope function is 1 at 0 and has a root of multiplicity of 3 at 1 as defined in DimeNet. It is applied to scaled radial edge distances d_ij / cut_off [0, 1].

The implementation corresponds to the definition in the DimeNet paper. It is different from the original implementation of DimeNet / DimeNet++ that define incorrect spherical basis layers as a result (a known issue).

__init__(p=6, name='Envelope')[source]

Initializes the SmoothingEnvelope layer.

Parameters
  • p – Power of the smoothing polynomial

  • name – Name of the layer

__call__(distances)[source]

Returns the envelope values.

class jax_dimenet.layers.RadialBesselLayer(*args, **kwargs)[source]

Radial Bessel Function (RBF) representation of pairwise distances.

freq_init

RBFFrequencyInitializer

__init__(cutoff, num_radial=16, envelope_p=6, name='BesselRadial')[source]

Initializes the RBF layer.

Parameters
  • cutoff – Radial cut-off distance

  • num_radial – Number of radial Bessel embedding functions

  • envelope_p – Power of envelope polynomial

  • name – Name of RBF layer

__call__(distances)[source]

Returns the RBF embeddings of edge distances.

class jax_dimenet.layers.SphericalBesselLayer(*args, **kwargs)[source]

Spherical Bessel Function (SBF) representation of angular triplets.

__init__(r_cutoff, num_spherical, num_radial, envelope_p=6, name='BesselSpherical')[source]

Initializes the SBF layer.

Parameters
  • r_cutoff – Radial cut-off

  • num_spherical – Number of spherical Bessel embedding functions

  • num_radial – Number of radial Bessel embedding functions

  • envelope_p – Power of envelope polynomial

  • name – Name of SBF layer

__call__(pair_distances, angles, angle_mask, expand_to_kj)[source]

Returns the SBF embeddings of angular triplets.

DimeNet++ Building Blocks

class jax_dimenet.layers.ResidualLayer(*args, **kwargs)[source]

Residual Layer: 2 activated Linear layers and a skip connection.

__init__(layer_size, activation=<CompiledFunction of <function silu>>, init_kwargs=None, name='ResLayer')[source]

Initializes the Residual layer.

Parameters
  • layer_size – Output size of the Linear layers

  • activation – Activation function

  • init_kwargs – Dict of initialization kwargs for Linear layers

  • name – Name of the Residual layer

__call__(inputs)[source]

Returns the ouput of the Residual layer.

class jax_dimenet.layers.EmbeddingBlock(*args, **kwargs)[source]

Embeddimg block of DimeNet.

Embeds edges by concatenating RBF embeddings with atom type embeddings of both connected atoms.

__init__(embed_size, n_species, type_embed_size=None, activation=<CompiledFunction of <function silu>>, init_kwargs=None, name='Embedding')[source]

Initializes an Embedding block.

Parameters
  • embed_size – Size of the edge embedding.

  • n_species – Number of different atom species the network is supposed to process.

  • type_embed_size – Embedding size of atom type embedding. Default None results in embed_size / 2.

  • activation – Activation function

  • init_kwargs – Dict of initialization kwargs for Linear layers

  • name – Name of Embedding block

__call__(rbf, species, idx_i, idx_j)[source]

Returns output of the Embedding block.

class jax_dimenet.layers.OutputBlock(*args, **kwargs)[source]

DimeNet++ Output block.

Predicts per-atom quantities given RBF embeddings and messages.

__init__(embed_size, out_embed_size=None, num_dense=3, num_targets=1, activation=<CompiledFunction of <function silu>>, init_kwargs=None, name='Output')[source]

Initializes an Output block.

Parameters
  • embed_size – Size of the edge embedding.

  • out_embed_size – Output size of Linear layers after upsampling

  • num_dense – Number of dense layers

  • num_targets – Number of target quantities to be predicted

  • activation – Activation function

  • init_kwargs – Dict of initialization kwargs for Linear layers

  • name – Name of Output block

__call__(messages, rbf, idx_i, n_particles)[source]

Returns predicted per-atom quantities.

class jax_dimenet.layers.InteractionBlock(*args, **kwargs)[source]

DimeNet++ Interaction block.

Performs directional message-passing based on RBF and SBF embeddings as well as messages from the previous message-passing iteration. Updated messages are used in the subsequent Output block.

__init__(embed_size, num_res_before_skip, num_res_after_skip, activation=<CompiledFunction of <function silu>>, init_kwargs=None, angle_int_embed_size=None, basis_int_embed_size=8, name='Interaction')[source]

Initializes an Interaction block.

Parameters
  • embed_size – Size of the edge embedding.

  • num_res_before_skip – Number of Residual blocks before skip

  • num_res_after_skip – Number of Residual blocks after skip

  • activation – Activation function

  • init_kwargs – Dict of initialization kwargs for Linear layers

  • angle_int_embed_size – Embedding size of Linear layers for down-projected triplet interation

  • basis_int_embed_size – Embedding size of Linear layers for interation of RBS/ SBF basis

  • name – Name of Interaction block

__call__(m_input, rbf, sbf, reduce_to_ji, expand_to_kj)[source]

Returns messages after interaction via message-passing.

Utility Functions

high_precision_segment_sum(data, segment_ids)

Implements the jax.ops.segment_sum, but casts input to float64 before summation and casts back to a target output type afterwards (float32 by default).