dense_to_sparse

phasic.svgd.dense_to_sparse(data)

Convert dense NaN-padded array to sparse observation format.

Parameters

data : jnp.ndarray

Dense 2D array of shape (n_times, n_features) where NaN indicates missing observations.

Returns

: SparseObservations

Sparse representation with only valid observations, sorted by feature with pre-computed slices for JAX JIT compatibility.

Examples

>>> dense = jnp.array([
...     [1.0, np.nan, 3.0],
...     [np.nan, 2.0, np.nan],
...     [1.5, 2.5, 3.5]
... ])
>>> sparse = dense_to_sparse(dense)
>>> print(sparse.values)   # [1.0, 1.5, 2.0, 2.5, 3.0, 3.5]
>>> print(sparse.features) # [0, 0, 1, 1, 2, 2]
>>> print(sparse.n_features)  # 3