Implementing IOrderedEnumerable, Part 2: The details

In my previous post, I developed the classes that implement IOrderedEnumerable so that we can create a chain of sort criteria that the GetEnumerator method can use to select the top items from the list. The harder part is writing that GetEnumerator method.

When I started this project I intended to use my generic heap class, DHeap, to do the selection. I actually wrote code to do that and almost published it. But the code was convoluted, and used a few “clever” tricks that were too ugly to post. Besides, the code was slow. It was faster than using OrderBy and Take, but not by a whole lot. It was slow and ugly, and not something that I wanted to post when there was a better (and, as it turns out, much faster) way to do things.

It’s instructive to understand the way the LINQ to Objects OrderBy implementation works. When code calls OrderBy, the LINQ code creates a chain of ordering criteria quite similar to what I showed in my previous post. The GetEnumerator method, when it’s eventually called, first extracts the sort keys and stores them in individual arrays. For example, if the code were OrderBy(p => p.Age).ThenBy(p => p.FirstName), then there are two key arrays created: one that contains all of the ages, and one that contains all of the first names. When sorting, the keys are compared and swapped as necessary.

Doing things that way costs memory, but lets us create and compare custom keys that might take significant time to generate. Rather than generating the key each time a comparison is made, the keys are generated one time and cached. It’s less than optimum, but it’s a much more flexible way to do things.

After studying the problem a bit (that is, after my initial solution was such a disappointment), I determined that I could use a technique that’s very similar to what the LINQ to Objects implementation of OrderBy does. But my key arrays are only the size of the heap (the number of items to be selected) rather than the size of the entire list. So if the code is supposed to select 100 items, then each key array is large enough to hold 101 items–the last item used as a scratch pad.

Each HeapOrderedEnumerable<TElement, TKey> instance has a _keys array and implements three abstract methods that are defined by the base HeapOrderedEnumerable<TElement> class:

  • ExtractKey(element, index) extracts a key of type TKey from the passed element and stores it in the _keys array at the specified index.
  • Compare(x, y) uses the comparison delegate to compare the keys at indexes x and y.
  • Swap(x, y) swaps the values in the _keys array located at indexes x and y.

The HeapOrderedEnumerable<TElement,TKey> class, then, is pretty simple. Modifying the code I posted last time, I just need to add the _keys array and the three methods that I outlined above. Here’s the completed class.

    internal class HeapOrderedEnumerable<TElement, TKey> : HeapOrderedEnumerable<TElement>
    {
        private readonly Func<TElement, TKey> _keySelector;
        private readonly IComparer<TKey> _comparer;
        private readonly bool _descending;

        private readonly TKey[] _keys;

        internal HeapOrderedEnumerable(
            IEnumerable<TElement> source,
            int numItems,
            Func<TElement, TKey> keySelector,
            IComparer<TKey> comparer,
            bool descending) : base(source, numItems)
        {
            _keySelector = keySelector;
            _comparer = comparer ?? Comparer<TKey>.Default;
            _descending = descending;

            // Allocate one extra key for the working item.
            _keys = new TKey[numItems+1];
        }

        public override int CompareKeys(int x, int y)
        {
            return _descending
                ? _comparer.Compare(_keys[y], _keys[x])
                : _comparer.Compare(_keys[x], _keys[y]);
        }

        public override void SwapKeys(int x, int y)
        {
            var t = _keys[x];
            _keys[x] = _keys[y];
            _keys[y] = t;
        }

        public override void ExtractKey(TElement item, int ix)
        {
            _keys[ix] = _keySelector(item);
        }
    }

I then created a class called HeapSelector that, given a list of ordering criteria (the chain of HeapOrderedEnumerable<TElement,TKey> instances) and a list of items, can select the top items that match those criteria. That class maintains an array of the items currently on the heap, and calls the ExtractKeyCompare, and Swap methods described above to maintain the keys for those items. Internally, HeapSelector implements a custom d-Heap to do the actual selection. The public interface consists of the constructor and the DoSelection method. Here’s the public interface.

    internal class HeapSelector<TElement>
    {
        public HeapSelector(
            IEnumerable source,
            HeapOrderedEnumerable[] criteria,
            int numItems);

        public IEnumerable DoSelect();
    }

The full code is shown at the end of this post, along with the rest.

All that’s left is the GetEnumerator method in HeapOrderedEnumerable<TElement> which, in concept, is very simple. It just has to create an array of ordering criteria to pass to the HeapSelector, create the HeapSelector, and then return the sequence generated by the DoSelect method. It’s almost that simple, except for one complication.

OrderBy and ThenBy state that they do a stable ordering. A stable ordering simply means that items that compare equal maintain their original relative order in the output. For example, imagine that I had this list of names and ages.

    Jim,52
    Ralph,30
    Susie,37
    Mary,52
    George,47

If I were to sort those names by descending age, the first two items could be Jim followed by Mary, or Mary followed by Jim. If a stable sort isn’t specified, then either ordering is correct. But if the sort is guaranteed stable, then Jim must come before Mary in the output because Jim appeared before Mary in the original list. Note that if I reversed the sort order (i.e. sorted by ascending age), Jim would still appear before Mary.

It’s reasonable for users of TopBy to expect a stable ordering, because that’s the way that OrderBy works and, more importantly, how ThenBy is documented to work. If I want ThenBy to work with TopBy, then TopBy must do a stable ordering. That complicates things a little bit because heap selection isn’t typically a stable operation.

It can be made stable, though, by adding one final ordering criterion: the record number. In the example above, Jim would be record 0 and Mary would be record 3. Each record is assigned a unique, monotonically increasing numeric key. If the other keys compare as equal, the record numbers are compared. This guarantees the correct ordering of otherwise equal items.

It turns out that adding that final ordering isn’t terribly complicated. The code essentially tacks on a final ThenByDescending that stores the unique record number. It then walks the chain of ordering criteria to build an array that it can pass to the HeapSelector constructor. Finally, it calls the HeapSelector instance’s DoSelect method.

This whole thing turned out to be about 30 lines longer than my initial attempt, but much easier to understand. It’s also about five times faster than my original code, which is a nice bonus. The entire code is shown below. Next time we’ll do a little testing to see how it stacks up with other ways of selecting the top items.

    using System;
    using System.Collections;
    using System.Collections.Generic;
    using System.Linq;

    namespace Heaps
    {
        public static class HeapEnumerable
        {
            public static IOrderedEnumerable<TSource> TopBy<TSource, TKey>(
                this IEnumerable<TSource> source,
                int numItems,
                Func<TSource, TKey> keySelector)
            {
                return new HeapOrderedEnumerable<TSource, TKey>(source, numItems, keySelector, null, false);
            }

            public static IOrderedEnumerable<TSource> TopBy<TSource, TKey>(
                this IEnumerable<TSource> source,
                int numItems,
                Func<TSource, TKey> keySelector,
                IComparer<TKey> comparer)
            {
                return new HeapOrderedEnumerable<TSource, TKey>(source, numItems, keySelector, comparer, false);
            }

            public static IOrderedEnumerable<TSource> TopByDescending<TSource, TKey>(
                this IEnumerable<TSource> source,
                int numItems,
                Func<TSource, TKey> keySelector)
            {
                return new HeapOrderedEnumerable<TSource, TKey>(source, numItems, keySelector, null, true);
            }

            public static IOrderedEnumerable<TSource> TopByDescending<TSource, TKey>(
                this IEnumerable<TSource> source,
                int numItems,
                Func<TSource, TKey> keySelector,
                IComparer<TKey> comparer)
            {
                return new HeapOrderedEnumerable<TSource, TKey>(source, numItems, keySelector, comparer, true);
            }

            internal abstract class HeapOrderedEnumerable<TElement> : IOrderedEnumerable<TElement>
            {
                private readonly IEnumerable<TElement> _source;
                private readonly int _numItems;
                internal HeapOrderedEnumerable<TElement> Parent;

                protected HeapOrderedEnumerable(
                    IEnumerable<TElement> source,
                    int numItems)
                {
                    _source = source;
                    _numItems = numItems;
                }

                public IOrderedEnumerable<TElement> CreateOrderedEnumerable<TKey>(
                    Func<TElement, TKey> keySelector,
                    IComparer<TKey> comparer, bool @descending)
                {
                    var oe = new HeapOrderedEnumerable<TElement, TKey>(
                        _source, _numItems, keySelector, comparer, descending);
                    oe.Parent = this;
                    return oe;
                }

                public IEnumerator<TElement> GetEnumerator()
                {
                    int numRecs = 0;
                    var recordKeySelector = new Func<TElement, int>(item => ++numRecs);

                    // Add a ThenByDescending for the record key.
                    // This ensures a stable ordering.
                    var oe = (HeapOrderedEnumerable<TElement>)CreateOrderedEnumerable(recordKeySelector, null, true);

                    // Get the ordering criteria, starting with the last ordering clause.
                    // Which will always be the record key ordering.
                    var criteria = oe.GetCriteria().ToArray();

                    var selector = new HeapSelector<TElement>(_source, criteria, _numItems);
                    return selector.DoSelect().GetEnumerator();
                }

                // Walks the ordering criteria to build an array that the HeapSelector can use.
                private IEnumerable<HeapOrderedEnumerable<TElement>> GetCriteria()
                {
                    var keys = new Stack<HeapOrderedEnumerable<TElement>>();

                    var current = this;
                    while (current != null)
                    {
                        keys.Push(current);
                        current = current.Parent;
                    }
                    return keys;
                }

                IEnumerator IEnumerable.GetEnumerator()
                {
                    return GetEnumerator();
                }

                // The individual ordering criteria instances (HeapOrderedEnumerable<TElement, TKey>)
                // implement these abstract methods to provice key extraction, comparison, and swapping.
                public abstract void ExtractKey(TElement item, int ix);
                public abstract int CompareKeys(int x, int y);
                public abstract void SwapKeys(int x, int y);
            }

            internal class HeapOrderedEnumerable<TElement, TKey> : HeapOrderedEnumerable<TElement>
            {
                private readonly Func<TElement, TKey> _keySelector;
                private readonly IComparer<TKey> _comparer;
                private readonly bool _descending;

                private readonly TKey[] _keys;

                internal HeapOrderedEnumerable(
                    IEnumerable<TElement> source,
                    int numItems,
                    Func<TElement, TKey> keySelector,
                    IComparer<TKey> comparer,
                    bool descending) : base(source, numItems)
                {
                    _keySelector = keySelector;
                    _comparer = comparer ?? Comparer<TKey>.Default;
                    _descending = descending;

                    // Allocate one extra key for the working item.
                    _keys = new TKey[numItems+1];
                }

                public override int CompareKeys(int x, int y)
                {
                    return _descending
                        ? _comparer.Compare(_keys[y], _keys[x])
                        : _comparer.Compare(_keys[x], _keys[y]);
                }

                public override void SwapKeys(int x, int y)
                {
                    var t = _keys[x];
                    _keys[x] = _keys[y];
                    _keys[y] = t;
                }

                public override void ExtractKey(TElement item, int ix)
                {
                    _keys[ix] = _keySelector(item);
                }
            }

            internal class HeapSelector<TElement>
            {
                private readonly IEnumerable<TElement> _source;
                private readonly HeapOrderedEnumerable<TElement>[] _criteria;
                private readonly int _numItems;
                private readonly TElement[] _items;
                private int _count;

                public HeapSelector(
                    IEnumerable<TElement> source,
                    HeapOrderedEnumerable<TElement>[] criteria,
                    int numItems)
                {
                    _source = source;
                    _criteria = criteria;
                    _numItems = numItems;
                    _items = new TElement[numItems+1];
                }

                public IEnumerable<TElement> DoSelect()
                {
                    // Build a heap from the first _numItems items
                    var srcEnumerator = _source.GetEnumerator();
                    while (_count < _numItems && srcEnumerator.MoveNext())
                    {
                        ExtractKeys(srcEnumerator.Current, _count);
                        ++_count;
                    }
                    Heapify();

                    // For each remaining item . . .
                    while (srcEnumerator.MoveNext())
                    {
                        ExtractKeys(srcEnumerator.Current, _numItems);
                        if (Compare(_numItems, 0) > 0)
                        {
                            // The current item is larger than the smallest item.
                            // So move the current item to the root and sift down.
                            Swap(0, _numItems);
                            SiftDown(0);
                        }
                    }

                    // Top N items are on the heap. Sort them.
                    int saveCount = _count;
                    while (_count > 0)
                    {
                        --_count;
                        Swap(0, _count);
                        SiftDown(0);
                    }

                    // And then return.
                    // Have to use the Take here because it's possible that saveCount
                    // will be smaller than _numItems.
                    return _items.Take(saveCount);
                }

                private const int ary = 3;

                private void Heapify()
                {
                    for (int i = _count / ary; i >= 0; --i)
                    {
                        SiftDown(i);
                    }
                }

                private void SiftDown(int ix)
                {
                    while ((ary*ix) + 1 < _count)
                    {
                        var child = (ix*ary) + 1;
                        // find the smallest child
                        var currentSmallestChild = child;
                        var maxChild = child + ary;
                        if (maxChild > _count) maxChild = _count;
                        ++child;
                        while (child < maxChild)
                        {
                            if (Compare(currentSmallestChild, child) > 0)
                                currentSmallestChild = child;
                            ++child;
                        }

                        child = currentSmallestChild;
                        if (Compare(child, ix) >= 0)
                            break;
                        Swap(ix, child);
                        ix = child;
                    }
                }

                private void ExtractKeys(TElement item, int ix)
                {
                    // Extract keys from the record into the array at index ix.
                    // Also calls the ExtractKey method for each ordering criteria.
                    _items[ix] = item;
                    foreach (var t in _criteria)
                    {
                        t.ExtractKey(item, ix);
                    }
                }

                private int Compare(int x, int y)
                {
                    // Walks the list of comparers, doing the comparisons.
                    // The first unequal comparison short-circuits the loop.
                    var rslt = 0;
                    for (int i = 0; rslt == 0 && i < _criteria.Length; ++i)
                    {
                        rslt = _criteria[i].CompareKeys(x, y);
                    }
                    return rslt;
                }

                // Swap two items. This swaps the elements in the local array,
                // and calls the Swap method for each of the ordering criteria.
                private void Swap(int x, int y)
                {
                    var temp = _items[x];
                    _items[x] = _items[y];
                    _items[y] = temp;
                    foreach (var t in _criteria)
                    {
                        t.SwapKeys(x, y);
                    }
                }
            }
        }
    }