SparseObservations
phasic.svgd.SparseObservations()Sparse representation of multivariate observations.
Replaces dense NaN-padded format with parallel arrays containing only valid observations. This avoids NaN propagation through JAX callbacks during gradient computation.
Attributes
values :jnp.ndarray-
Observation values (n_obs,) - no NaN values
features :jnp.ndarray-
Feature index for each observation (n_obs,) - integers
n_features :int-
Total number of features (for rewards indexing)
slices : tuple of tuples, optional-
Pre-computed (start, end) slices for each feature to avoid dynamic boolean indexing in JIT-compiled code. If provided, observations must be sorted by feature index.
Examples
>>> # 10 observations for feature 0, 10 for feature 1, 10 for feature 2
>>> sparse = SparseObservations(
... values=jnp.array([1.1, 1.2, ..., 2.1, 2.2, ..., 3.1, 3.2, ...]),
... features=jnp.array([0, 0, ..., 1, 1, ..., 2, 2, ...]),
... n_features=3
... )See Also
dense_to_sparse : Convert dense NaN-padded array to sparse format
Methods
| Name | Description |
|---|---|
| get_feature_values | Get observation values for a specific feature. |
get_feature_values
phasic.svgd.SparseObservations.get_feature_values(feature_idx)Get observation values for a specific feature.
Uses pre-computed slices if available (JAX JIT compatible), otherwise falls back to boolean indexing (not JIT compatible).
Parameters
feature_idx :int-
Feature index
Returns
:jnp.ndarray-
Observation values for this feature