Cleaner way to recurse on equivalent pairs of constructors?

I want to write a function which compares two elements of a sum type. I want to compare them in this manner.

  • when the constructors are different, the result is the entire first expression
  • when the constructors are the same, the result is obtained by recursing on the corresponding sub-expressions and combining the results

What is a clean way to do this?

Here is an example, with a small number of constructors.

{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}


import Control.Comonad.Cofree
import Data.Functor.Foldable.TH

type Op = String

data Stx
  = Lit Int
  | Add Stx Stx
  | Mul Stx Stx
  | SpecialBinOp Op Stx Stx
  deriving (Eq, Show, Generic)

makeBaseFunctor ''Stx

type StxAnn = Cofree StxF

data Diff = Diff
  { nodeId :: Int,
    newExpr :: StxAnn Int
  }

-- when the constructors are different, the result is the entire first expression
-- when the constructors are the same, the result is obtained by considering the corresponding sub-expressions
-- i.e, treeDiff (C x y) (D _ _) -> [C x y]
-- and, treeDiff (C x y) (D u v) -> treeDiff x u <> treeDiff y v
treeDiff :: StxAnn Int -> StxAnn Int -> [Diff]
treeDiff oldTree newTree = case (oldTree, newTree) of
  (nodeId :< LitF a, _ :< LitF b)
    | a /= b -> [Diff nodeId newTree]
    | otherwise -> []
  (_ :< AddF x1 y1, _ :< AddF x2 y2) -> treeDiff x1 x2 <> treeDiff y1 y2
  (_ :< MulF x1 y1, _ :< MulF x2 y2) -> treeDiff x1 x2 <> treeDiff y1 y2
  (nodeId :< SpecialBinOpF op1 x1 y1, _ :< SpecialBinOpF op2 x2 y2)
    | op1 == op2 -> treeDiff x1 x2 <> treeDiff y1 y2
    | otherwise -> [Diff nodeId newTree]
  (nodeId :< _, _ :< _) -> [Diff nodeId newTree]

It seems that this can be done using Control.Unification from unification-fd.

deriving stock instance (Generic1 StxF)

deriving anyclass instance (Unifiable StxF)

diff :: StxAnn Int -> StxAnn Int -> [Diff]
diff (annOld :< old) (annNew :< new) = case zipMatch old new of
  Nothing -> [Diff annOld (annNew :< new) ]
  Just tree -> foldMap f tree
    where
      f :: Either (Cofree StxF Int) (Cofree StxF Int, Cofree StxF Int) -> [Diff]
      f (Left _) =
        -- both sides have been completely unified
        []
      f (Right (x, y)) = diff x y
1 Like