Chaining different states on type-level

I’d like to chain state monads with different state types together, like:

stage1 :: a -> State t1 b
stage2 :: b -> State t2 c

chain :: (t1, t2) -> a -> ((t1, t2), c)
chain (s1, s2) = \x ->
  let (y, s1') = runState (stage1 x) s1
      (z, s2') = runState (stage2 y) s2
  in ((s1', s2'), z)

And I’m thinking of making a more generic version, turning the above example into (pseudocode):

chain (s1 :+: s2 :+: HNil) (stage1 :+: stage2 :+: HNil) = \x ->
  let (y, s1') = runState (stage1 x) s1
      (z, s2') = runState (stage2 y) s2
  in (s1' :+: s2' :+: HNil, z)

Finally I end up with these code, but it fails to compile (GHC 9.4.8):

{-# LANGUAGE DataKinds #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE AllowAmbiguousTypes #-}

module Main where

import Data.Kind (Type)
import Control.Monad.State

type family KFrom (f :: Type) :: Type where
  KFrom (a -> State t b) = a
type family KTo (f :: Type) :: Type where
  KTo (a -> State t b) = b
type family From (ks :: [Type]) :: Type where
  From (k ': ks) = KFrom k
type family To (ks :: [Type]) :: Type where
  To '[k] = KTo k
  To (k ': ks) = To ks

-- Errors occur in the definition of Chain
class Chain (ts :: [Type]) (ks :: [Type]) where
  chain :: HList ts -> HList ks -> From ks -> (HList ts, To ks)
instance Chain '[t] '[a -> State t b] where
  chain (s :+: HNil) (f :+: HNil) = \x ->
    let (y, s') = runState (f x) s
    in (s' :+: HNil, y)
instance Chain ts ks => Chain (t ': ts) ((a -> State t b) ': ks) where
  chain (s :+: ss) (f :+: fs) = \x ->
    let (y, s') = runState (f x) s
        (ss', z) = chain ss fs y -- Compile Error One
    in (s' :+: ss', z) -- Compile Error Two

-- Heterogeneous List Definitions
infixr 4 :+:
data HList :: [Type] -> Type where
  HNil :: HList '[]
  (:+:) :: t -> HList ts -> HList (t ': ts)

instance Show (HList '[]) where
  show HNil = "HNil"
instance (Show x, Show (HList xs)) => Show (HList (x ': xs)) where
  show (x :+: xs) = show x ++ " :+: " ++ show xs

Compile Error One

The first error, y :: b mismatches with x :: From ((a -> State t b) : ks), which is due to the lack of evidence that these Kleisli arrows overlaps. But as I changed the signature

instance Chain ts ks => Chain (t ': ts) ((a -> State t b) ': ks) where
-- into
instance Chain ts ks => Chain (t ': ts) ((a -> State t (From ks)) ': ks) where

GHC complains: Illegal type synonym family application ‘From ks’ in instance.

Compile Error Two

The second error, z :: To ks mismatches with the result type To ((a -> State t b) : ks), which is strange because according to the definition of To,

type family To (ks :: [Type]) :: Type where
  To '[k] = KTo k
  To (k ': ks) = To ks

it should be clear that the two types are equal.


I’d appreciate it if you could give me some suggestions on the problem.

1 Like

what do you need it for? if you can use a custom Monad Type class it is really easy to Implement the type changing state monad for parametric monads

data StateT si sn m a = MkStateT (si -> m (sn,a))

class Monad m where
  return :: a -> m p p a
  (>>=) :: (a -> m p q b) -> m s p a -> m s q b
2 Likes

One other alternative is to use an effect system or mtl:

stage1 :: State t1 < es => a -> Eff es b
stage2 :: State t2 < es => b -> Eff es c

chain :: (State t1 < es, State t2 < es) => a -> Eff es b
chain = stage1 >=> stage2

This is just pseudocode but it should be implementable in any effect library you can find. Although having two state effects with different states is not always ergonomic to use.

What a weird way to say effectful :stuck_out_tongue:

How does effectful cope with t1 ~ t2, i.e. if you want two distinct states of the same type?

It doesn’t, you use a newtype :man_shrugging:

1 Like

I would recommend a different approach; storing your chainable functions in an HList is going to be more complex than directly composing them and storing their composition.

{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}

module Main where

import Data.Functor
import Data.Kind (Type)
import Control.Monad.State

-- Heterogeneous List Definitions
infixr 4 :+:
data HList :: [Type] -> Type where
  HNil :: HList '[]
  (:+:) :: t -> HList ts -> HList (t ': ts)

instance Show (HList '[]) where
  show HNil = "HNil"
instance (Show x, Show (HList xs)) => Show (HList (x ': xs)) where
  show (x :+: xs) = show x ++ " :+: " ++ show xs

-- Composition of heterogeneous state kleislis
(>>>:) :: (a -> State s b) -> (b -> State (HList ss) c) -> a -> State (HList (s ': ss)) c
(>>>:) f g a = state \(s :+: ss) -> do
  let (b, s') = runState (f a) s
  let (c, ss') = runState (g b) ss
  (c, s' :+: ss')
infixr 9 >>>:

-- Example usage
main :: IO ()
main = do
  let f a = modify (+ a) $> a
  let g b = modify (concat . replicate b) $> True
  let h c = modify (|| c) $> "done"
  print $ flip runState (1 :+: "hello" :+: False :+: HNil) $
    (f >>>: g >>>: h >>>: pure) 2
1 Like

This is one of the reasons that I don’t want to use effectful.

This approach is much better than mine! Thank you for your advice.

1 Like

This is one of the reasons I developed Bluefin (package: bluefin). In Bluefin effects are disambiguated by value level handles. It does, however, require you to pass those handles around manually, which may or may not be what you want.

https://hackage.haskell.org/package/bluefin-0.0.2.0/docs/Bluefin-State.html

1 Like

Another generalization avoiding heterogeneous lists is to use zoom from the lens library. This works especially well if you already collect the various fragments of state in a big record, so you don’t have to unpack it into a heterogeneous list.

import Control.Monad
import Control.Lens hiding (zoom)
import Control.Monad.Trans.State.Strict

zoom :: Functor m => Lens' s a -> StateT a m x -> StateT s m x
zoom lens action = StateT (\s ->
  runStateT action (view lens s) <&> \(x, a) -> (x, set lens a s) )

chain :: Monad m => (a -> StateT a m b) -> (b -> StateT b m c) -> (a -> StateT (a, b) m c)
chain m1 m2 = (zoom _1 . m1) >=> (zoom _2 . m2)

main :: IO ()
main = print (runState
  (chain (\n -> modify (+ 1) >> pure (n + 1)) (\n -> modify (+ 1) >> pure (n + 1)) 10)
  (0 :: Int, 0 :: Int))
-- Output: (12, (1,1))
2 Likes