Comparing k-NN in Rust

Contents

In my voyages around the internet, I came across a pair of
blog posts which compare the implementation of a
k-nearest neighbour (k-NN) classifier in F# and OCaml. I couldn’t
resist writing the code into Rust to see how it fared.

Code

The Rust code is a nearly-direct translation of the original F# code,
the only change was changing distance to compute the squared
distance, that is, a*a + b*b + ... (square root is strictly
increasing, yo).

usestd::io::{File,BufferedReader};structLabelPixel{label:int,pixels:Vec<int>}fnslurp_file(file:&Path)->Vec<LabelPixel>{BufferedReader::new(File::open(file).unwrap()).lines().skip(1).map(|line|{letline=line.unwrap();letmutiter=line.as_slice().trim().split(',').map(|x|from_str(x).unwrap());LabelPixel{label:iter.next().unwrap(),pixels:iter.collect()}}).collect()}fndistance_sqr(x:&[int],y:&[int])->int{// run through the two vectors, summing up the squares of the differencesx.iter().zip(y.iter()).fold(0,|s,(&a,&b)|s+(a-b)*(a-b))}fnclassify(training:&[LabelPixel],pixels:&[int])->int{training.iter()// find element of `training` with the smallest distance_sqr to `pixel`.min_by(|p|distance_sqr(p.pixels.as_slice(),pixels)).unwrap().label}fnmain(){lettraining_set=slurp_file(&Path::new("trainingsample.csv"));letvalidation_sample=slurp_file(&Path::new("validationsample.csv"));letnum_correct=validation_sample.iter().filter(|x|{classify(training_set.as_slice(),x.pixels.as_slice())==x.label}).count();println!("Percentage correct: {}%",num_correctasf64 / validation_sample.len()asf64*100.0);}

(Prints Percentage correct: 94.4%, matching the OCaml.)

How’s it compare?

I don’t have an F# compiler, so I’ll only compare against the fastest
OCaml solution (from the follow-up post), after making the
same modification to distance.

The Rust was compiled with rustc -O, and the OCaml with ocamlopt
str.cmxa (as recommended), using version 4.01.0. I ran each 3 times
(times in seconds) on these CSV files.

Lang

1

2

3

Rust

3.56

3.46

3.86

OCaml

13.9

14.7

14.1

So the Rust code is about 3.5–4× faster than the
OCaml.

It’s worth noting that the Rust code is entirely safe and built
directly (and mostly minimally) using the abstractions provided by the
standard library. The speed is mainly due to the magic of
Rust’s (lazy) iterators
which provide very efficient sequential access to elements of
vectors/slices, as well as a variety of efficient
adaptors
implementing various useful algorithms. These may look high-level and
hard to optimise, but they are very transparent to the compiler,
resulting in fast machine code.

Updated 2014-06-11: the Rust code is not as fast as it could be,
due to bugs like
#11751, caused by
LLVM being unable to understand that & pointers are never
null. benh wrote a short slice-zip iterator that may
make its way into the standard library: he even
used it to make the
code 3 times faster.

In comparison, the OCaml code has had to manually write a few
functions (for folding and for reading lines from a file), and
contains two possibly-concerning pieces of code:

1
2

letv1=unsafe_geta1iinletv2=unsafe_geta2iin

It might be interesting to compare against
this D code, but I
can’t get it to compile right.

What about parallelism?

I’m glad you asked! Rust is designed to be good for concurrency, using
the type system
to guarantee that code is threadsafe. As I said before, Rust is under
heavy development, and currently lacks a data parallelism library (so
there’s no parallel-map to just call directly yet), but it’s easy
enough to use
the built-in futures
for this.

The code can be made parallel simply by replacing the main function
with the following.

// how many chunks should the validation sample be divided into? (==// how many futures to create.)staticNUM_CHUNKS:uint=32;fnmain(){usesync::{Arc,Future};usestd::cmp;// "atomic reference counted": guaranteed thread-safe shared// memory. The type signature and API of `Arc` guarantees that// concurrent access to the contents will be safe, due to the `Share`// trait.lettraining_set=Arc::new(slurp_file(&Path::new("trainingsample.csv")));letvalidation_sample=Arc::new(slurp_file(&Path::new("validationsample.csv")));letchunk_size=(validation_sample.len()+NUM_CHUNKS-1)/NUM_CHUNKS;letmutfutures=range(0,NUM_CHUNKS).map(|i|{// create new "copies" (just incrementing the reference// counts) for our new future to handle.letts=training_set.clone();letvs=validation_sample.clone();Future::spawn(proc(){// compute the region of the vector we are handling...letlo=i*chunk_size;lethi=cmp::min(lo+chunk_size,vs.len());// ... and then handle that region.vs.slice(lo,hi).iter().filter(|x|{classify(ts.as_slice(),x.pixels.as_slice())==x.label}).count()})}).collect::<Vec<Future<uint>>>();// run through the futures (waiting for each to complete) and sum the resultsletnum_correct=futures.mut_iter().map(|f|f.get()).fold(0,|a,b|a+b);println!("Percentage correct: {}%",num_correctasf64 / validation_sample.len()asf64*100.0);}

(Also prints Percentage correct: 94.4%.)

This gives a nice speed up, approximately halving the time required:
the real time is now stable around 1.81 seconds (6.25 s of user time)
on my machine.

I'm Huon Wilsonhuon_w, a
mathematically and statistically inclined software engineer,
currently working on the Swift team at Apple, but interested from
hearing from you. Before that I was a long-term volunteer
on Rust's core team.