The MergeWithBy LINQ extension

Last time, I showed how to write the standard two-way merge using LINQ. With that code, you can easily merge two sorted sequences from any type of container that implements the IEnumerable<T> interface. That in itself is very useful, and with the addition of a parameter that defines the comparison function, it likely would suffice for most merge operations. However, it’s less than satisfactory for the same reasons I pointed out in my introduction to the TopBy method earlier this year: sometimes creating an IComparer is a giant pain in the neck. To make a LINQ-compatible merge capability, I really need to implement IOrderedEnumerable.

I’ve struggled with what to call the method. Obvious candidates are MergeMergeByMergeWith, and MergeWithBy, but each of those has its drawbacks. Consider, for example, the possibilities:

    var merged = List1.Merge(List2, k => k.Age).ThenBy(k => k.LastName);
    var merged = List1.MergeBy(List2, k => k.Age).ThenBy(k => k.LastName);
    var merged = List1.MergeWith(List2, k => k.Age).ThenBy(k => k.LastName);
    var merged = List1.MergeWithBy(List2, k => k.Age).ThenBy(k => k.LastName);

I’m not 100% happy with any one of those options but after struggling with it for a while I decided on MergeWithBy which, although a little clunky to the ear (MergeWithByDescending in particular), is the most descriptive. I’m merging one list with another by the specified ordering criteria.

With the naming out of the way, the easy part is defining the LINQ extension methods. These exactly mirror the standard OrderBy methods as well as my TopBy methods.

    public static class MergeExtension
    {
        public static IOrderedEnumerable<TSource> MergeWithBy<TSource, TKey>(
            this IEnumerable<TSource> list1,
            IEnumerable<TSource> list2,
            Func<TSource, TKey> keySelector)
        {
            return MergeWithBy(list1, list2, keySelector, null);
        }

        public static IOrderedEnumerable<TSource> MergeWithBy<TSource, TKey>(
            this IEnumerable<TSource> list1,
            IEnumerable<TSource> list2,
            Func<TSource, TKey> keySelector,
            IComparer<TKey> comparer)
        {
            return new MergeOrderedEnumerable<TSource, TKey>(list1, list2, keySelector, comparer, false);
        }

        public static IOrderedEnumerable<TSource> MergeWithByDescending<TSource, TKey>(
            this IEnumerable<TSource> list1,
            IEnumerable<TSource> list2,
            Func<TSource, TKey> keySelector)
        {
            return MergeWithByDescending(list1, list2, keySelector, null);
        }

        public static IOrderedEnumerable<TSource> MergeWithByDescending<TSource, TKey>(
            this IEnumerable<TSource> list1,
            IEnumerable<TSource> list2,
            Func<TSource, TKey> keySelector,
            IComparer<TKey> comparer)
        {
            return new MergeOrderedEnumerable<TSource, TKey>(list1, list2, keySelector, comparer, true);
        }
    }

The harder part is implementing the IOrderedEnumerable<T> interface that actually does the work. The idea behind it, as I described in my two-part series on IOrderedEnumerable, is to create, for each of the ordering criteria, a class instance that can compare the relevant key. There is a base class, MergeOrderedEnumerable<TSource<, and a derived class for each key type: MergeOrderedEnumerable<TSource, TKey>.

I covered the details of how those two classes work in the article linked above, and in the second part. For the most part, those two classes are just bookkeeping: setting things up so that a third class, the selector, can do the actual work. The full source of the merge extension class is shown at the end of this post.

The guts of the merge–the part that implements the Merge method I showed last time–is the MergeSelector.DoSelect method, shown here:

    public IEnumerable<TSource> DoSelect()
    {
        // Initialize the iterators
        var iterators = new IEnumerator<TSource>[2];

        var next = new Func<int, bool>(ix =>
        {
            if (!iterators[ix].MoveNext()) return false;
            ExtractKeys(iterators[ix].Current, ix);
            return true;
        });

        iterators[0] = _list1.GetEnumerator();
        iterators[1] = _list2.GetEnumerator();
        var i1HasItems = next(0);
        var i2HasItems = next(1);
        while (i1HasItems && i2HasItems)
        {
            if (Compare(0, 1) <= 0)
            {
                yield return iterators[0].Current;
                i1HasItems = next(0);
            }
            else
            {
                yield return iterators[1].Current;
                i2HasItems = next(1);
            }
        }

        while (i1HasItems)
        {
            yield return iterators[0].Current;
            i1HasItems = next(0);
        }

        while (i2HasItems)
        {
            yield return iterators[1].Current;
            i2HasItems = next(1);
        }
    }

This code is nearly a direct translation of the Merge I showed the last time. The primary difference in the structure of the code is that I put the iterators in an array so that I could reduce the amount of duplicated code. The next function advances to the next item in whichever list is specified. Because I have to extract keys that are maintained by the individual MergeOrderedEnumerable<TSource, TKey> instances, without that function I’d have to write code like this four times (twice for i1, and twice for i2):

    yield return i1.Current;
    i1.HasItems = i1.MoveNext();
    if (i1.HasItems)
    {
        ExtractKeys(i1.Current, 0);
    }

I suppose I could have combined the code:

    if ((i1.HasItems = it.MoveNext())) ExtractKeys(i1.Current, 0);

Call it a matter of style.

That’s really all there is to it. The MergeOrderedEnumerable classes look unnecessarily complex at first glance, but after you study them for a few minutes, you understand why all that code is there. It’s particularly instructive to set up a short merge of a few items, and then single-step the code. Not only will you see how all this stuff works together, you’ll gain a better understanding of how LINQ works in general.

That’s almost all there is to merging two sequences. The only thing remaining is the matter of uniqueness, which I’ll talk about next time in Removing duplicates.

    namespace EnumerableExtensions
    {
        public static class MergeExtension
        {
            public static IOrderedEnumerable<TSource> MergeWithBy<TSource, TKey>(
                this IEnumerable<TSource> list1,
                IEnumerable<TSource> list2,
                Func<TSource, TKey> keySelector)
            {
                return MergeWithBy(list1, list2, keySelector, null);
            }

            public static IOrderedEnumerable<TSource> MergeWithBy<TSource, TKey>(
                this IEnumerable<TSource> list1,
                IEnumerable<TSource> list2,
                Func<TSource, TKey> keySelector,
                IComparer<TKey> comparer)
            {
                return new MergeOrderedEnumerable<TSource, TKey>(list1, list2, keySelector, comparer, false);
            }

            public static IOrderedEnumerable<TSource> MergeWithByDescending<TSource, TKey>(
                this IEnumerable<TSource> list1,
                IEnumerable<TSource> list2,
                Func<TSource, TKey> keySelector)
            {
                return MergeWithByDescending(list1, list2, keySelector, null);
            }

            public static IOrderedEnumerable<TSource> MergeWithByDescending<TSource, TKey>(
                this IEnumerable<TSource> list1,
                IEnumerable<TSource> list2,
                Func<TSource, TKey> keySelector,
                IComparer<TKey> comparer)
            {
                return new MergeOrderedEnumerable<TSource, TKey>(list1, list2, keySelector, comparer, true);
            }
        }

        internal abstract class MergeOrderedEnumerable<TSource> : IOrderedEnumerable<TSource>
        {
            private readonly IEnumerable<TSource> _list1;
            private readonly IEnumerable<TSource> _list2;
            internal MergeOrderedEnumerable<TSource> Parent;

            protected MergeOrderedEnumerable(
                IEnumerable<TSource> list1,
                IEnumerable<TSource> list2)
            {
                _list1 = list1;
                _list2 = list2;
            }

            public IOrderedEnumerable<TSource> CreateOrderedEnumerable<TKey>(
                Func<TSource, TKey> keySelector,
                IComparer<TKey> comparer,
                bool @descending)
            {
                var oe = new MergeOrderedEnumerable<TSource, TKey>(
                    _list1, _list2, keySelector, comparer, descending) {Parent = this};
                return oe;
            }

            public IEnumerator<TSource> GetEnumerator()
            {
                var criteria = GetCriteria().ToArray();

                var selector = new MergeSelector<TSource>(_list1, _list2, criteria);
                return selector.DoSelect().GetEnumerator();
            }

            // Walks the ordering criteria to build an array that the HeapSelector can use.
            private IEnumerable<MergeOrderedEnumerable<TSource>> GetCriteria()
            {
                var keys = new Stack<MergeOrderedEnumerable<TSource>>();

                var current = this;
                while (current != null)
                {
                    keys.Push(current);
                    current = current.Parent;
                }
                return keys;
            }

            IEnumerator IEnumerable.GetEnumerator()
            {
                return GetEnumerator();
            }

            // The individual ordering criteria instances (MergeOrderedEnumerable<TElement, TKey>)
            // implement these abstract methods to provice key extraction, comparison, and swapping.
            public abstract void ExtractKey(TSource item, int ix);
            public abstract int CompareKeys(int x, int y);
        }

        internal class MergeOrderedEnumerable<TSource, TKey> : MergeOrderedEnumerable<TSource>
        {
            private readonly Func<TSource, TKey> _keySelector;
            private readonly IComparer<TKey> _comparer;
            private readonly bool _descending;
            private readonly TKey[] _keys;

            internal MergeOrderedEnumerable(
                IEnumerable<TSource> list1,
                IEnumerable<TSource> list2,
                Func<TSource, TKey> keySelector,
                IComparer<TKey> comparer,
                bool descending)
                : base(list1, list2)
            {
                _keySelector = keySelector;
                _comparer = comparer ?? Comparer<TKey>.Default;
                _descending = descending;

                _keys = new TKey[2];
            }

            public override int CompareKeys(int x, int y)
            {
                return _descending
                    ? _comparer.Compare(_keys[y], _keys[x])
                    : _comparer.Compare(_keys[x], _keys[y]);
            }

            public override void ExtractKey(TSource item, int ix)
            {
                _keys[ix] = _keySelector(item);
            }
        }

        internal class MergeSelector<TSource>
        {
            private readonly IEnumerable<TSource> _list1;
            private readonly IEnumerable<TSource> _list2;
            private readonly MergeOrderedEnumerable<TSource>[] _criteria;

            public MergeSelector(
                IEnumerable<TSource> list1,
                IEnumerable<TSource> list2,
                MergeOrderedEnumerable<TSource>[] criteria)
            {
                _list1 = list1;
                _list2 = list2;
                _criteria = criteria;
            }

            public IEnumerable<TSource> DoSelect()
            {
                // Initialize the iterators
                var iterators = new IEnumerator<TSource>[2];

                var next = new Func<int, bool>(ix =>
                {
                    if (!iterators[ix].MoveNext()) return false;
                    ExtractKeys(iterators[ix].Current, ix);
                    return true;
                });

                iterators[0] = _list1.GetEnumerator();
                iterators[1] = _list2.GetEnumerator();
                var i1HasItems = next(0);
                var i2HasItems = next(1);
                while (i1HasItems && i2HasItems)
                {
                    if (Compare(0, 1) <= 0)
                    {
                        yield return iterators[0].Current;
                        // I could do a loop here using ExtractCompare to compare against
                        // item 2. That would reduce the key extraction as long as the
                        // new item from list 1 is smaller than the new item from list 2.
                        // Then extract all of the keys once l1 goes high.
                        // Lots of code that might not be particularly useful.
                        i1HasItems = next(0);
                    }
                    else
                    {
                        yield return iterators[1].Current;
                        i2HasItems = next(1);
                    }
                }

                while (i1HasItems)
                {
                    yield return iterators[0].Current;
                    // TODO: Could add an "extract" parameter to the next function
                    // If "extract" is false, it doesn't extract the keys.
                    // That would speed up the tail copying,
                    // but might not be worth the trouble.
                    i1HasItems = next(0);
                }

                while (i2HasItems)
                {
                    yield return iterators[1].Current;
                    i2HasItems = next(1);
                }
            }

            private int Compare(int x, int y)
            {
                var rslt = 0;
                for (var i = 0; rslt == 0 && i < _criteria.Length; ++i)
                {
                    rslt = _criteria[i].CompareKeys(x, y);
                }
                return rslt;
            }

            private void ExtractKeys(TSource item, int ix)
            {
                foreach (var t in _criteria)
                {
                    t.ExtractKey(item, ix);
                }
            }
        }
    }