SparseObservations

phasic.svgd.SparseObservations()

Sparse representation of multivariate observations.

Replaces dense NaN-padded format with parallel arrays containing only valid observations. This avoids NaN propagation through JAX callbacks during gradient computation.

Attributes

values : jnp.ndarray

Observation values (n_obs,) - no NaN values

features : jnp.ndarray

Feature index for each observation (n_obs,) - integers

n_features : int

Total number of features (for rewards indexing)

slices : tuple of tuples, optional

Pre-computed (start, end) slices for each feature to avoid dynamic boolean indexing in JIT-compiled code. If provided, observations must be sorted by feature index.

Examples

>>> # 10 observations for feature 0, 10 for feature 1, 10 for feature 2
>>> sparse = SparseObservations(
...     values=jnp.array([1.1, 1.2, ..., 2.1, 2.2, ..., 3.1, 3.2, ...]),
...     features=jnp.array([0, 0, ..., 1, 1, ..., 2, 2, ...]),
...     n_features=3
... )

See Also

dense_to_sparse : Convert dense NaN-padded array to sparse format

Methods

Name Description
get_feature_values Get observation values for a specific feature.

get_feature_values

phasic.svgd.SparseObservations.get_feature_values(feature_idx)

Get observation values for a specific feature.

Uses pre-computed slices if available (JAX JIT compatible), otherwise falls back to boolean indexing (not JIT compatible).

Parameters

feature_idx : int

Feature index

Returns

: jnp.ndarray

Observation values for this feature