"TTD with Idris"-style implementation of the State monad


#1

In Edwin Brady’s “TTD with Idris” an implementation of the state monad is given by

data State : (stateType : Type) -> Type -> Type where
     Get : State stateType stateType
     Put : stateType -> State stateType ()

     Pure : ty -> State stateType ty
     Bind : State stateType a -> (a -> State stateType b) ->
             State stateType b

get : State stateType stateType
get = Get

put : stateType -> State stateType ()
put = Put

mutual
  Functor (State stateType) where
      map func x = do val <- x
                      pure (func val)

  Applicative (State stateType) where
      pure = Pure
      (<*>) f a = do f' <- f
                     a' <- a
                     pure (f' a')

  Monad (State stateType) where
      (>>=) = Bind

runState : State stateType a -> (st : stateType) -> (a, stateType)
runState Get st = (st, st)
runState (Put newState) st = ((), newState)

runState (Pure x) st = (x, st)
runState (Bind cmd prog) st = let (val, nextState) = runState cmd st in
                                  runState (prog val) nextState

I know the “standard” implementation of the state monad newtype State s a = State { runState :: s -> (s, a) } (in fact, I have implemented the Monad type class for it as an exercise).

But I would like to have a Haskell “free” implementation which is as close as possible to the Idris one. The closest I got was this:

data State s a where
  Return :: a -> State s a
  Bind :: State s a -> (a -> State s b) -> State s b 
  Get :: State s s
  Put :: s -> State s ()

I am more familiar with Agda’s inductive data types than with Haskell’s GADTs, but it seems to me that Idris’ data State : (stateType : Type) -> Type -> Type and Haskell’s data State s a are different: the stateType is “fixed”, and the “type in the functorial position” can “vary” between constructors (not sure what the exact terminology should be here).

I have attempted to implement the Functor, Applicative, and Monad type classes for this data type, but I couldn’t do it without pattern matching on the four cases (Put, Get, Return, Bind), and even then I was at a loss. It is very elegant how the Idris implementation defines State to be an instance of those three type classes by mutual recursion; does Haskell allow mutually recursive definitions of instances of type classes? If not, how should I define these for the Idris-style State data type?


#2

Yes, GADTs do not let you express the difference between an index and a parameter, like inductive families in Agda and Idris do. This doesn’t matter much for the question you have here though.

It would have been cool if the Idris definition of this depended on that heavily, but it doesn’t. What you see in the Idris version is a simple syntactic sugar for (>>=). Idris doesn’t reserve the do-notation for monads, instead it desugars it into applications of (>>=). Now, the definition of (>>=) doesn’t have to come from a monad. Idris lets you overload function names as long as you give a different type. So you can define a separate (>>=) function and use do-notation for it, if you wanted to.

The only use of the mutual block above is that it uses the definition that (>>=) = Bind. This is not a super interesting definition. You could just desugar the do-notation by hand and write everything using Binds and lambdas if you wanted to. Here’s the Haskell version written that way:

{-# LANGUAGE GADTs #-}

data State s a where
  Return :: a -> State s a
  Bind :: State s a -> (a -> State s b) -> State s b
  Get :: State s s
  Put :: s -> State s ()

instance Functor (State s) where
  fmap f x = Bind x (\val -> Return (f val))

instance Applicative (State s) where
  pure = Return
  (<*>) f a = Bind f (\f' -> Bind a (\a' -> pure (f' a')))

instance Monad (State s) where
  return = pure
  (>>=) = Bind

runState :: State s a -> s -> (a, s)
runState Get st = (st, st)
runState (Put newState) st = ((), newState)
runState (Return x) st = (x, st)
runState (Bind cmd prog) st =
  let (val, nextState) = runState cmd st in runState (prog val) nextState

I already gave a way to define it above. But, funnily enough, Haskell does let you define type class instance in a mutually recursive way. So you can write

{-# LANGUAGE GADTs #-}

data State s a where
  Return :: a -> State s a
  Bind :: State s a -> (a -> State s b) -> State s b
  Get :: State s s
  Put :: s -> State s ()

instance Functor (State s) where
  fmap f x = do val <- x
                return (f val)

instance Applicative (State s) where
  pure = return
  (<*>) f a = do f' <- f
                 a' <- a
                 return (f' a')

instance Monad (State s) where
  return = pure
  (>>=) = Bind


runState :: State s a -> s -> (a, s)
runState Get st = (st, st)
runState (Put newState) st = ((), newState)
runState (Return x) st = (x, st)
runState (Bind cmd prog) st =
  let (val, nextState) = runState cmd st in runState (prog val) nextState

This compiles for me. What kind of errors did you get in your type class instances that required you to pattern match on the cases? The code blocks I gave here compile for me on GHC 8.4.3 without any problem.


#3

Maybe it’s a bit late to reply now, but if you’ve never come across this before, your State type looks very similar to operational monads. (I think these are also known as freer monads; refer to the paper for more information.)