Layers
Jax / Haiku implementation of layers to build the DimeNet++ architecture.
The DimeNet++ Building Blocks take components of
SparseDirectionalGraph
as input. Please refer
to this class for input descriptions.
Initializers
- class jax_dimenet.layers.OrthogonalVarianceScalingInit(scale=2.0)[source]
Initializer scaling variance of uniform orthogonal matrix distribution.
Generates a weight matrix with variance according to Glorot initialization. Based on a random (semi-)orthogonal matrix. Neural networks are expected to learn better when features are decorrelated e.g. stated by “Reducing overfitting in deep networks by decorrelating representations”.
The approach is adopted from the original DimeNet and the implementation is inspired by Haiku’s variance scaling initializer.
- scale
Variance scaling factor
DimeNet++ Layers
Basis Layers
- class jax_dimenet.layers.SmoothingEnvelope(*args, **kwargs)[source]
Smoothing envelope function for radial edge embeddings.
Smoothing the cut-off enables twice continuous differentiability of the model output, including the potential energy. The envelope function is 1 at 0 and has a root of multiplicity of 3 at 1 as defined in DimeNet. It is applied to scaled radial edge distances d_ij / cut_off [0, 1].
The implementation corresponds to the definition in the DimeNet paper. It is different from the original implementation of DimeNet / DimeNet++ that define incorrect spherical basis layers as a result (a known issue).
- class jax_dimenet.layers.RadialBesselLayer(*args, **kwargs)[source]
Radial Bessel Function (RBF) representation of pairwise distances.
- freq_init
RBFFrequencyInitializer
- class jax_dimenet.layers.SphericalBesselLayer(*args, **kwargs)[source]
Spherical Bessel Function (SBF) representation of angular triplets.
- __init__(r_cutoff, num_spherical, num_radial, envelope_p=6, name='BesselSpherical')[source]
Initializes the SBF layer.
- Parameters
r_cutoff – Radial cut-off
num_spherical – Number of spherical Bessel embedding functions
num_radial – Number of radial Bessel embedding functions
envelope_p – Power of envelope polynomial
name – Name of SBF layer
DimeNet++ Building Blocks
- class jax_dimenet.layers.ResidualLayer(*args, **kwargs)[source]
Residual Layer: 2 activated Linear layers and a skip connection.
- __init__(layer_size, activation=<CompiledFunction of <function silu>>, init_kwargs=None, name='ResLayer')[source]
Initializes the Residual layer.
- Parameters
layer_size – Output size of the Linear layers
activation – Activation function
init_kwargs – Dict of initialization kwargs for Linear layers
name – Name of the Residual layer
- class jax_dimenet.layers.EmbeddingBlock(*args, **kwargs)[source]
Embeddimg block of DimeNet.
Embeds edges by concatenating RBF embeddings with atom type embeddings of both connected atoms.
- __init__(embed_size, n_species, type_embed_size=None, activation=<CompiledFunction of <function silu>>, init_kwargs=None, name='Embedding')[source]
Initializes an Embedding block.
- Parameters
embed_size – Size of the edge embedding.
n_species – Number of different atom species the network is supposed to process.
type_embed_size – Embedding size of atom type embedding. Default None results in embed_size / 2.
activation – Activation function
init_kwargs – Dict of initialization kwargs for Linear layers
name – Name of Embedding block
- class jax_dimenet.layers.OutputBlock(*args, **kwargs)[source]
DimeNet++ Output block.
Predicts per-atom quantities given RBF embeddings and messages.
- __init__(embed_size, out_embed_size=None, num_dense=3, num_targets=1, activation=<CompiledFunction of <function silu>>, init_kwargs=None, name='Output')[source]
Initializes an Output block.
- Parameters
embed_size – Size of the edge embedding.
out_embed_size – Output size of Linear layers after upsampling
num_dense – Number of dense layers
num_targets – Number of target quantities to be predicted
activation – Activation function
init_kwargs – Dict of initialization kwargs for Linear layers
name – Name of Output block
- class jax_dimenet.layers.InteractionBlock(*args, **kwargs)[source]
DimeNet++ Interaction block.
Performs directional message-passing based on RBF and SBF embeddings as well as messages from the previous message-passing iteration. Updated messages are used in the subsequent Output block.
- __init__(embed_size, num_res_before_skip, num_res_after_skip, activation=<CompiledFunction of <function silu>>, init_kwargs=None, angle_int_embed_size=None, basis_int_embed_size=8, name='Interaction')[source]
Initializes an Interaction block.
- Parameters
embed_size – Size of the edge embedding.
num_res_before_skip – Number of Residual blocks before skip
num_res_after_skip – Number of Residual blocks after skip
activation – Activation function
init_kwargs – Dict of initialization kwargs for Linear layers
angle_int_embed_size – Embedding size of Linear layers for down-projected triplet interation
basis_int_embed_size – Embedding size of Linear layers for interation of RBS/ SBF basis
name – Name of Interaction block
Utility Functions
|
Implements the jax.ops.segment_sum, but casts input to float64 before summation and casts back to a target output type afterwards (float32 by default). |