Source code for flox.aggregations

from __future__ import annotations

import copy
import logging
import warnings
from dataclasses import dataclass
from functools import cached_property, partial
from typing import TYPE_CHECKING, Any, Callable, Literal, TypedDict

import numpy as np
from numpy.typing import ArrayLike, DTypeLike

from . import aggregate_flox, aggregate_npg, xrutils
from . import xrdtypes as dtypes

if TYPE_CHECKING:
    FuncTuple = tuple[Callable | str, ...]
    OptionalFuncTuple = tuple[Callable | str | None, ...]


logger = logging.getLogger("flox")


def _is_arg_reduction(func: str | Aggregation) -> bool:
    if isinstance(func, str) and func in ["argmin", "argmax", "nanargmax", "nanargmin"]:
        return True
    if isinstance(func, Aggregation) and func.reduction_type == "argreduce":
        return True
    return False


class AggDtypeInit(TypedDict):
    final: DTypeLike | None
    intermediate: tuple[DTypeLike, ...]


class AggDtype(TypedDict):
    user: DTypeLike | None
    final: np.dtype
    numpy: tuple[np.dtype | type[np.intp], ...]
    intermediate: tuple[np.dtype | type[np.intp], ...]


def get_npg_aggregation(func, *, engine):
    try:
        method_ = getattr(aggregate_npg, func)
        method = partial(method_, engine=engine)
    except AttributeError:
        aggregate = aggregate_npg._get_aggregate(engine).aggregate
        method = partial(aggregate, func=func)
    return method


def generic_aggregate(
    group_idx,
    array,
    *,
    engine: str,
    func: str,
    axis=-1,
    size=None,
    fill_value=None,
    dtype=None,
    **kwargs,
):
    if engine == "flox":
        try:
            method = getattr(aggregate_flox, func)
        except AttributeError:
            logger.debug(f"Couldn't find {func} for engine='flox'. Falling back to numpy")
            method = get_npg_aggregation(func, engine="numpy")

    elif engine == "numbagg":
        from . import aggregate_numbagg

        try:
            if "var" in func or "std" in func:
                ddof = kwargs.get("ddof", 0)
                if aggregate_numbagg.NUMBAGG_SUPPORTS_DDOF or (ddof != 0):
                    method = getattr(aggregate_numbagg, func)
                else:
                    logger.debug(f"numbagg too old for ddof={ddof}. Falling back to numpy")
                    method = get_npg_aggregation(func, engine="numpy")
            else:
                method = getattr(aggregate_numbagg, func)

        except AttributeError:
            logger.debug(f"Couldn't find {func} for engine='numbagg'. Falling back to numpy")
            method = get_npg_aggregation(func, engine="numpy")

    elif engine in ["numpy", "numba"]:
        method = get_npg_aggregation(func, engine=engine)

    else:
        raise ValueError(
            f"Expected engine to be one of ['flox', 'numpy', 'numba', 'numbagg']. Received {engine} instead."
        )

    group_idx = np.asarray(group_idx, like=array)

    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered")
        result = method(
            group_idx, array, axis=axis, size=size, fill_value=fill_value, dtype=dtype, **kwargs
        )
    return result


def _normalize_dtype(dtype: DTypeLike, array_dtype: np.dtype, fill_value=None) -> np.dtype:
    if dtype is None:
        dtype = array_dtype
    if dtype is np.floating:
        # mean, std, var always result in floating
        # but we preserve the array's dtype if it is floating
        if array_dtype.kind in "fcmM":
            dtype = array_dtype
        else:
            dtype = np.dtype("float64")
    elif not isinstance(dtype, np.dtype):
        dtype = np.dtype(dtype)
    if fill_value not in [None, dtypes.INF, dtypes.NINF, dtypes.NA]:
        dtype = np.result_type(dtype, fill_value)
    return dtype


def _get_fill_value(dtype, fill_value):
    """Returns dtype appropriate infinity. Returns +Inf equivalent for None."""
    if fill_value == dtypes.INF or fill_value is None:
        return dtypes.get_pos_infinity(dtype, max_for_int=True)
    if fill_value == dtypes.NINF:
        return dtypes.get_neg_infinity(dtype, min_for_int=True)
    if fill_value == dtypes.NA:
        if np.issubdtype(dtype, np.floating):
            return np.nan
        # This is madness, but npg checks that fill_value is compatible
        # with array dtype even if the fill_value is never used.
        elif np.issubdtype(dtype, np.integer):
            return dtypes.get_neg_infinity(dtype, min_for_int=True)
        else:
            return None
    return fill_value


def _atleast_1d(inp, min_length: int = 1):
    if xrutils.is_scalar(inp):
        inp = (inp,) * min_length
    assert len(inp) >= min_length
    return inp


def returns_empty_tuple(*args, **kwargs):
    return ()


@dataclass
class Dim:
    values: ArrayLike
    name: str | None

    @cached_property
    def is_scalar(self) -> bool:
        return xrutils.is_scalar(self.values)

    @cached_property
    def size(self) -> int:
        return 0 if self.is_scalar else len(self.values)  # type: ignore[arg-type]


[docs] class Aggregation:
[docs] def __init__( self, name, *, numpy: str | FuncTuple | None = None, chunk: str | FuncTuple | None, combine: str | FuncTuple | None, preprocess: Callable | None = None, finalize: Callable | None = None, fill_value=None, final_fill_value=dtypes.NA, dtypes=None, final_dtype: DTypeLike | None = None, reduction_type: Literal["reduce", "argreduce"] = "reduce", new_dims_func: Callable | None = None, ): """ Blueprint for computing grouped aggregations. See aggregations.py for examples on how to specify reductions. Attributes ---------- name : str Name of reduction. numpy : str or callable, optional Reduction function applied to numpy inputs. This function should compute the grouped reduction and must have a specific signature. If string, these must be "native" reductions implemented by the backend engines (numpy_groupies, flox, numbagg). If None, will be set to ``name``. chunk : None or str or tuple of str or callable or tuple of callable For dask inputs only. Either a single function or a list of functions to be applied blockwise on the input dask array. If None, will raise an error for dask inputs. combine : None or str or tuple of str or callbe or tuple of callable For dask inputs only. Functions applied when combining intermediate results from the blockwise stage (see ``chunk``). If None, will raise an error for dask inputs. finalize : callable For dask inputs only. Function that combines intermediate results to compute final result. preprocess : callable For dask inputs only. Preprocess inputs before ``chunk`` stage. reduction_type : {"reduce", "argreduce"} Type of reduction. fill_value : number or tuple(number), optional Value to use when a group has no members. If single value will be converted to tuple of same length as chunk. If appropriate, provide a different fill_value per reduction in ``chunk`` as a tuple. final_fill_value : optional fill_value for final result. dtypes : DType or tuple(DType), optional dtypes for intermediate results. If single value, will be converted to a tuple of same length as chunk. If appropriate, provide a different fill_value per reduction in ``chunk`` as a tuple. final_dtype : DType, optional DType for output. By default, uses dtype of array being reduced. new_dims_func: Callable Function that receives finalize_kwargs and returns a tupleof sizes of any new dimensions added by the reduction. For e.g. quantile for q=(0.5, 0.85) adds a new dimension of size 2, so returns (2,) """ self.name = name # preprocess before blockwise self.preprocess = preprocess # Use "chunk_reduce" or "chunk_argreduce" self.reduction_type = reduction_type self.numpy: FuncTuple = (numpy,) if numpy else (self.name,) # initialize blockwise reduction self.chunk: OptionalFuncTuple = _atleast_1d(chunk) # how to aggregate results after first round of reduction self.combine: OptionalFuncTuple = _atleast_1d(combine) # simpler reductions used with the "simple combine" algorithm self.simple_combine: OptionalFuncTuple = () # finalize results (see mean) self.finalize: Callable | None = finalize self.fill_value = {} # This is used for the final reindexing self.fill_value[name] = final_fill_value # Aggregation.fill_value is used to reindex to group labels # at the *intermediate* step. # They should make sense when aggregated together with results from other blocks self.fill_value["intermediate"] = self._normalize_dtype_fill_value(fill_value, "fill_value") self.dtype_init: AggDtypeInit = { "final": final_dtype, "intermediate": self._normalize_dtype_fill_value(dtypes, "dtype"), } self.dtype: AggDtype = None # type: ignore[assignment] # The following are set by _initialize_aggregation self.finalize_kwargs: dict[Any, Any] = {} self.min_count: int = 0 self.new_dims_func: Callable = ( returns_empty_tuple if new_dims_func is None else new_dims_func )
@cached_property def new_dims(self) -> tuple[Dim]: return self.new_dims_func(**self.finalize_kwargs) @cached_property def num_new_vector_dims(self) -> int: return len(tuple(dim for dim in self.new_dims if not dim.is_scalar)) def _normalize_dtype_fill_value(self, value, name): value = _atleast_1d(value) if len(value) == 1 and len(value) < len(self.chunk): value = value * len(self.chunk) if len(value) != len(self.chunk): raise ValueError(f"Bad {name} specified for Aggregation {name}.") return value def __dask_tokenize__(self): return ( Aggregation, self.name, self.preprocess, self.reduction_type, self.numpy, self.chunk, self.combine, self.finalize, self.fill_value, self.dtype, ) def __repr__(self) -> str: return "\n".join( ( f"{self.name!r}, fill: {self.fill_value.values()!r}, dtype: {self.dtype}", f"chunk: {self.chunk!r}", f"combine: {self.combine!r}", f"finalize: {self.finalize!r}", f"min_count: {self.min_count!r}", ) )
count = Aggregation( "count", numpy="nanlen", chunk="nanlen", combine="sum", fill_value=0, final_fill_value=0, dtypes=np.intp, final_dtype=np.intp, ) # note that the fill values are the result of np.func([np.nan, np.nan]) # final_fill_value is used for groups that don't exist. This is usually np.nan sum_ = Aggregation("sum", chunk="sum", combine="sum", fill_value=0) nansum = Aggregation("nansum", chunk="nansum", combine="sum", fill_value=0) prod = Aggregation("prod", chunk="prod", combine="prod", fill_value=1, final_fill_value=1) nanprod = Aggregation("nanprod", chunk="nanprod", combine="prod", fill_value=1) def _mean_finalize(sum_, count): with np.errstate(invalid="ignore", divide="ignore"): return sum_ / count mean = Aggregation( "mean", chunk=("sum", "nanlen"), combine=("sum", "sum"), finalize=_mean_finalize, fill_value=(0, 0), dtypes=(None, np.intp), final_dtype=np.floating, ) nanmean = Aggregation( "nanmean", chunk=("nansum", "nanlen"), combine=("sum", "sum"), finalize=_mean_finalize, fill_value=(0, 0), dtypes=(None, np.intp), final_dtype=np.floating, ) # TODO: fix this for complex numbers def _var_finalize(sumsq, sum_, count, ddof=0): with np.errstate(invalid="ignore", divide="ignore"): result = (sumsq - (sum_**2 / count)) / (count - ddof) result[count <= ddof] = np.nan return result def _std_finalize(sumsq, sum_, count, ddof=0): return np.sqrt(_var_finalize(sumsq, sum_, count, ddof)) # var, std always promote to float, so we set nan var = Aggregation( "var", chunk=("sum_of_squares", "sum", "nanlen"), combine=("sum", "sum", "sum"), finalize=_var_finalize, fill_value=0, final_fill_value=np.nan, dtypes=(None, None, np.intp), final_dtype=np.floating, ) nanvar = Aggregation( "nanvar", chunk=("nansum_of_squares", "nansum", "nanlen"), combine=("sum", "sum", "sum"), finalize=_var_finalize, fill_value=0, final_fill_value=np.nan, dtypes=(None, None, np.intp), final_dtype=np.floating, ) std = Aggregation( "std", chunk=("sum_of_squares", "sum", "nanlen"), combine=("sum", "sum", "sum"), finalize=_std_finalize, fill_value=0, final_fill_value=np.nan, dtypes=(None, None, np.intp), final_dtype=np.floating, ) nanstd = Aggregation( "nanstd", chunk=("nansum_of_squares", "nansum", "nanlen"), combine=("sum", "sum", "sum"), finalize=_std_finalize, fill_value=0, final_fill_value=np.nan, dtypes=(None, None, np.intp), final_dtype=np.floating, ) min_ = Aggregation("min", chunk="min", combine="min", fill_value=dtypes.INF) nanmin = Aggregation("nanmin", chunk="nanmin", combine="nanmin", fill_value=np.nan) max_ = Aggregation("max", chunk="max", combine="max", fill_value=dtypes.NINF) nanmax = Aggregation("nanmax", chunk="nanmax", combine="nanmax", fill_value=np.nan) def argreduce_preprocess(array, axis): """Returns a tuple of array, index along axis. Copied from dask.array.chunk.argtopk_preprocess """ import dask.array import numpy as np # TODO: arg reductions along multiple axes seems weird. assert len(axis) == 1 axis = axis[0] idx = dask.array.arange(array.shape[axis], chunks=array.chunks[axis], dtype=np.intp) # broadcast (TODO: is this needed?) idx = idx[tuple(slice(None) if i == axis else np.newaxis for i in range(array.ndim))] def _zip_index(array_, idx_): return (array_, idx_) return dask.array.map_blocks( _zip_index, array, idx, dtype=array.dtype, meta=array._meta, name="groupby-argreduce-preprocess", ) def _pick_second(*x): return x[1] argmax = Aggregation( "argmax", preprocess=argreduce_preprocess, chunk=("max", "argmax"), # order is important combine=("max", "argmax"), reduction_type="argreduce", fill_value=(dtypes.NINF, 0), final_fill_value=-1, finalize=_pick_second, dtypes=(None, np.intp), final_dtype=np.intp, ) argmin = Aggregation( "argmin", preprocess=argreduce_preprocess, chunk=("min", "argmin"), # order is important combine=("min", "argmin"), reduction_type="argreduce", fill_value=(dtypes.INF, 0), final_fill_value=-1, finalize=_pick_second, dtypes=(None, np.intp), final_dtype=np.intp, ) nanargmax = Aggregation( "nanargmax", preprocess=argreduce_preprocess, chunk=("nanmax", "nanargmax"), # order is important combine=("max", "argmax"), reduction_type="argreduce", fill_value=(dtypes.NINF, 0), final_fill_value=-1, finalize=_pick_second, dtypes=(None, np.intp), final_dtype=np.intp, ) nanargmin = Aggregation( "nanargmin", preprocess=argreduce_preprocess, chunk=("nanmin", "nanargmin"), # order is important combine=("min", "argmin"), reduction_type="argreduce", fill_value=(dtypes.INF, 0), final_fill_value=-1, finalize=_pick_second, dtypes=(None, np.intp), final_dtype=np.intp, ) first = Aggregation("first", chunk=None, combine=None, fill_value=0) last = Aggregation("last", chunk=None, combine=None, fill_value=0) nanfirst = Aggregation("nanfirst", chunk="nanfirst", combine="nanfirst", fill_value=np.nan) nanlast = Aggregation("nanlast", chunk="nanlast", combine="nanlast", fill_value=np.nan) all_ = Aggregation( "all", chunk="all", combine="all", fill_value=True, final_fill_value=False, dtypes=bool, final_dtype=bool, ) any_ = Aggregation( "any", chunk="any", combine="any", fill_value=False, final_fill_value=False, dtypes=bool, final_dtype=bool, ) # Support statistical quantities only blockwise # The parallel versions will be approximate and are hard to implement! median = Aggregation( name="median", fill_value=dtypes.NA, chunk=None, combine=None, final_dtype=np.float64 ) nanmedian = Aggregation( name="nanmedian", fill_value=dtypes.NA, chunk=None, combine=None, final_dtype=np.float64 ) def quantile_new_dims_func(q) -> tuple[Dim]: return (Dim(name="quantile", values=q),) quantile = Aggregation( name="quantile", fill_value=dtypes.NA, chunk=None, combine=None, final_dtype=np.float64, new_dims_func=quantile_new_dims_func, ) nanquantile = Aggregation( name="nanquantile", fill_value=dtypes.NA, chunk=None, combine=None, final_dtype=np.float64, new_dims_func=quantile_new_dims_func, ) mode = Aggregation(name="mode", fill_value=dtypes.NA, chunk=None, combine=None) nanmode = Aggregation(name="nanmode", fill_value=dtypes.NA, chunk=None, combine=None) aggregations = { "any": any_, "all": all_, "count": count, "sum": sum_, "nansum": nansum, "prod": prod, "nanprod": nanprod, "mean": mean, "nanmean": nanmean, "var": var, "nanvar": nanvar, "std": std, "nanstd": nanstd, "max": max_, "nanmax": nanmax, "min": min_, "nanmin": nanmin, "argmax": argmax, "nanargmax": nanargmax, "argmin": argmin, "nanargmin": nanargmin, "first": first, "nanfirst": nanfirst, "last": last, "nanlast": nanlast, "median": median, "nanmedian": nanmedian, "quantile": quantile, "nanquantile": nanquantile, "mode": mode, "nanmode": nanmode, } def _initialize_aggregation( func: str | Aggregation, dtype, array_dtype, fill_value, min_count: int, finalize_kwargs: dict[Any, Any] | None, ) -> Aggregation: if not isinstance(func, Aggregation): try: # TODO: need better interface # we set dtype, fillvalue on reduction later. so deepcopy now agg = copy.deepcopy(aggregations[func]) except KeyError: raise NotImplementedError(f"Reduction {func!r} not implemented yet") elif isinstance(func, Aggregation): # TODO: test that func is a valid Aggregation agg = copy.deepcopy(func) func = agg.name else: raise ValueError("Bad type for func. Expected str or Aggregation") # np.dtype(None) == np.dtype("float64")!!! # so check for not None dtype_: np.dtype | None = ( np.dtype(dtype) if dtype is not None and not isinstance(dtype, np.dtype) else dtype ) final_dtype = _normalize_dtype(dtype_ or agg.dtype_init["final"], array_dtype, fill_value) agg.dtype = { "user": dtype, # Save to automatically choose an engine "final": final_dtype, "numpy": (final_dtype,), "intermediate": tuple( ( _normalize_dtype(int_dtype, np.result_type(array_dtype, final_dtype), int_fv) if int_dtype is None else np.dtype(int_dtype) ) for int_dtype, int_fv in zip( agg.dtype_init["intermediate"], agg.fill_value["intermediate"] ) ), } # Replace sentinel fill values according to dtype agg.fill_value["user"] = fill_value agg.fill_value["intermediate"] = tuple( _get_fill_value(dt, fv) for dt, fv in zip(agg.dtype["intermediate"], agg.fill_value["intermediate"]) ) agg.fill_value[func] = _get_fill_value(agg.dtype["final"], agg.fill_value[func]) fv = fill_value if fill_value is not None else agg.fill_value[agg.name] if _is_arg_reduction(agg): # this allows us to unravel_index easily. we have to do that nearly every time. agg.fill_value["numpy"] = (0,) else: agg.fill_value["numpy"] = (fv,) if finalize_kwargs is not None: assert isinstance(finalize_kwargs, dict) agg.finalize_kwargs = finalize_kwargs # This is needed for the dask pathway. # Because we use intermediate fill_value since a group could be # absent in one block, but present in another block # We set it for numpy to get nansum, nanprod tests to pass # where the identity element is 0, 1 if min_count > 0: agg.min_count = min_count agg.numpy += ("nanlen",) if agg.chunk != (None,): agg.chunk += ("nanlen",) agg.combine += ("sum",) agg.fill_value["intermediate"] += (0,) agg.fill_value["numpy"] += (0,) agg.dtype["intermediate"] += (np.intp,) agg.dtype["numpy"] += (np.intp,) else: agg.min_count = 0 simple_combine: list[Callable | None] = [] for combine in agg.combine: if isinstance(combine, str): if combine in ["nanfirst", "nanlast"]: simple_combine.append(getattr(xrutils, combine)) else: simple_combine.append(getattr(np, combine)) else: simple_combine.append(combine) agg.simple_combine = tuple(simple_combine) return agg