Scala mapWhile - map with a conditional predicate

October 1, 2013

Earlier today I needed to apply a map function to a list for as long as a predicate was satisfied; once the predicate was false, I need the map function to bail. I was somewhat surprised to find Scala has no built-in method to accomplish such a task.

The easiest method would be to use takeWhile and map in combination:

val ids = (1 to 100).toList
def f(x: Int) = DB.fetch[Profile](x)
def p(x: Profile) = x.isArchived
val profiles = longs.takeWhile(p(f(_))).map(f(_))
// Spoiler - we are fetching the profile twice...

This does not work so well when the predicate is based around the map function, which is an expensive computation. In my case, I need to iterate over a list of profile ids, for each id retrieve a profile from the database, then test the profile for a certain condition, stopping the iteration as soon as the predicate fails. Instead of returning profile ids, as would be the case with takeWhile, I need to return the profile.

List.takeWhile allows testing each profile for the predicate, however it returns the original profile id, and not the newly fetched profile data, which I require. I came up with two solutions, one using a while loop and one using recursion (optimized for tail recursion).

class MapWhileList[A](self: List[A]) {
  def mapWhile[B](f: A => B, p: B => Boolean): List[B] = {
    val b = new ListBuffer[B]
    var these = self
    var a:Option[B] = Some(f(these.head))
    while (!these.isEmpty && p(a.get)) {
      b += a.get
      these = these.tail
      a = if (!these.isEmpty) Some(f(these.head)) else None
    }
    b.toList
  }

  def mapWhileRec[B](f: A => B, p: B => Boolean): List[B] = {
    @tailrec
    def loop(xs: List[A], acc: List[B]): List[B] = {
      if (xs.isEmpty) acc
      else {
        val b = f(xs.head)
        if (!p(b)) acc
        else loop(xs.tail, b :: acc)
      }
    }

    loop(self, Nil).reverse
  }
}

implicit def mapWhileList[A](self: List[A]) = new MapWhileList(self)

// Usage (slightly contrived)
def f(x: Int): Profile = DB.fetch[Profile](id) // expensive computation
def p(x: Profile): Boolean = !x.isArchived // We stop when we encounter an archived Profile

val ids = (1 to 100).toList // list of profile IDs
val mw1 = ids.mapWhile(f, p)
val mw2 = ids.mapWhileRec(f, p)

Tweets

I am a software developer from Winnipeg with a passion for learning about new things on a regular basis.
If you want to get in touch, email me.