jax_dimenet.layers.high_precision_segment_sum

jax_dimenet.layers.high_precision_segment_sum(data, segment_ids, num_segments=None, out_type=<class 'jax._src.numpy.lax_numpy.float32'>, indices_are_sorted=False, unique_indices=False, bucket_size=None)[source]

Implements the jax.ops.segment_sum, but casts input to float64 before summation and casts back to a target output type afterwards (float32 by default). Used to inprove numerical accuracy of summation.