Speeding up Shortest Common Supersequence

I’m having a hard time passing the last group of tests of Kattis’ Nafnagift, in which you should find the shortest common supersequence of two strings.

I tried implementing memoization using advice from one of Brent Yorgey’s blog posts about automatic memoization but it is not enough, although memoization noticeably sped up my solution.

Here is the full solution. What improvements would you make?

{-# LANGUAGE LambdaCase #-}

module Main where

import Control.Arrow ((>>>))
import Control.Monad (join)
import Data.Array (Array, Ix, listArray, (!))
import Data.Bifunctor (bimap)
import Data.Either (isLeft)
import Data.Foldable (maximumBy)
import Data.List (tails)
import Data.Ord (comparing)

main :: IO ()
main = interact $ lines >>> (\[m, n] -> solve (m, n))

solve :: (Eq a) => ([a], [a]) -> [a]
solve (m, n) = map (either id id) $ scs (m, n)
 where
  rng = ((0, 0), (length m, length n))
  is = reverse $ (,) <$> tails m <*> tails n

  scs = memo is rng $ \case
    (x : xs, y : ys)
      | x == y -> Left x : scs (xs, ys)
      | otherwise ->
          maximumBy
            (comparing (length . filter isLeft))
            [Right y : scs (x : xs, ys), Right x : scs (xs, y : ys)]
    (xs, ys) -> Right <$> xs <> ys

tabulate :: (Ix i) => [a] -> (i, i) -> (a -> e) -> Array i e
tabulate is rng f = listArray rng (map f is)

memo is rng f i = tabulate is rng f ! join bimap length i

The biggest issue I can see is that your scs function takes actual lists as inputs, and computes their lengths in order to look up memoized results in the table. This means you calculate the lengths of every possible pair of tails of the inputs, which I think makes the whole thing O(n^3) instead of O(n^2). Instead, put the two input strings into immutable arrays so you can look up individual characters by index in constant time, and write the scs function so it takes a pair of indices instead of a pair of lists, and looks up characters in the immutable arrays instead of using pattern-matching.

4 Likes

Yes, that was it! Thank you! The last thing I had to do was replace the output list by a Sequence, to get the length in constant time. Also, for the keen eyed: I had complicated it slightly by doing the comparison using maximumBy, minimumBy is correct and removes some complexity.

For anyone finding this in the future, here is what I ended up with:

module Nafnagift where

import Control.Arrow ((>>>))
import Data.Array (Array, Ix (..), listArray, (!))
import Data.Foldable (Foldable (toList), minimumBy)
import Data.Ord (comparing)
import Data.Sequence ((<|))
import qualified Data.Sequence as Seq

main :: IO ()
main = interact $ lines >>> (\[m, n] -> solve m n)

solve :: (Eq a) => [a] -> [a] -> [a]
solve m n = toList $ scs (0, 0)
 where
  rng = ((0, 0), (length m, length n))
  m' = listArray (0, length m) m
  n' = listArray (0, length n) n

  scs = memo rng $ \(i, j) ->
    let x = m' ! i
        y = n' ! j
        i' = i + 1
        j' = j + 1
     in if inRange rng (i', j')
          then
            if x == y
              then x <| scs (i', j')
              else minimumBy (comparing length) [y <| scs (i, j'), x <| scs (i', j)]
          else Seq.fromList $ drop i m <> drop j n

tabulate :: (Ix i) => (i, i) -> (i -> e) -> Array i e
tabulate rng f = listArray rng (map f $ range rng)

memo :: (Ix i) => (i, i) -> (i -> e) -> i -> e
memo rng = (!) . tabulate rng

2 Likes