I want to write a function which compares two elements of a sum type. I want to compare them in this manner.
- when the constructors are different, the result is the entire first expression
- when the constructors are the same, the result is obtained by recursing on the corresponding sub-expressions and combining the results
What is a clean way to do this?
Here is an example, with a small number of constructors.
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
import Control.Comonad.Cofree
import Data.Functor.Foldable.TH
type Op = String
data Stx
= Lit Int
| Add Stx Stx
| Mul Stx Stx
| SpecialBinOp Op Stx Stx
deriving (Eq, Show, Generic)
makeBaseFunctor ''Stx
type StxAnn = Cofree StxF
data Diff = Diff
{ nodeId :: Int,
newExpr :: StxAnn Int
}
-- when the constructors are different, the result is the entire first expression
-- when the constructors are the same, the result is obtained by considering the corresponding sub-expressions
-- i.e, treeDiff (C x y) (D _ _) -> [C x y]
-- and, treeDiff (C x y) (D u v) -> treeDiff x u <> treeDiff y v
treeDiff :: StxAnn Int -> StxAnn Int -> [Diff]
treeDiff oldTree newTree = case (oldTree, newTree) of
(nodeId :< LitF a, _ :< LitF b)
| a /= b -> [Diff nodeId newTree]
| otherwise -> []
(_ :< AddF x1 y1, _ :< AddF x2 y2) -> treeDiff x1 x2 <> treeDiff y1 y2
(_ :< MulF x1 y1, _ :< MulF x2 y2) -> treeDiff x1 x2 <> treeDiff y1 y2
(nodeId :< SpecialBinOpF op1 x1 y1, _ :< SpecialBinOpF op2 x2 y2)
| op1 == op2 -> treeDiff x1 x2 <> treeDiff y1 y2
| otherwise -> [Diff nodeId newTree]
(nodeId :< _, _ :< _) -> [Diff nodeId newTree]