Custom Aggregations

This notebook is motivated by a post on the Pangeo discourse forum.

Even better would be a command that lets me simply do the following.

A = da.groupby(['lon_bins', 'lat_bins']).mode()

This notebook will describe how to accomplish this using a custom Aggregation.

Tip

flox now supports mode, nanmode, quantile, nanquantile, median, nanmedian using exactly the same approach as shown below

import numpy as np
import numpy_groupies as npg
import xarray as xr

import flox.xarray
from flox import Aggregation
from flox.aggregations import mean

# define latitude and longitude bins
binsize = 1.0  # 1°x1° bins
lon_min, lon_max, lat_min, lat_max = [-180, 180, -65, 65]
lon_bins = np.arange(lon_min, lon_max, binsize)
lat_bins = np.arange(lat_min, lat_max, binsize)

size = 28397


da = xr.DataArray(
    np.random.randint(0, 7, size=size),
    dims="profile",
    coords={
        "lat": (
            "profile",
            (np.random.random(size) - 0.5) * (lat_max - lat_min),
        ),
        "lon": (
            "profile",
            (np.random.random(size) - 0.5) * (lon_max - lon_min),
        ),
    },
    name="label",
)
da
<xarray.DataArray 'label' (profile: 28397)> Size: 227kB
array([2, 2, 2, ..., 5, 2, 2])
Coordinates:
    lat      (profile) float64 227kB 38.18 41.5 52.25 ... -21.61 -52.45 -53.08
    lon      (profile) float64 227kB -128.7 -14.27 -15.08 ... -20.95 -143.9
Dimensions without coordinates: profile

A built-in reduction

First a simple example of lat-lon binning using a built-in reduction: mean

binned_mean = flox.xarray.xarray_reduce(
    da,
    da.lat,
    da.lon,
    func="mean",  # built-in
    expected_groups=(lat_bins, lon_bins),
    isbin=(True, True),
)
binned_mean.plot()
<matplotlib.collections.QuadMesh at 0x7f093b31f800>
../_images/298c41cd034b01f61ea0ae4aa7814f6c93a1acab0008a7ba04a8338348d5b779.png

Aggregations

flox knows how to interperet func="mean" because it’s been implemented in aggregations.py as an Aggregation

An Aggregation is a blueprint for computing an aggregation, with both numpy and dask data.

print(type(mean))
mean
<class 'flox.aggregations.Aggregation'>
'mean', fill: dict_values([<NA>, (0, 0)]), dtype: None
chunk: ('sum', 'nanlen')
combine: ('sum', 'sum')
finalize: <function _mean_finalize at 0x7f093d17d080>
min_count: 0

Here’s how the mean Aggregation is created

mean = Aggregation(
    name="mean",

    # strings in the following are built-in grouped reductions
    # implemented by the underlying  "engine": flox or numpy_groupies or numbagg

    # for pure  numpy inputs
    numpy="mean",

    # The next are for dask inputs and describe how to reduce
    # the data in parallel
    chunk=("sum", "nanlen"), # first compute these blockwise : (grouped_sum, grouped_count)
    combine=("sum", "sum"), #  reduce intermediate results (sum the sums, sum the counts)
    finalize=lambda sum_, count: sum_ / count, # final mean value (divide sum by count)

    fill_value=(0, 0),  # fill value for intermediate  sums and counts when groups have no members
    dtypes=(None, np.intp),  # optional dtypes for intermediates
    final_dtype=np.floating,  # final dtype for output
)

Defining a custom aggregation

First we’ll need a function that executes the grouped reduction given numpy inputs.

Custom functions are required to have this signature (copied form numpy_groupies):


def custom_grouped_reduction(
    group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None
):
    """
    Parameters
    ----------

    group_idx : np.ndarray, 1D
        integer codes for group labels (1D)
    array : np.ndarray, nD
        values to reduce (nD)
    axis : int
        axis of array along which to reduce. Requires array.shape[axis] == len(group_idx)
    size : int, optional
        expected number of groups. If none, output.shape[-1] == number of uniques in group_idx
    fill_value : optional
        fill_value for when number groups in group_idx is less than size
    dtype : optional
        dtype of output

    Returns
    -------

    np.ndarray with array.shape[-1] == size, containing a single value per group
    """
    pass

Since numpy_groupies does not implement a median, we’ll do it ourselves by passing np.median to numpy_groupies.aggregate_numpy.aggregate. This will loop over all groups, and then execute np.median on the group members in serial. It is not fast, but quite convenient.

def grouped_median(group_idx, array, *, axis=-1, size=None, fill_value=None, dtype=None):
    return npg.aggregate_numpy.aggregate(
        group_idx,
        array,
        func=np.median,
        axis=axis,
        size=size,
        fill_value=fill_value,
        dtype=dtype,
    )

Now we create the Aggregation

agg_median = Aggregation(
    name="median",
    numpy=grouped_median,
    fill_value=-1,
    chunk=None,
    combine=None,
)
agg_median
'median', fill: dict_values([<NA>, (-1,)]), dtype: None
chunk: (None,)
combine: (None,)
finalize: None
min_count: 0

And apply it!

flox.xarray.xarray_reduce(
    da,
    da.lat,
    da.lon,
    func=agg_median,
    expected_groups=(lat_bins, lon_bins),
    isbin=(True, True),
    fill_value=np.nan,
)
<xarray.DataArray 'label' (lat_bins: 129, lon_bins: 359)> Size: 370kB
array([[0. , 5. , 1. , ..., nan, 3.5, nan],
       [nan, nan, nan, ..., nan, 3.5, nan],
       [4. , nan, nan, ..., nan, 3. , nan],
       ...,
       [4. , 4. , nan, ..., nan, 4. , nan],
       [nan, nan, nan, ..., nan, 6. , nan],
       [nan, nan, 0. , ..., nan, nan, nan]])
Coordinates:
  * lat_bins  (lat_bins) object 1kB (-65.0, -64.0] ... (63.0, 64.0]
  * lon_bins  (lon_bins) object 3kB (-180.0, -179.0] ... (178.0, 179.0]