Strategies for climatology calculations

This notebook is motivated by this post on the Pangeo discourse forum.

import dask.array
import pandas as pd
import xarray as xr

import flox
import flox.xarray

Let’s first create an example Xarray Dataset representing the OISST dataset, with chunk sizes matching that in the post.

oisst = xr.DataArray(
    dask.array.ones((14532, 720, 1440), chunks=(20, -1, -1)),
    dims=("time", "lat", "lon"),
    coords={"time": pd.date_range("1981-09-01 12:00", "2021-06-14 12:00", freq="D")},
    name="sst",
)
oisst
<xarray.DataArray 'sst' (time: 14532, lat: 720, lon: 1440)> Size: 121GB
dask.array<ones_like, shape=(14532, 720, 1440), dtype=float64, chunksize=(20, 720, 1440), chunktype=numpy.ndarray>
Coordinates:
  * time     (time) datetime64[ns] 116kB 1981-09-01T12:00:00 ... 2021-06-14T1...
Dimensions without coordinates: lat, lon

To account for Feb-29 being present in some years, we’ll construct a time vector to group by as “mmm-dd” string.

See also

For more options, see this great website.

day = oisst.time.dt.strftime("%h-%d").rename("day")
day
<xarray.DataArray 'day' (time: 14532)> Size: 116kB
array(['Sep-01', 'Sep-02', 'Sep-03', ..., 'Jun-12', 'Jun-13', 'Jun-14'],
      dtype=object)
Coordinates:
  * time     (time) datetime64[ns] 116kB 1981-09-01T12:00:00 ... 2021-06-14T1...

First, method="map-reduce"

The default method=”map-reduce” doesn’t work so well. We aggregate all days in a single ~3GB chunk.

For this to work well, we’d want smaller chunks in space and bigger chunks in time.

flox.xarray.xarray_reduce(
    oisst,
    day,
    func="mean",
    method="map-reduce",
)
<xarray.DataArray 'sst' (day: 366, lat: 720, lon: 1440)> Size: 3GB
dask.array<transpose, shape=(366, 720, 1440), dtype=float64, chunksize=(366, 720, 1440), chunktype=numpy.ndarray>
Coordinates:
  * day      (day) object 3kB 'Apr-01' 'Apr-02' 'Apr-03' ... 'Sep-29' 'Sep-30'
Dimensions without coordinates: lat, lon

Rechunking for map-reduce

We can split each chunk along the lat, lon dimensions to make sure the output chunk sizes are more reasonable

flox.xarray.xarray_reduce(
    oisst.chunk({"lat": -1, "lon": 120}),
    day,
    func="mean",
    method="map-reduce",
)
<xarray.DataArray 'sst' (day: 366, lat: 720, lon: 1440)> Size: 3GB
dask.array<transpose, shape=(366, 720, 1440), dtype=float64, chunksize=(366, 720, 120), chunktype=numpy.ndarray>
Coordinates:
  * day      (day) object 3kB 'Apr-01' 'Apr-02' 'Apr-03' ... 'Sep-29' 'Sep-30'
Dimensions without coordinates: lat, lon

But what if we didn’t want to rechunk the dataset so drastically (note the 10x increase in tasks). For that let’s try method="cohorts"

method="cohorts"

We can take advantage of patterns in the groups here “day of year”. Specifically:

  1. The groups at an approximately periodic interval, 365 or 366 days

  2. The chunk size 20 is smaller than the period of 365 or 366. This means, that to construct the mean for days 1-20, we just need to use the chunks that contain days 1-20.

This strategy is implemented as method=”cohorts”

flox.xarray.xarray_reduce(
    oisst,
    day,
    func="mean",
    method="cohorts",
)
<xarray.DataArray 'sst' (day: 366, lat: 720, lon: 1440)> Size: 3GB
dask.array<transpose, shape=(366, 720, 1440), dtype=float64, chunksize=(11, 720, 1440), chunktype=numpy.ndarray>
Coordinates:
  * day      (day) object 3kB 'Apr-01' 'Apr-02' 'Apr-03' ... 'Sep-29' 'Sep-30'
Dimensions without coordinates: lat, lon

By default cohorts doesn’t work so well for this problem because the period isn’t regular (365 vs 366) and the period isn’t divisible by the chunk size. So the groups end up being “out of phase” (for a visual illustration click here). Now we have the opposite problem: the chunk sizes on the output are too small.

Let us inspect the cohorts

# integer codes for each "day"
codes, _ = pd.factorize(day.data)
preferred_method, cohorts = flox.core.find_group_cohorts(
    labels=codes,
    chunks=(oisst.chunksizes["time"],),
)
print(len(cohorts))
158

Looking more closely, we can see many cohorts with a single entry.

cohorts.values()
dict_values([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 364], [10], [11], [12], [13], [14], [15], [16], [17], [18], [19], [20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30], [31], [32], [33], [34], [35], [36], [37], [38], [39], [40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50], [51], [52], [53], [54], [55], [56], [57], [58], [59], [60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70], [71], [72], [73], [74], [75], [76], [77], [78], [79], [80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90], [91], [92], [93], [94], [95], [96], [97], [98], [99], [100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110], [111], [112], [113], [114], [115], [116], [117], [118], [119], [120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130], [131], [132], [133], [134], [135], [136], [137], [138], [139], [140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150], [151], [152], [153], [154], [155], [156], [157], [158], [159], [160, 161, 162, 163], [164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 365], [175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185], [186, 187, 188, 189, 190, 191, 192], [193, 194], [195], [196], [197], [198], [199], [200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210], [211, 212], [213, 214], [215], [216], [217], [218], [219], [220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230], [231, 232, 233], [234], [235], [236], [237], [238], [239], [240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250], [251, 252, 253], [254], [255], [256], [257], [258], [259], [260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270], [271, 272, 273, 274], [275], [276], [277], [278], [279], [280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290], [291], [292, 293, 294], [295], [296], [297], [298], [299], [300, 301, 302], [303], [304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314], [315], [316], [317], [318], [319], [320, 321], [322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332], [333, 334], [335], [336], [337], [338], [339], [340, 341], [342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352], [353, 354], [355], [356], [357], [358], [359], [360], [361], [362], [363]])

Rechunking data for cohorts

Can we fix the “out of phase” problem by rechunking along time?

First lets see where the current chunk boundaries are

oisst.chunksizes["time"][:10]
(20, 20, 20, 20, 20, 20, 20, 20, 20, 20)

We’ll choose to rechunk such that a single month in is a chunk. This is not too different from the current chunking but will help your periodicity problem

newchunks = xr.ones_like(day).astype(int).resample(time="M").count()
/home/docs/checkouts/readthedocs.org/user_builds/flox/conda/stable/lib/python3.12/site-packages/xarray/core/groupby.py:668: FutureWarning: 'M' is deprecated and will be removed in a future version, please use 'ME' instead.
  index_grouper = pd.Grouper(
rechunked = oisst.chunk(time=tuple(newchunks.data))

And now our cohorts contain more than one group, and there is a substantial reduction in number of cohorts 162 -> 12

preferred_method, new_cohorts = flox.core.find_group_cohorts(
    labels=codes,
    chunks=(rechunked.chunksizes["time"],),
)
# one cohort per month!
len(new_cohorts)
12
preferred_method
'cohorts'
new_cohorts.values()
dict_values([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29], [30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60], [61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90], [91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121], [122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152], [153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 365], [181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211], [212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241], [242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272], [273, 274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 299, 300, 301, 302], [303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333], [334, 335, 336, 337, 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, 359, 360, 361, 362, 363, 364]])

Now the groupby reduction looks OK in terms of number of tasks but remember that rechunking to get to this point involves some communication overhead.

flox.xarray.xarray_reduce(rechunked, day, func="mean", method="cohorts")
<xarray.DataArray 'sst' (day: 366, lat: 720, lon: 1440)> Size: 3GB
dask.array<transpose, shape=(366, 720, 1440), dtype=float64, chunksize=(31, 720, 1440), chunktype=numpy.ndarray>
Coordinates:
  * day      (day) object 3kB 'Apr-01' 'Apr-02' 'Apr-03' ... 'Sep-29' 'Sep-30'
Dimensions without coordinates: lat, lon

flox’s heuristics will choose "cohorts" automatically!

flox.xarray.xarray_reduce(rechunked, day, func="mean")
<xarray.DataArray 'sst' (day: 366, lat: 720, lon: 1440)> Size: 3GB
dask.array<transpose, shape=(366, 720, 1440), dtype=float64, chunksize=(31, 720, 1440), chunktype=numpy.ndarray>
Coordinates:
  * day      (day) object 3kB 'Apr-01' 'Apr-02' 'Apr-03' ... 'Sep-29' 'Sep-30'
Dimensions without coordinates: lat, lon

How about other climatologies?

Let’s try monthly

flox.xarray.xarray_reduce(oisst, oisst.time.dt.month, func="mean")
<xarray.DataArray 'sst' (month: 12, lat: 720, lon: 1440)> Size: 100MB
dask.array<transpose, shape=(12, 720, 1440), dtype=float64, chunksize=(1, 720, 1440), chunktype=numpy.ndarray>
Coordinates:
  * month    (month) int64 96B 1 2 3 4 5 6 7 8 9 10 11 12
Dimensions without coordinates: lat, lon

This looks great. Why?

It’s because each chunk (size 20) is smaller than number of days in a typical month. flox initially applies the groupby-reduction blockwise. For the chunk size of 20, we will have at most 2 groups in each chunk, so the initial blockwise reduction is quite effective - at least a 10x reduction in size from 20 elements in time to at most 2 elements in time.

For this kind of problem, "map-reduce" works quite well.