2013-12-31

How to implement in-place set intersection in C++

This blog post shows and explains C++ source code implementing in-place set intersection, i.e. removing each element from a set (or another sorted container) *bc which is not also a member of container bc.

The std::intersection function template in the <algorithm> header in C++ standard template library populates a new output set, adding all elements in the intersection into it. This can be too slow and a waste of memory if one of the inputs is not needed afterwards. In this case an in-place intersection is desired instead, but unfortunately such a function template is not part of the C++ standard template library.

Here is a simple in-place implementation which looks up each element of *bc in ac, and removes (erases) it from *bc if not found:

#include <set>

// Remove elements from bc which are missing from ac.
//
// The time required is proportional to log(ac.size()) * bc->size(), so it's
// faster than IntersectionUpdate if ac is large compared to bc.
template<typename Input, typename Output>
static void IntersectionUpdateLargeAc(
    const std::set<Input> &ac, std::set<Output> *bc) {
  const typename std::set<Input >::const_iterator a_end = ac.end();
  const typename std::set<Output>::const_iterator b_end = bc->end();
  for (typename std::set<Output>::iterator b = bc->begin(); b != b_end; ) {
    if (ac.find(*b) == a_end) {  // Not found.
      // Not a const_iterator, erase wouldn't accept it until C++11.
      const typename std::set<Output>::iterator b_old = b++;
      bc->erase(b_old);  // erase doesn't invalidate b.
    } else {
      ++b;
    }
  }
}

Removing from *bc above is a bit tricky, because we don't want to invalidate the iterator b. In C++11 erase returns a new iterator, which is just after the removed elements, but we don't use that just to be backwards-compatible. Instead of that we make use of the fact that iterators to the non-removed elements are kept intact for set, multiset and list, so we create the temporary iterator b_old, which will be invalidated, but b remains valid.

We need the typename keyword in local variable declarations, because they have a dependent type (i.e. a type whose identifier is within another type specified by a template parameter.)

The time complexity is O(log(as) · bs), so it is fast if ac is large when compared to *bc. For example, when as = 3k and bs = k, then it's O(k2).

As an alternative, we could iterate over the two sets in increasing (ascending) order at the same time, similarly to the merge operation (as implemented by std::merge in mergesort, but dropping elements from *bc if there is no corresponding element in ac. One possible implementation:

#include <set>

// Remove elements from bc which are missing from ac.
//
// The time required is proportional to ac.size() + bc->size().
template<typename Input, typename Output>
static void IntersectionUpdate(
    const std::set<Input> &ac, std::set<Output> *bc) {
  typename std::set<Input>::const_iterator a = ac.begin();
  const typename std::set<Input>::const_iterator a_end = ac.begin();
  typename std::set<Output>::iterator b = bc->begin();
  const typename std::set<Output>::iterator b_end = bc->end();
  while (a != a_end && b != b_end) {
    if (*a < *b) {
      ++a;
    } else if (*a > *b) {
      const typename std::set<Output>::iterator b_old = b++;
      bc->erase(b_old);  // erase doesn't invalidate b.
    } else {  // Elements are equal, keep them in the intersection.
      ++a;
      ++b;
    }
  }
  bc->erase(b, b_end);  // Remove remaining elements in bc but not in ac.
}

The time complexity of this above (IntersectionUpdate) is O(as + bs), which is faster than IntersectionUpdateLargeAc if ac is not much smaller than *bc. For example, when as = 3k and bs = k, then it's O(3k + k), so IntersectionUpdateLargeAc is faster.

Example usage of both (just to see if they compile):

int main(int, char**) {
  std::set<int> a, b;
  IntersectionUpdateLargeAc(a, &b);
  IntersectionUpdate(a, &b);
  return 0;
}

It's natural to ask if these function templates can be generalized to C++ containers other than set. They take advantage of the input being sorted, so let's consider sorted std::vector, sorted std::list and std::multiset in addition to std::set. To avoid the complexity of having to distinguish keys from values, let's ignore std::map and std::multimap.

The generalization of IntersectionUpdateLargeAc from set to multiset is trivial: no code change is necessary. The std::multiset::find operation returns any matching element, which is good for us. However, with IntersectionUpdate, the last ++a; must be removed: without the removal subsequent occurrences of the same value in *bc would be removed if ac contains this value only once. No other code change is needed. It is tempting to introduce a loop in the previous (*a > *b) if branch:

for (;;) {
  const typename Output::iterator b_old = b++;
  const bool do_break = b == b_end || *b_old != *b;
  bc->erase(b_old);  // erase doesn't invalidate b.
  if (do_break) break;
}

However, this change is not necessary, because subsequent equal values in *bc would be removed in subsequent iterations of the outer loop.

Here are the full generalized implementations:

#if __cplusplus >= 201103 || __GXX_EXPERIMENTAL_CXX0X__
#include <type_traits>
#endif

// Remove elements from bc which are missing from ac. Supported containers for 
// bc: list (only if sorted), vector (only if sorted), set, multiset. Supported
// containers for ac: set, multiset.
//
// The time required is proportional to log(ac.size()) * bc->size(), so it's
// faster than IntersectionUpdate if ac is large compared to bc.
template<typename Input, typename Output>
static void IntersectionUpdateLargeAc(const Input &ac, Output *bc) {
#if __cplusplus >= 201103 || __GXX_EXPERIMENTAL_CXX0X__
  // We could use std::is_convertible (both ways) instead of std::is_same.
  static_assert(std::is_same<typename Input::value_type,
                             typename Output::value_type>::value,
                "the containers passed to IntersectionUpdateLargeAc() need to "
                "have the same value_type");
#endif
  const typename Input::const_iterator a_end = ac.end();
  const typename Output::const_iterator b_end = bc->end();
  for (typename Output::iterator b = bc->begin(); b != b_end; ) {
    if (ac.find(*b) == a_end) {  // Not found.
      // Not a const_iterator, erase wouldn't accept it until C++11.
      const typename Output::iterator b_old = b++;
      bc->erase(b_old);  // erase doesn't invalidate b.
    } else {
      ++b;
    }
  }
}

// Remove elements from bc which are missing from ac. Supported containers for 
// ac and bc: list (only if sorted), vector (only if sorted), set, multiset.
template<typename Input, typename Output>
static void IntersectionUpdate(const Input &ac, Output *bc) {
#if __cplusplus >= 201103 || __GXX_EXPERIMENTAL_CXX0X__
  static_assert(std::is_same<typename Input::value_type,
                             typename Output::value_type>::value,
                "the containers passed to IntersectionUpdate() need to "
                "have the same value_type");
#endif
  typename Input::const_iterator a = ac.begin();
  const typename Input::const_iterator a_end = ac.end();
  typename Output::iterator b = bc->begin();
  // Can't be a const interator, similarly to b_old.
  const typename Output::iterator b_end = bc->end();
  while (a != a_end && b != b_end) {
    if (*a < *b) {
      ++a;
    } else if (*a > *b) {
      const typename Output::iterator b_old = b++;
      bc->erase(b_old);  // erase doesn't invalidate b.
    } else {  // Elements are equal, keep it in the intersection.
      // Don't do ++a, in case ac is a multiset.
      ++b;
    }
  }
  bc->erase(b, b_end);  // Remove remaining elements in bc but not in ac.
}

These work as expected for set, multiset and sorted list. It also doesn't require that the two containers are of the same kind. For C++0x and C++11, an extra static_assert is present in the code to print a helpful compact error message if the base types are different.

However, when *bc is a vector, we get a compile error, because in C++ older than C++11, std::vector::erase doesn't return an iterator (but it returns void). Even if we could get an iterator, b_end would be invalidated by erase, because it's behind it. This is easy to fix, we should use bc->end() instead of b_end everywhere. However, if we didn't make any other changes, the algorithm would be slower than necessary, because std::vector::erase moves each element behind the erased one. So the time complexity would be O(as + bs2). To speed it up, let's swap the to-be-removed elements with the element with the last element of the vector, and to the actual removal at the end of the function:

#if __cplusplus >= 201103 || __GXX_EXPERIMENTAL_CXX0X__
#include <type_traits>
#include <utility>  // std::swap.
#else
#include <algorithm>  // std::swap.
#endif

// Template specialization for vector output.
template<typename Input, typename T>
static void IntersectionUpdate(const Input &ac, std::vector<T> *bc) {
#if __cplusplus >= 201103 || __GXX_EXPERIMENTAL_CXX0X__
  static_assert(std::is_same<typename Input::value_type, T>::value,
                "the containers passed to IntersectionUpdate() need to "
                "have the same value_type");
#endif
  typename Input::const_iterator a = ac.begin();
  const typename Input::const_iterator a_end = ac.end();
  typename std::vector<T>::iterator b = bc->begin();
  // Elements between b_high an bc->end() will be removed (erased) right before
  // the function returns. We defer their removal to save time.
  typename std::vector<T>::iterator b_high = bc->end();
  while (a != a_end && b != b_high) {
    if (*a < *b) {
      ++a;
    } else if (*a > *b) {
      std::iter_swap(b, --b_high);  // Works even if swapping with itself.
    } else {  // Elements are equal, keep them in the intersection.
      ++a;
      ++b;
    }
  }
  bc->erase(b, bc->end());  // Remove remaining elements in bc but not in ac.
}

Once we have the generic implementation and the special implementation for vector in the same file, the C++ compiler would take care of choosing the right (most specific) one depending on whether *bc is a vector or not. So all these work now:

#include <list>
#include <set>
#include <vector>

int main(int, char**) {
  std::set<int> s;
  std::multiset<int> ms;
  std::vector<int> v;
  // std::list<unsigned> l;  // Won't work in C++0x and C++11.
  std::list<int> l;
  IntersectionUpdate(s, &ms);
  IntersectionUpdate(ms, &v);
  IntersectionUpdate(v, &l);
  IntersectionUpdate(l, &s);
  IntersectionUpdateLargeAc(s, &ms);
  IntersectionUpdateLargeAc(ms, &v);
  // IntersectionUpdateLargeAc(v, &l);  // v is not good as ac.
  // IntersectionUpdateLargeAc(l, &s);  // l is not good as ac.
  IntersectionUpdateLargeAc(s, &l);
  IntersectionUpdateLargeAc(ms, &s);
  return 0;
}

The full source code is available on GitHub.

No comments: