Higher-Order Functions in Java

This Summer I’ve been teaching Intro to Computer Languages. It’s a really awesome class, with alot of good material in it. Last week, my students got back the results of their midterms, and were disappointed in their scores. Immediately, I heard some clamoring for an extra credit. Not only was the average midterm score pretty low, but also the class has been jumping rapidly from one topic to another. The schedule was originally planned for a 10 week course, but has been compressed for a 6 week Summer Session. I don’t think that student learning rate improves by almost a factor of 2 just because it’s summer an the lecture is 3hrs instead of the normal 1hr 30min. Quite the contrary, I think retention is best achieved as a result of long-term exposure to material, precisely the opposite of our current situation.

Prior to the midterm, in Project 2, we had done an exercise writing a in-order iterator over a binary search tree. This exercise had two important components: (1) The iterator allows pausing between elements. So, it must have a way to remember where it left off the last time an element was returned. (2) The binary search tree could be storing any type of data. So, the entire project had to satisify the typing constraints of Java’s generics system.

Then, in the following week (I told you the class moves quick), We used Haskell in Project 3 to cover the concept of Higher Order Functions (HOF). Now, most of my students have had their minds trained on Java, which doesn’t provide first class functions. So, the concept that a function might take other functions as arguments or return a function as a result is really foreign.

Since my students were clamoring for some extra credit, I tried to think up something that might tie together what has been done so far. My predecessor, Alex Thornton, had done a Higher Order Function implementation in C# as a fun and interesting demo for his last lecture. Now, I don’t know C#, but I’m pretty sure that I could do the same in Java. It turns out you can create for yourself a java Framework that provides this higher order functions, with a bit of work. Although, as we shall see, it’s verbose and clunky.

Let’s begin by defining what a function looks like. Since Java doesn’t support first-class functions, we start by defining a Function Object.

public interface IFunction<A, B> {
	B call(A arg);
}

Basically, any object that wishes to act as a function from A → B, implements the IFunction interface, so that it has a call method.

Next, I try and think of some task that would use some HOFs. So, let’s first create a list of Strings, and print them out.

public static void main(String[] args)
{
    ArrayList<String> words = new ArrayList<String>();
    words.add("File");
    words.add("Edit");
    words.add("Source");
    words.add("Refactor");
    words.add("Navigate");
    words.add("Search");
    words.add("Project");
    words.add("Run");
    words.add("Window");
    words.add("Help");
 
    System.out.println("Starting With: ");
    FunctionLib.map(new PrintFunction<String>(), words);
    System.out.println("\n");

Instead of using the normal for-each loop to I print out all the elements in a list, I map a PrintFunction (specialized on String) onto the list. Let’s look at the HOF map first. We know from using Haskell that map takes a function and a list, and applies that function to every element in the list. So, map takes two arguments: (1) a function from A → B, and (2) a list of A. Finally, after applying the function to every element, it returns a list of B. We collect all the HOFs we’ll implement into a FunctionLib class.

public class FunctionLib {
    // map :: (a -> b) -> [a] -> [b]
    public static <A,B> List<B> map(IFunction<A, B> fun, List<A> args)
    {
        List<B> results = new ArrayList<B>();
        for(A a : args)
            results.add(fun.call(a));
        return results;
    }
}

I’m mappping a <code>PrintFunction</code> that has the side-effect of printing each element it is called on. The primary difficulty with doing this is that according to my IFunction interface, I have to return something that fits in Java’s generics system. So, I specialize the return type parameter to Object and return null. Technically, the map method will remember the resulting nulls, and build a list of them, which then gets discarded. A bit wasteful, but it’s an eager language, so there’s not much we can do to avoid that.

public class PrintFunction<E> implements IFunction<E, Object> {
    @Override
    public Object call(E arg) {
        System.out.println(arg);
        return null;
    }
}

I don’t especially like creating a whole class file for each function I might want to pass into an HOF. Fortunately, Java supports anonymous classes. Let’s create one for getting the length of a string, call it lengthFunction, and pass it into map, so that we get back an array of Integers representing the lengths.

    System.out.println("The length of each String is: ");
    IFunction<String, Integer> lengthFunction = new IFunction<String, Integer>() {
        public Integer call(String arg) {
            return arg.length();
        }
    };
 
    List<Integer> lengths= FunctionLib.map(lengthFunction, words);
    FunctionLib.map(new PrintFunction<Integer>(), lengths);
    System.out.println("\n");

Although, not creating a class file for each new function is nice, it’s still hideously verbose to create a whole class. Oh, do I wish that Java provided syntax for lambda‘s.

Now that I have two lists, one of words and another of their lengths, I should be able to stitch them together. Haskell provides another HOF for precisely this task: zip is a function that takes two lists and returns a pair of their elements. The resulting list of pairs is the same length as the shortest input list.

public class FunctionLib {
    // zip :: [a] -> [b] -> [(a, b)]
    public static <A,B> List<Pair<A,B>> zip(List<A> listA, List<B&> listB)
    {
        List<Pair<A,B> results = new ArrayList<Pair<A,B>>();
        int len = Math.min(listA.size(), listB.size());
        for(int i=0; i<len; i++) {
            results.add( new Pair<A,B>(listA.get(i), listB.get(i)) );
        }
        return results;    
    }
}

Java is so hideously dysfunctional that it doesn’t come with a native Pair class. However, it’s easy to make one.

public class Pair<A,B> {
    public A first;
    public B second;
 
    Pair(A a, B b)
    {
        first = a;
        second = b;
    }
 
    @Override
    public String toString()
    {
        String firstStr = first == null ? first.toString() : "null";
        String secondStr = second == null ? second.toString() : "null";
        return "(" + firstStr + ", " + secondStr + ")";
    }
}

After all that monstrous cruft, we can finally get the list of pairs containing each word and its length.

    List<Pair<String, Integer>> associations = FunctionLib.zip(words, lengths);

But we’ve yet to do anything really interesting with that. So, let’s filter the list for words that are greater than 5 letters long. Haskell also provides an HOF filter which takes a function that is applied to every element in the list, and a list on which to apply it, returning a list of those elements for which the function evaluated true.

public class FunctionLib {
    // filter :: (a -> bool) -> [a] -> [a]
    public static <A> List<A> filter(IFunction<A, Boolean> pred, List<A> listA)
    {
        List<A> results = new ArrayList<A>();
        for(A a : listA) {
            if (pred.call(a))
                results.add(a);
        }
        return results;
    }
}

I use this function to pull out the pairs which have a string longer than 5 characters. Again, I again use the verbose anonymous class.

    System.out.println("The words with length bigger than 5 are:");
    IFunction<Pair<String, Integer>, Boolean> biggerThanFive = new IFunction<Pair<String, Integer>, Boolean>() {
        public Boolean call(Pair<String, Integer> arg)
        {
            return arg.second > 5;
        }
    };
    List<Pair<String, Integer>> longerThanFive = FunctionLib.filter(biggerThanFive, associations);

But, I’m really only interested in the strings, not the pairs. So, now I want a way to pull out only the first element of each pair.

    List<String> stringsLongerThanFive = FunctionLib.map(Pair.firstFunction(String.class, Integer.class), longerThanFive);
    FunctionLib.map(new PrintFunction<String>(), stringsLongerThanFive);
    System.out.println("\n");

Notice, that this time, I’ve given myself a convenience method, Pair.firstFunction. I’d really have preferred something like C++’s syntax here, but Java’s generic system has some pretty fundamental brokenness. First, as a result of type-erasure, the system doesn’t know what the type parameters into the function are at runtime. That makes an expression like Pair<String,Integer>.firstFunction invalid, because the type parameters are thrown away way before the method call takes place. Second, the type parameters to a generic method are determined via inspection on the types of the provided actual arguments. That makes an expression like Pair.firstFunction<String,Integer> invalid. You don’t pass type parameters at the function call site (as you would when you instantiate a generic object). Instead, as a workaround, you have to provide dummy actual arguments, that the typing system uses to specialize the call.

public class Pair<A,B> {
    ...
 
    public static <A,B> IFunction<Pair<A, B>, A> firstFunction(Class<A> a, Class<B> b) {
        return new IFunction<Pair<A, B>, A>() {
            public A call(Pair<A, B> arg) {
                return arg.first;
            }
        };
    }
 
    public static <A,B> IFunction<Pair<A,B>, B> secondFunction(Class<A> a, Class<B> b) {
        return new IFunction<Pair<A, B>, B>() {
            public B call(Pair<A, B> arg) {
                return arg.second;
            }
        };
    }
}

Things are a little bit different here than the PrintFunction we implemented earlier. We should like to keep the function that pulls out the first element of a pair in the Pair class. Doing this is not only trick because of Java’s generic type system deficiencies, but also because we can’t simply name a function as a reference (as you can in C). Instead, we create a firstFunction method, that when called with type parameters as actual arguments, returns an appropriately specialized IFunction which returns the first element when you pass it a pair in the call. Creation of the returned IFunction is also achieved via an anonymous class.

We now turn to the coup de grâce. We have a list of all the strings with length bigger than 5, so let’s use an HOF to concatenate them. Haskell again provides us with a function, foldl1 that will start at the beginning of a list and inductively apply a two argument function onto it, returning a single result. Essentially, if + were my string concatenation operator and my list were [a, b, c, d, ...] I’d be calculating (..((a + b) + c) + d) + ..). It’s a bit tricky in our implementation though. We have to pass into the HOF a function that takes two arguments, but my IFunction interface is only able to represent functions of a single argument. What shall we do?

Currying to the rescue! Haskell types the first argument to foldl1 as (a -> a -> a). Which you might naïvely think is an function of two arguments, but the arrow operator is right-associative. So, it’s actually parsed as (a -> (a -> a)), which is a function of one argument that returns a function of one argument. So, a function of two arguments is equivalent to a nested chain of functions, each accepting one argument. Which means, the IFunction interface is enough!

public class FunctionLib {
    // foldl1 :: (a -> a -> a) -> [a] -> a
    // In Haskell the function type operator -> is right associative,
    // So, the function description "a -> a -> a" is parsed as "a -> (a -> a)"
    // This order of application is represented in the nesting of IFunction's below
    public static <A> A foldl1(IFunction<A, IFunction<A,A>> fn, List<A> listA)
    {
        A accumulator = null;
        if (listA.size() == 0)
            return accumulator;
 
        accumulator = listA.get(0);
        if (listA.size() == 1)
            return accumulator;
 
        for(int i=1; i<listA.size(); i++) {
            // currying means two .call invocations
            accumulator = fn.call(accumulator).call(listA.get(i));
        }
        return accumulator;
    }
}

Finally, I can wrap up this exercise by applying a concatenation operator to a call to foldl1. This ends up being some of the most ugly code that can be written. Unlike the previous calls to map and filter, I did not want to save the temporary function into a variable. Rather, I’ve written it inline. Unfortunately, because of the currying, we have to nest anonymous functions!

    System.out.println("When we concatenate these we get:");
    String concat = FunctionLib.foldl1(
        new IFunction<String, IFunction>String, String>>() {
            public IFunction<String, String> call(final String arg1) {
                return new IFunction<String, String>() {
                    public String call(String arg2) {
                        return arg1 + arg2;
                    }
                };
            }
        },
        stringsLongerThanFive);
    System.out.println(concat);
    System.out.println("\n");
}

What’s going on here? Remember, we are representing a function of two arguments as nested functions of a single argument. So for example, I want to concatenate the list [a, b, c, d], I would pass in the concatenation operator, (+). Inside of foldl1, we take out the first element, a and invoke the call method with the first argument. This returns (a+), which is the concatenation operator with the first argument filled in. A second invocation is used to pass in the second argument, and that returns a string, which is saved in an accumulator.

This level of indirection means that I have to have some way of returning a function with the first argument filled in. I accomplish this by nesting anonymous classes and using a closure. The outer call method takes final String arg1 as an argument. final is necessary, because I have to tell Java that it is OK to ‘save’ or ‘cache’ this argument. I’m guaranteeing that I don’t have any code within this method that might change what arg1 referrs to. Once I’ve made this guarantee, I’m able to create and return a new IFunction with its own call(String arg2) that uses the outer arg1.

After going through all this awesome Computer Science (weird rules about the generics system, anonymous inner classes, currying, and closures), I decided that this is not something I could reasonably expect my students to do as extra credit. First, I would want extra credit to assist my student’s absorption of concepts, and yet still not be so difficult that it becomes a distraction to actual coursework. Second, it should be easy enough to help the students that need it to improve their grade. Although this example does hit the generics and higher-order functions topics, I think it has too many complex issues rolled together. For example, I’ve never actually had to use closures and anonymous inner classes before. Now, suddenly, I use them both in a single project!

UPDATE Mon Aug 8 19:56:37 PDT 2011: Jim Duey has a talk Functional Programming: A Pragmatic Introduction. At around 30mins he shows how Laziness can be used to implement a message queue for concurrent programs.

UPDATE Tue Aug 9 11:30:16 PDT 2011: You can download the project source here.