The Scala programming language is unique in terms of FP languages in
that it lets you choose your own balance between pure FP and filthy
imperative style. This is good and bad, but one benefit is that when you
have an algorithm that doesn't immediately translate well into a functional
style, you can implement it imperatively and refactor it later to be
more functional.

This example is from about a year ago, and I'll take you through a few
steps of my refactor. I'm still very much a Scala beginner, so keep in
mind, there may be some mistakes.

The Problem, and the Algorithm

I was working on a proof-of-concept middleware TCP load balancer using
Netty and Zookeeper. One feature that was going to be important was the
ability to do "weighted round-robin"
load balancing. At my job, we deploy code to new servers rather than
upgrading code on existing servers. Then, we move new servers into the load
balancer to old ones out. With our current load balancing solution, there
wasn't an easy way to shift live traffic incrementally, so that was the motivation
behind weighted round robin.

With a little bit of research on Wikipedia, I found a pretty straightforward
algorithm used in packet switching. Then, I found the following C/psuedo-code implementation
from this site.

First crack at it

My first Scala implementation was almost exactly like this. I used vars for mutable
state and a recursive call to an inner function for the while loop.

importscala.reflect.ClassTagimportscala.annotation.tailrectypeWeight=IntclassWeightedRRIterator[A:ClassTag](valitemsToWeights:Map[A, Weight]){valitems:Array[A]=itemsToWeights.keys.toArrayvaltotal=items.length// Integer does not have gcd but BigInt does.lazyvalgcdWeights=itemsToWeights.values.map(BigInt(_)).reduce(_.gcd(_)).toIntlazyvalmaxWeight=itemsToWeights.values.maxprivatevari=-1privatevarcurrentWeight=0defhasNext=total!=0defnext:A={if(!hasNext)thrownewNoSuchElementException("Called next on empty iterator")// this inner function will get called at worst case, `items.size`// times for each next() call@tailrecdefdoNext():A={i=(i+1)%totalif(i==0){currentWeight-=gcdWeightsif(currentWeight<=0)currentWeight=maxWeight}if(itemsToWeights(items(i))>=currentWeight)items(i)elsedoNext()}doNext()}}valiter=newWeightedRRIterator(Map("a"->5,"b"->3))for(i<-1to20){print(iter.next)}// prints "aaabababaaabababaaab"

Note

A few things unrelated to the algorithm: the [A: ClassTag] is basically a saying that generic type A
must also pass along a ClassTag. This is an unfortunate side effect of Java
type erasure. It's needed because I'm creating an Array of
type A, which is of an unknown type at runtime. In Scala, this is all you need to do, but when
trying to use this class as a library from Java, I wasn't able to figure out how to pass the ClassTag
and the generated type signatures of the functions were pretty dense. Some other data structures don't have
this problem, but Array was chosen for constant time access since we're indexing by a counter.

The this.synchronized is the equivalent of the Java synchronized keyword and only important
if you plan on sharing the iterator between multiple threads (not recommended!).

The @tailreq annotation does not actually affect the function other than to give you a compile-time
error if the function has a non-tail call. It's not clear the function is any better than a while-loop,
but it does help us to avoid needing an explicit return.

The idea of this algorithm is mostly simple, yet it's difficult to reason out of the code.
One key to refactoring was to get a much more intuitive understanding of what the code
is actually doing. It becomes very clear when illustrated. Imagine we have weights of
[10, 5, 20] for hosts a, b, c, respectively. The gcd is 5, max weight is 20,
and amount is 3.

i

currentWeight

selected

0

20

--

1

20

--

2

20

c

0

15

--

1

15

--

2

15

c

0

10

a

1

10

--

2

10

c

0

5

a

1

5

b

2

5

c

0

20

--

1

20

--

2

20

c

... and so on.

So the gist of it is i is an index that cycles through out array,
currentWeight cycles between maxWeight and 0 decrementing by gcd. For each
value of currentWeight, we yield a new value only if the weight of that item
is high enough. So for every 4 loops of currentWeight, a is selected 2 times,
b is selected 1 time, and c is selected 4 times. This is what our weights of
10, 5 and 20 reduce to, so it works!

There's a few things that make this difficult to read from the code
(for me, anyway). One thing is we're indexing an array rather than using some kind
of iterator. This is caveman-style of iterating through arrays. It's error-prone
and low level. The other problem, from an FP perspective is those pesky vars.
Mutable data is shunned in functional programming.

Finally, the intention of the programmer is lost in the details of the implementation.
All the arithmetic and updating counters is a lot to keep track of. The code is not
very expressive.

Let's Get Functional

It turns out Scala for-comprehensions are a really nice way to express this.
The following is analogous to a nested for-loop in an imperative language:

We're iterating over two sequences here, we're iterating over a
scala.collection.immutable.Range(5, 10, 15, 20) and for each iteration of that,
we're iterating through ("a", "b", "c"). The if expression
filters any generated iterations for which our weight isn't high enough. As you can
see, it's a much more natural translation of the above table to code.

There's one catch, the iterator runs once, but we want to run forever. I was
actually stuck on this for a pretty long time, having several ideas but they
none very elegant. Finally, when playing with recursive Streams, I discovered
two great feature of Scala's Iterator class:

We're using the temporary variable permutations because Iterator.continually
takes parameters by name which basically means anything inside of the parenthesis
is executed every time.

We could actually avoid the hasNext and next functions and just make this calls Iterable rather
than an Iterator (right now it isn't declared as either, in my actual code it extends a base class
that has other implementations like WeightedProbabilisticIterator and extends Iterator).

Fixing the Distribution

One problem a colleague of mine was quick to point out was their poor distribution resulting
from this approach. Image you had weights like 100 and 7 for o and x. What you would get would be:

the differences here are subtle but I'd argue the first one will produce better distributions
because it makes it impossible for the lower weighted items ever get selected twice before
higher weighted items.

Finally, if you want the iterator to be continuously shuffled: reshuffled every time we get to the end,
you can take advantage of Iterator.continuously taking its parameter as call-by-name.

Since I was being indecisive, I implemented each of these "strategies" as traits,
(SequentialOrdering, Shuffled, ContinuouslyShuffled) so when you instantiate
a new RoundRobinIterator you must select one by mixing it in, such as

valwi=newWeightedRRIterator(Map("pizza"->2,"taco"->4))withShuffled

I'm not sure if it was the right approach or not yet, because the compiler won't stop
you from meaninglessly mixing in multiple strategies.

In conclusion

For comprehensions are very expressive, and the Scala Iterator has some pretty
neat tricks. I'm not sure how to do something similar in Java, but I am sure
any approximation would just make me sad.

If there's any interest, I can open source the code discussed here. I made some
interesting OO design choices that I'd probably want to revisit before letting
other people see.