from __future__ import annotations

from import Hashable, Iterable, Sequence
from typing import TYPE_CHECKING, Any, Union

import numpy as np
import pandas as pd
import xarray as xr
from packaging.version import Version
from xarray.core.duck_array_ops import _datetime_nanmin

from .aggregations import Aggregation, Dim, _atleast_1d, quantile_new_dims_func
from .core import (
from .core import rechunk_for_blockwise as rechunk_array_for_blockwise
from .core import rechunk_for_cohorts as rechunk_array_for_cohorts
from .xrutils import _contains_cftime_datetimes, _to_pytimedelta, datetime_to_numeric

    from xarray.core.types import T_DataArray, T_Dataset

    from .core import T_ExpectedGroupsOpt, T_ExpectIndex, T_ExpectOpt

    Dims = Union[str, Iterable[Hashable], None]

def _restore_dim_order(result, obj, by, no_groupby_reorder=False):
    def lookup_order(dimension):
        if dimension == and by.ndim == 1:
            (dimension,) = by.dims
            if no_groupby_reorder:
                return -1e6  # some arbitrarily low value
        if dimension in obj.dims:
            axis = obj.get_axis_num(dimension)
            axis = 1e6  # some arbitrarily high value
        return axis

    new_order = sorted(result.dims, key=lookup_order)
    return result.transpose(*new_order)

def _broadcast_size_one_dims(*arrays, core_dims):
    """Broadcast by adding size-1 dimensions in the right place.

    Workaround because apply_ufunc doesn't support this yet.

    Specialized to the groupby problem.
    array_dims = set(core_dims[0])
    broadcasted = [arrays[0]]
    for dims, array in zip(core_dims[1:], arrays[1:]):
        assert set(dims).issubset(array_dims)
        order = [dims.index(d) for d in core_dims[0] if d in dims]
        array = array.transpose(*order)
        axis = [core_dims[0].index(d) for d in core_dims[0] if d not in dims]
        broadcasted.append(np.expand_dims(array, axis))

    return broadcasted

[docs] def xarray_reduce( obj: T_Dataset | T_DataArray, *by: T_DataArray | Hashable, func: str | Aggregation, expected_groups: T_ExpectedGroupsOpt = None, isbin: bool | Sequence[bool] = False, sort: bool = True, dim: Dims | ellipsis = None, fill_value=None, dtype: np.typing.DTypeLike = None, method: str | None = None, engine: str | None = None, keep_attrs: bool | None = True, skipna: bool | None = None, min_count: int | None = None, reindex: bool | None = None, **finalize_kwargs, ): """GroupBy reduce operations on xarray objects using numpy-groupies Parameters ---------- obj : DataArray or Dataset Xarray object to reduce *by : DataArray or iterable of str or iterable of DataArray Variables with which to group by ``obj`` func : {"all", "any", "count", "sum", "nansum", "mean", "nanmean", \ "max", "nanmax", "min", "nanmin", "argmax", "nanargmax", "argmin", "nanargmin", \ "quantile", "nanquantile", "median", "nanmedian", "mode", "nanmode", \ "first", "nanfirst", "last", "nanlast"} or Aggregation Single function name or an Aggregation instance expected_groups : str or sequence expected group labels corresponding to each `by` variable isbin : iterable of bool If True, corresponding entry in ``expected_groups`` are bin edges. If False, the entry in ``expected_groups`` is treated as a simple label. sort : (optional), bool Whether groups should be returned in sorted order. Only applies for dask reductions when ``method`` is not ``"map-reduce"``. For ``"map-reduce"``, the groups are always sorted. dim : hashable dimension name along which to reduce. If None, reduces across all dimensions of `by` fill_value Value used for missing groups in the output i.e. when one of the labels in ``expected_groups`` is not actually present in ``by``. dtype : data-type, optional DType for the output. Can be anything accepted by ``np.dtype``. method : {"map-reduce", "blockwise", "cohorts"}, optional Strategy for reduction of dask arrays only: * ``"map-reduce"``: First apply the reduction blockwise on ``array``, then combine a few newighbouring blocks, apply the reduction. Continue until finalizing. Usually, ``func`` will need to be an Aggregation instance for this method to work. Common aggregations are implemented. * ``"blockwise"``: Only reduce using blockwise and avoid aggregating blocks together. Useful for resampling-style reductions where group members are always together. If `by` is 1D, `array` is automatically rechunked so that chunk boundaries line up with group boundaries i.e. each block contains all members of any group present in that block. For nD `by`, you must make sure that all members of a group are present in a single block. * ``"cohorts"``: Finds group labels that tend to occur together ("cohorts"), indexes out cohorts and reduces that subset using "map-reduce", repeat for all cohorts. This works well for many time groupings where the group labels repeat at regular intervals like 'hour', 'month', dayofyear' etc. Optimize chunking ``array`` for this method by first rechunking using ``rechunk_for_cohorts`` (for 1D ``by`` only). engine : {"flox", "numpy", "numba"}, optional Algorithm to compute the groupby reduction on non-dask arrays and on each dask chunk: * ``"numpy"``: Use the vectorized implementations in ``numpy_groupies.aggregate_numpy``. This is the default choice because it works for other array types. * ``"flox"``: Use an internal implementation where the data is sorted so that all members of a group occur sequentially, and then numpy.ufunc.reduceat is to used for the reduction. This will fall back to ``numpy_groupies.aggregate_numpy`` for a reduction that is not yet implemented. * ``"numba"``: Use the implementations in ``numpy_groupies.aggregate_numba``. * ``"numbagg"``: Use the reductions supported by ``numbagg.grouped``. This will fall back to ``numpy_groupies.aggregate_numpy`` for a reduction that is not yet implemented. keep_attrs : bool, optional Preserve attrs? skipna : bool, optional If True, skip missing values (as marked by NaN). By default, only skips missing values for float dtypes; other dtypes either do not have a sentinel missing value (int) or ``skipna=True`` has not been implemented (object, datetime64 or timedelta64). min_count : int, default: None The required number of valid values to perform the operation. If fewer than min_count non-NA values are present the result will be NA. Only used if skipna is set to True or defaults to True for the array's dtype. reindex : bool, optional Whether to "reindex" the blockwise results to `expected_groups` (possibly automatically detected). If True, the intermediate result of the blockwise groupby-reduction has a value for all expected groups, and the final result is a simple reduction of those intermediates. In nearly all cases, this is a significant boost in computation speed. For cases like time grouping, this may result in large intermediates relative to the original block size. Avoid that by using method="cohorts". By default, it is turned off for arg reductions. **finalize_kwargs kwargs passed to the finalize function, like ``ddof`` for var, std or ``q`` for quantile. Returns ------- DataArray or Dataset Reduced object See Also -------- flox.core.groupby_reduce Raises ------ NotImplementedError ValueError Examples -------- >>> import xarray as xr >>> from flox.xarray import xarray_reduce >>> # Create a group index: >>> labels = xr.DataArray( ... [1, 2, 3, 1, 2, 3, 0, 0, 0], ... dims="x", ... name="label", ... ) >>> # Create a DataArray to apply the group index on: >>> da = da = xr.ones_like(labels) >>> # Sum all values in da that matches the elements in the group index: >>> xarray_reduce(da, labels, func="sum") <xarray.DataArray 'label' (label: 4)> Size: 32B array([3, 2, 2, 2]) Coordinates: * label (label) int64 32B 0 1 2 3 """ if skipna is not None and isinstance(func, Aggregation): raise ValueError("skipna must be None when func is an Aggregation.") nby = len(by) for b in by: if isinstance(b, xr.DataArray) and is None: raise ValueError("Cannot group by unnamed DataArrays.") # TODO: move to GroupBy._flox_reduce if keep_attrs is None: keep_attrs = True if isinstance(isbin, Sequence): isbins = isbin else: isbins = (isbin,) * nby expected_groups_valid = _validate_expected_groups(nby, expected_groups) if not sort: raise NotImplementedError("sort must be True for xarray_reduce") # eventually drop the variables we are grouping by maybe_drop = {b for b in by if isinstance(b, Hashable)} unindexed_dims = tuple( b for b, isbin_ in zip(by, isbins) if isinstance(b, Hashable) and not isbin_ and b in obj.dims and b not in obj.indexes ) by_da = tuple(obj[g] if isinstance(g, Hashable) else g for g in by) grouper_dims = [] for g in by_da: for d in g.dims: if d not in grouper_dims: grouper_dims.append(d) if isinstance(obj, xr.Dataset): ds = obj else: ds = obj._to_temp_dataset() try: from xarray.indexes import PandasMultiIndex except ImportError: PandasMultiIndex = tuple() # type: ignore[assignment, misc] more_drop = set() for var in maybe_drop: maybe_midx = ds._indexes.get(var, None) if isinstance(maybe_midx, PandasMultiIndex): idx_coord_names = set(tuple(maybe_midx.index.names) + (maybe_midx.dim,)) idx_other_names = idx_coord_names - set(maybe_drop) more_drop.update(idx_other_names) maybe_drop.update(more_drop) if dim is Ellipsis: if nby > 1: raise NotImplementedError("Multiple by are not allowed when dim is Ellipsis.") name_ = by_da[0].name if name_ in ds.dims and not isbins[0]: dim_tuple = tuple(d for d in obj.dims if d != name_) else: dim_tuple = tuple(obj.dims) elif dim is not None: dim_tuple = _atleast_1d(dim) else: dim_tuple = tuple(grouper_dims) # broadcast to make sure grouper dimensions are present in the array. exclude_dims = tuple(d for d in ds.dims if d not in grouper_dims and d not in dim_tuple) if any(d not in grouper_dims and d not in obj.dims for d in dim_tuple): raise ValueError(f"Cannot reduce over absent dimensions {dim}.") try: xr.align(ds, *by_da, join="exact", copy=False) except ValueError as e: raise ValueError( "Object being grouped must be exactly aligned with every array in `by`." ) from e needs_broadcast = any( not set(grouper_dims).issubset(set(variable.dims)) for variable in ds.data_vars.values() ) if needs_broadcast: ds_broad = xr.broadcast(ds, *by_da, exclude=exclude_dims)[0] else: ds_broad = ds dims_not_in_groupers = tuple(d for d in dim_tuple if d not in grouper_dims) if dims_not_in_groupers == tuple(dim_tuple) and not any(isbins): # reducing along a dimension along which groups do not vary # This is really just a normal reduction. # This is not right when binning so we exclude. if isinstance(func, str) and func.startswith("nan"): raise ValueError(f"Specify func={func[3:]}, skipna=True instead of func={func}") elif isinstance(func, Aggregation): raise NotImplementedError( "func must be a string when reducing along a dimension not present in `by`" ) # skipna is not supported for all reductions # kwargs = {"skipna": skipna} if skipna is not None else {} kwargs.update(finalize_kwargs) result = getattr(ds_broad, func)(dim=dim_tuple, **kwargs) if isinstance(obj, xr.DataArray): return obj._from_temp_dataset(result) else: return result ds = ds.drop_vars([var for var in maybe_drop if var in ds.variables]) axis = tuple(range(-len(dim_tuple), 0)) # Set expected_groups and convert to index since we need coords, sizes # for output xarray objects expected_groups_valid_list: list[T_ExpectIndex] = [] group_names: tuple[Any, ...] = () group_sizes: dict[Any, int] = {} for idx, (b_, expect, isbin_) in enumerate(zip(by_da, expected_groups_valid, isbins)): group_name = ( f"{}_bins" if isbin_ or isinstance(expect, pd.IntervalIndex) else ) group_names += (group_name,) if isbin_ and isinstance(expect, int): raise NotImplementedError( "flox does not support binning into an integer number of bins yet." ) expect1: T_ExpectOpt if expect is None: if isbin_: raise ValueError( f"Please provided bin edges for group variable {idx} " f"named {group_name} in expected_groups." ) expect1 = _get_expected_groups(, sort=sort) else: expect1 = expect expect_index = _convert_expected_groups_to_index((expect1,), (isbin_,), sort=sort)[0] # The if-check is for type hinting mainly, it narrows down the return # type of _convert_expected_groups_to_index to pure pd.Index: if expect_index is not None: expected_groups_valid_list.append(expect_index) group_sizes[group_name] = len(expect_index) else: # This will never be reached raise ValueError("expect_index cannot be None") def wrapper(array, *by, func, skipna, core_dims, **kwargs): array, *by = _broadcast_size_one_dims(array, *by, core_dims=core_dims) # Handle skipna here because I need to know dtype to make a good default choice. # We cannot handle this easily for xarray Datasets in xarray_reduce if skipna and func in ["all", "any", "count"]: raise ValueError(f"skipna cannot be truthy for {func} reductions.") if skipna or (skipna is None and isinstance(func, str) and array.dtype.kind in "cfO"): if "nan" not in func and func not in ["all", "any", "count"]: func = f"nan{func}" # Flox's count works with non-numeric and its faster than converting. requires_numeric = func not in ["count", "any", "all"] or ( func == "count" and kwargs["engine"] != "flox" ) if requires_numeric: is_npdatetime = array.dtype.kind in "Mm" is_cftime = _contains_cftime_datetimes(array) if is_npdatetime: offset = _datetime_nanmin(array) # xarray always uses np.datetime64[ns] for np.datetime64 data dtype = "timedelta64[ns]" array = datetime_to_numeric(array, offset) elif is_cftime: offset = array.min() array = datetime_to_numeric(array, offset, datetime_unit="us") result, *groups = groupby_reduce(array, *by, func=func, **kwargs) # Transpose the new quantile dimension to the end. This is ugly. # but new core dimensions are expected at the end :/ # but groupby_reduce inserts them at the beginning if func in ["quantile", "nanquantile"]: (newdim,) = quantile_new_dims_func(**finalize_kwargs) if not newdim.is_scalar: # NOTE: _restore_dim_order will move any new dims to the end anyway. # This transpose is simply makes it easy to specify output_core_dims # output dim order: (*broadcast_dims, *group_dims, quantile_dim) result = np.moveaxis(result, 0, -1) # Output of count has an int dtype. if requires_numeric and func != "count": if is_npdatetime: return result.astype(dtype) + offset elif is_cftime: return _to_pytimedelta(result, unit="us") + offset return result # These data variables do not have any of the core dimension, # take them out to prevent errors. # apply_ufunc can handle non-dim coordinate variables without core dimensions missing_dim = {} if isinstance(obj, xr.Dataset): # broadcasting means the group dim gets added to ds, so we check the original obj for k, v in obj.data_vars.items(): is_missing_dim = not (any(d in v.dims for d in dim_tuple)) if is_missing_dim: missing_dim[k] = v # dim_tuple contains dimensions we are reducing over. These need to be the last # core dimensions to be synchronized with axis. input_core_dims = [[d for d in grouper_dims if d not in dim_tuple] + list(dim_tuple)] input_core_dims += [list(b.dims) for b in by_da] newdims: tuple[Dim, ...] = ( quantile_new_dims_func(**finalize_kwargs) if func in ["quantile", "nanquantile"] else () ) output_core_dims = [d for d in input_core_dims[0] if d not in dim_tuple] output_core_dims.extend(group_names) vector_dims = [ for dim in newdims if not dim.is_scalar] output_core_dims.extend(vector_dims) output_sizes = group_sizes output_sizes.update({ dim.size for dim in newdims if dim.size != 0}) actual = xr.apply_ufunc( wrapper, ds_broad.drop_vars(tuple(missing_dim)).transpose(..., *grouper_dims), *by_da, input_core_dims=input_core_dims, # for xarray's test_groupby_duplicate_coordinate_labels exclude_dims=set(dim_tuple), output_core_dims=[output_core_dims], dask="allowed", dask_gufunc_kwargs=dict( output_sizes=output_sizes, output_dtypes=[dtype] if dtype is not None else None ), keep_attrs=keep_attrs, kwargs={ "func": func, "axis": axis, "sort": sort, "fill_value": fill_value, "method": method, "min_count": min_count, "skipna": skipna, "engine": engine, "reindex": reindex, "expected_groups": tuple(expected_groups_valid_list), "isbin": isbins, "finalize_kwargs": finalize_kwargs, "dtype": dtype, "core_dims": input_core_dims, }, ) # restore non-dim coord variables without the core dimension # TODO: shouldn't apply_ufunc handle this? for var in set(ds_broad._coord_names) - set(ds_broad._indexes) - set(ds_broad.dims): if all(d not in ds_broad[var].dims for d in dim_tuple): actual[var] = ds_broad[var] for newdim in newdims: actual.coords[] = newdim.values if newdim.is_scalar else np.array(newdim.values) expect3: T_ExpectIndex | np.ndarray for name, expect2, by_ in zip(group_names, expected_groups_valid_list, by_da): # Can't remove this until xarray handles IntervalIndex: if isinstance(expect2, pd.IntervalIndex): # TODO: Only place where expect3 is an ndarray, remove the type if xarray # starts supporting IntervalIndex. expect3 = expect2.to_numpy() else: expect3 = expect2 if isinstance(actual, xr.Dataset) and name in actual: actual = actual.drop_vars(name) # When grouping by MultiIndex, expect is an pd.Index wrapping # an object array of tuples if ( name in ds_broad.indexes and isinstance(ds_broad.indexes[name], pd.MultiIndex) and not isinstance(expect3, pd.RangeIndex) ): levelnames = ds_broad.indexes[name].names if isinstance(expect3, np.ndarray): # TODO: workaoround for IntervalIndex issue. raise NotImplementedError expect3 = pd.MultiIndex.from_tuples(expect3.values, names=levelnames) actual[name] = expect3 if Version(xr.__version__) > Version("2022.03.0"): actual = actual.set_coords(levelnames) else: actual[name] = expect3 if keep_attrs: actual[name].attrs = by_.attrs if unindexed_dims: actual = actual.drop_vars(unindexed_dims) if nby == 1: for var in actual: if isinstance(obj, xr.Dataset): template = obj[var] else: template = obj if actual[var].ndim > 1 + len(vector_dims): no_groupby_reorder = isinstance( obj, xr.Dataset ) # do not re-order dataarrays inside datasets actual[var] = _restore_dim_order( actual[var], template, by_da[0], no_groupby_reorder=no_groupby_reorder ) if missing_dim: for k, v in missing_dim.items(): missing_group_dims = {d: size for d, size in group_sizes.items() if d not in v.dims} # The expand_dims is for backward compat with xarray's questionable behaviour if missing_group_dims: actual[k] = v.expand_dims(missing_group_dims).variable else: actual[k] = v.variable if isinstance(obj, xr.DataArray): return obj._from_temp_dataset(actual) else: return actual
[docs] def rechunk_for_cohorts( obj: T_DataArray | T_Dataset, dim: str, labels: T_DataArray, force_new_chunk_at, chunksize: int | None = None, ignore_old_chunks: bool = False, debug: bool = False, ): """ Rechunks array so that each new chunk contains groups that always occur together. Parameters ---------- obj : DataArray or Dataset array to rechunk dim : str Dimension to rechunk labels : DataArray 1D Group labels to align chunks with. This routine works well when ``labels`` has repeating patterns: e.g. ``1, 2, 3, 1, 2, 3, 4, 1, 2, 3`` though there is no requirement that the pattern must contain sequences. force_new_chunk_at : Sequence Labels at which we always start a new chunk. For the example ``labels`` array, this would be `1`. chunksize : int, optional nominal chunk size. Chunk size is exceeded when the label in ``force_new_chunk_at`` is less than ``chunksize//2`` elements away. If None, uses median chunksize along ``dim``. Returns ------- DataArray or Dataset Xarray object with rechunked arrays. """ return _rechunk( rechunk_array_for_cohorts, obj, dim, labels, force_new_chunk_at=force_new_chunk_at, chunksize=chunksize, ignore_old_chunks=ignore_old_chunks, debug=debug, )
[docs] def rechunk_for_blockwise(obj: T_DataArray | T_Dataset, dim: str, labels: T_DataArray): """ Rechunks array so that group boundaries line up with chunk boundaries, allowing embarrassingly parallel group reductions. This only works when the groups are sequential (e.g. labels = ``[0,0,0,1,1,1,1,2,2]``). Such patterns occur when using ``.resample``. Parameters ---------- obj : DataArray or Dataset Array to rechunk dim : hashable Name of dimension to rechunk labels : DataArray Group labels Returns ------- DataArray or Dataset Xarray object with rechunked arrays. """ return _rechunk(rechunk_array_for_blockwise, obj, dim, labels)
def _rechunk(func, obj, dim, labels, **kwargs): """Common logic for rechunking xarray objects.""" obj = obj.copy(deep=True) if isinstance(obj, xr.Dataset): for var in obj: if obj[var].chunks is not None: obj[var] = obj[var].copy( data=func( obj[var].data, axis=obj[var].get_axis_num(dim),, **kwargs ) ) else: if obj.chunks is not None: obj = obj.copy( data=func(, axis=obj.get_axis_num(dim),, **kwargs) ) return obj