Search This Blog

Grouping, sampling and batching - custom collectors in Java 8

Continuing first article, this time we will write some more useful custom collectors: for grouping by given criteria, sampling input, batching and sliding over with fixed size window.

Grouping (counting occurrences, histogram)

Imagine you have a collection of some items and you want to calculate how many times each item (with respect to equals()) appears in this collection. This can be achieved using CollectionUtils.getCardinalityMap() from Apache Commons Collections. This method takes an Iterable<T> and returns Map<T, Integer>, counting how many times each item appeared in the collection. However sometimes instead of using equals() we would like to group by an arbitrary attribute of input T. For example say we have a list of Person objects and we would like to compute the number of males vs. females (i.e. Map<Sex, Integer>) or maybe an age distribution. There is a built-in collector Collectors.groupingBy(Function<T, K> classifier) - however it returns a map from key to all items mapped to that key. See:

This overloaded version of groupingBy() takes three parameters. First one is the key (classifier) function, as previously. Second argument creates a new map, we'll see shortly why it's useful. counting() is a nested collector that takes all people with same sex and combines them together - in our case simply counting them as they arrive. Being able to choose map implementation is useful e.g. when building age histogram. We would like to know how many people we have at given age - but age values should be sorted:

We ended up with a TreeMap from age (sorted) to count of people having that age.

Sampling, batching and sliding window

IterableLike.sliding() method in Scala allows to view a collection through a sliding fixed-size window. This window starts at the beginning and in each iteration moves by given number of items. Such functionality, missing in Java 8, allows several useful operators like computing moving average, splitting big collection into batches (compare with Lists.partition() in Guava) or sampling every n-th element. We will implement collector for Java 8 providing similar behaviour. Let's start from unit tests, which should describe briefly what we want to achieve:

Using data driven tests in Spock I managed to write almost 40 test cases in no-time, succinctly describing all requirements. I hope these are clear for you, even if you haven't seen this syntax before. I already assumed existence of handy factory methods:

The fact that collectors receive items one after another makes are job harder. Of course first collecting the whole list and sliding over it would have been easier, but sort of wasteful. Let's build result iteratively. I am not even pretending this task can be parallelized in general, so I'll leave combiner() unimplemented:

I spent quite some time writing this implementation, especially correct finisher() so don't be frightened. The crucial part is a buffer that collects items until it can form one sliding window. Then "oldest" items are discarded and window slides forward by step. I am not particularly happy with this implementation, but tests are passing. sliding(N) (synonym to sliding(N, 1)) will allow calculating moving average of N items. sliding(N, N) splits input into batches of size N. sliding(1, N) takes every N-th element (samples). I hope you'll find this collector useful, enjoy!