Stateモナドはいくつあるか?


この記事はHaskell Advent Calendar 2024の10日目です


去年の12月という古い話ではありますが、 @Kory__3 さんが発した疑問(ツイート)に答える形で、次の事実を証明することができました。 (Agdaで書いた証明

Monad (State s)のインスタンスは唯一
Haskellにおいて、instance Monad (State s) where ...と書けるMonadのインスタンスであってMonad則を満たすものは、 標準的なStateモナドと同値なものに限られる。

ただし、State sは以下の型とします。

newtype State s a = State { runState :: s -> (s, a) }

この記事では、その証明の概略を述べます。

何を証明したのか

証明の前に、まずは示そうとしていることが正確に何であるのかを詳しく述べます。 「instance Monad (State s) where ...と書けるインスタンス」というのは、 文字通り次の形をしたMonadのインスタンスです。

instance Monad (State s) where
  return = ...
  (>>=) = ...

このように書くことができるインスタンスのうち、きちんとMonad則を満たすものは、 以下に示す”普通の”Stateモナドだけしか存在しないことを示すことが証明のゴールになります。

instance Monad (State s) where
  return a = State $ \s -> (s,a)
  (>>=) = usualBind

usualBind :: forall s a b. State s a -> (a -> State s b) -> State s b
usualBind ma k = State $ \s0 ->
  let (s1,a) = runState ma s0
  in runState (k a) s1

ただし、State s型の内容である関数s -> (s,a)return, (>>=)といった関数は、 いずれも部分関数でないものだけを考えています。undefinedや再帰を用いて、 一部の入力に対してになるような定義は認めていません。1

この主張の意味するところには、勘違いしやすい箇所が2つあります。まず、 Monad則を満たすインスタンスだけを数えている点です。

実は、Monad則を満たさなくてもよいならば、幾らでも異なるインスタンスを作ることができます。 例えば、以下のインスタンスはコンパイルできますが、Monad則を満たしません。

instance Monad (State s) where
  return a = State $ \s -> (s, a)
  (>>=) = strangeBind

strangeBind :: forall s a b. State s a -> (a -> State s b) -> State s b
strangeBind ma k = State $ \s0 ->
  let (s1, _) = runState ma s0
      (s2, _) = runState ma s1
      (_,  a) = runState ma s2
  in runState (k a) s1

2つ目に、「どんな型Aについても、Monad (State A)のインスタンスは唯一である」 と混同しそうになる点があります。ここで主張していることは

  • ⭕「『(任意の型sに対して使える)Monad (State s)のインスタンス』は通常のStateモナドに限られる」

であって、

  • ❌「任意の型sに対して『Monad (State s)のインスタンスは通常のStateモナドに限られる』が成り立つ」

ではありません。 実際、後者は誤った主張です。 反例を挙げると、sBoolを代入した場合にそれは成り立ちません。

instance Monad (State Bool) where
  -- 要 FlexibleInstances 拡張
  ......

と書けるインスタンスには、きちんとMonad則を満たす”普通でない”インスタンスが多数存在します。

“普通でない”インスタンスの具体例

例えば、以下のようなMonadインスタンスがあり、普通のStateモナドとは異なりますがMonad則は満たしています。

toTriple :: State Bool a -> (Bool -> Bool, a, a)
toTriple (State f) = (fst . f, snd (f False), snd (f True))

fromTriple :: (Bool -> Bool, a, a) -> State Bool a
fromTriple (g, a0, a1) = State $ \s ->
  case s of
    False -> (g s, a0)
    True -> (g s, a1)

instance Monad (State Bool) where
  return a = fromTriple (id, a, a)
  ma >>= k =
    let (f, a0, a1) = toTriple ma
        (g, b0, _)  = toTriple (k a0)
        (_, _,  b1) = toTriple (k a1)
    in fromTriple (g . f, b0, b1)

更に、上記のMonadインスタンスがMonad則を満たすことを確かめると、 その際にBool -> Boolidを単位元、flip (.)を積とするモノイドを成していること だけが重要であることに気づきます。

実際、4個(Bool -> Boolと同じ数)の要素をもつ任意のモノイドを用いても State Bool a上に類似のMonadインスタンスを定めることが可能です。 そのようなモノイドで同型でないものは35種類あり、それらが全て別々の”普通でない”Monadインスタンス を定めます。

証明の概略

証明は3つのパートに分けられます。

パート1: joinの変化だけ考えればよい

詳細は省きますが、parametricityというHaskellの言語としての性質から、 Monad (State s)のインスタンスとして変化を付けられる可能性があるのは、 (>>=)の部分だけであることがわかります。詳しく言うと、

  • 任意のデータ型Fに対してFunctor Fのインスタンスは唯一である

    • Functor則を満たす2つのfmap, fmap' :: forall a b. (a -> b) -> F a -> F bに対して常にfmap f x === fmap' f xが成り立つので、 “異なる”Functorのインスタンスは作ることができない
  • Monad則を考慮に入れずとも、returnは一つしか存在しない

    • forall s a. a -> State s aが付く式の値は、いずれも

      return :: forall s a. a -> State s a
      return a = State $ \s -> (s, a)

      と同値である

が成り立ちます。したがって、「Monad則を満たす(>>=)としてあり得る値はいくつあるか」 が分かればよいのですが、ここで更に、証明での扱いやすさのため、(>>=)の代わりに joinを使うことにします。

元来join(>>=)を使って定義されていますが、逆にjoinから(>>=)を復元することもできるので、 それぞれの型BindTyJoinTyの間には全単射があります。 つまり、(>>=)としてあり得る変化を調べる代わりに、joinとしてあり得る変化を考えれば必要十分というわけです。

type BindTy = forall s a b. State s a -> (a -> State s b) -> State s b
type JoinTy = forall s a. State s (State s a) -> State s a

-- BindTy と JoinTy は同型
bindToJoin :: BindTy -> JoinTy
bindToJoin bind' = \mma -> mma `bind'` id

joinToBind :: JoinTy -> BindTy
joinToBind join' = \ma k -> join' (fmap k ma)

例えば、どちらもBindTy型の値であるusualBindstrangeBindをそれぞれ JoinTy型に変換すると、以下のようになります。

-- usualJoin = bindToJoin usualBind
usualJoin :: JoinTy
usualJoin mma = State $ \s0 ->
  let (s1, ma) = runState mma s0
  in runState ma s1

-- strangeJoin = bindToJoin strangeBind
strangeJoin :: JoinTy
strangeJoin mma = State $ \s0 ->
  let (s1, _) = runState mma s0
      (s2, _) = runState mma s1
      (_,  ma) = runState mma s2
  in runState ma s1

つまり、証明のゴールは、

JoinTy型の値のうち、returnと組み合わせてMonad則を満たすものは、 通常のStateモナドを定めるusualJoinの1つに限られる

を示すことになります。

また、ここでいうMonad則は、joinを用いて表現した以下の3つの等式のことです。

  • 左単位則(left unit) join . return === id
  • 右単位則(right unit) join . fmap return === id
  • 結合則(associativity) join . join === join . fmap join

パート2: ポリモーフィックな関数の代わりに代数的データ型で考えればよい

パート1で、証明のゴールを

JoinTy型の値のうち、returnと組み合わせてMonad則を満たすものは、 通常のStateモナドを定める1つだけである

に絞り込むことができました。しかし、直接この証明にとりかかるのはまだ早いです。

JoinTy型というのは、型パラメータs, aに関してforallが付いた、ポリモーフィックな関数

type JoinTy = forall s a. State s (State s a) -> State s a

のことでした。そのようなポリモーフィックな関数すべてに対してMonad則を満たすかどうかを調べていくというのは、 どうやればいいのでしょうか?

ここで、Boehm–Beraducciエンコーディングという、「代数的データ型をポリモーフィックな関数で表現する」 手法を使うことができます。 これは代数的データ型とポリモーフィックな関数が同型になるようなエンコーディングなので、逆に関数のほうを代数的データ型と見做すためにも使えます。

厳密な言明はここでは説明しませんが、Boehm–Beraducciエンコーディングは以下のような例をすべて包含する一般化になっています。2

  • 自然数Nat ↔︎ forall r. (r -> r) -> r -> r
    {-# LANGUAGE RankNTypes #-}
    -- 以下、他のすべての例にもRankNTypes拡張が必要
    
    data Nat = Suc Nat | Z
    type NatEnc = forall r. (r -> r) -> r -> r
    
    -- encodeNat (Suc Z) = \s z -> s z
    -- encodeNat (Suc (Suc Z)) = \s z -> s (s z)
    encodeNat :: Nat -> NatEnc
    encodeNat (Suc n) = \s z -> s (encodeNat n s z)
    encodeNat Z       = \_ z -> z
    
    decodeNat :: NatEnc -> Nat
    decodeNat f = f Suc Z
  • Bool ↔︎ forall r. r -> r -> r
    data Bool = False | True
    type BoolEnc = forall r. r -> r -> r
    
    encodeBool :: Bool -> BoolEnc
    encodeBool False = \x _ -> x
    encodeBool True  = \_ y -> y
    
    decodeBool :: BoolEnc -> Bool
    decodeBool f = f False True
  • リスト List a ↔︎ forall r. (a -> r -> r) -> r -> r
    data List a = Cons a (List a) | Nil
    type ListEnc a = forall r. (a -> r -> r) -> r -> r
    
    -- encodeList (Cons a0 Nil) = \c n -> c a0 n
    -- encodeList (Cons a0 (Cons a1 Nil)) = \c n -> c a0 (c a1 n)
    encodeList :: List a -> ListEnc a
    encodeList (Cons a as) = \c n -> c a (encodeList as c n)
    encodeList Nil         = \_ n -> n
    
    decodeList :: ListEnc a -> List a
    decodeList f = f Suc Z
  • 二分木 BinTree a ↔︎ forall r. (a -> r) -> (r -> r -> r) -> r
    data BinTree a = Tip a | Branch (BinTree a) (BinTree a)
    type BinTreeEnc a = forall r. (a -> r) -> (r -> r -> r) -> r
    
    encodeBT :: BinTree a -> BinTreeEnc a
    encodeBT (Tip x)      = \t _ -> t x
    encodeBT (Branch l r) = \t b -> b (encodeBT l t b) (encodeBT r t b)
    
    decodeBT :: BinTreeEnc -> BinTree a
    decodeBT f = f Tip Branch

JoinTyを、Boehm–Beraducciエンコーディングが適用できる形になるまで同型な型に変形させていくと、 以下のようになります。

JoinTy
 = forall s a. State s (State s a) -> State s a
   -- Stateのnewtypeを剥がす
 ~ forall s a. (s -> (s, s -> (s, a))) -> (s -> (s, a))
   -- (x ->)はタプルに分配できる:
   --     (x -> (y,z)) ~ (x -> y, x -> z)
 ~ forall s a. (s -> s, s -> s -> (s, a)) -> s -> (s, a)
 ~ forall s a. (s -> s, (s -> s -> s, s -> s -> a)) -> s -> (s, a)
   -- カリー化:
   --     ((x, y) -> z) ~ (x -> y -> z) 
 ~ forall s a. (s -> s) -> (s -> s -> s) -> (s -> s -> a) -> s -> (s, a)
   -- 引数の順序を入れ替え
 ~ forall s a. (s -> s) -> (s -> s -> s) -> s -> (s -> s -> a) -> (s, a)
   -- 引数の型が a に依存しない関数と forall a. を交換
 ~ forall s. (s -> s) -> (s -> s -> s) -> s -> forall a. (s -> s -> a) -> (s, a)
   -- カリー化
 ~ forall s. (s -> s) -> (s -> s -> s) -> s -> forall a. ((s,s) -> a) -> (s, a)
   -- Yoneda
 ~ forall s. (s -> s) -> (s -> s -> s) -> s -> Yoneda ((,) s) (s,s)
 ~ forall s. (s -> s) -> (s -> s -> s) -> s -> (s,(s,s))
   -- (s,(s,s)) ~ (s,s,s)
   -- (x ->)をタプルに分配
 ~ forall s. (
      (s -> s) -> (s -> s -> s) -> s -> s,
      (s -> s) -> (s -> s -> s) -> s -> s,
      (s -> s) -> (s -> s -> s) -> s -> s
   )
   -- forall をタプルに分配:
   --     forall x. (f x, g x) ~ (forall x. f x, forall x. g x)
 ~ (
      forall s. (s -> s) -> (s -> s -> s) -> s -> s,
      forall s. (s -> s) -> (s -> s -> s) -> s -> s,
      forall s. (s -> s) -> (s -> s -> s) -> s -> s
   )
   -- 各成分の型を (TEnc := forall s. (s -> s) -> (s -> s -> s) -> s -> s) とおく
 = (TEnc, TEnc, TEnc)

ここで出てきたTEnc = forall s. (s -> s) -> (s -> s -> s) -> s -> sという型は、 代数的データ型TにBoehm–Beraducciエンコーディングを適用した結果になっています。

data T = F T | G T T | X

-- TEnc は T の Boehm--Beraducciエンコーディング
type TEnc = forall s. (s -> s) -> (s -> s -> s) -> s -> s

encodeT :: T -> TEnc
encodeT (F t)   = \f g x -> f (encodeT t f g x)
encodeT (G t u) = \f g x -> g (encodeT t f g x) (encodeT u f g x)
encodeT X       = \_ _ x -> x

decodeT :: TEnc -> T
decodeT enc = enc F G X

したがって、JoinTyは、ある全単射encodeJoin, decodeJoinによって、 代数的データ型であるTの3つ組(T,T,T)と同型です。 これらの同型を実際にHaskellで書くと以下の通りになります。

encodeJoin :: (T,T,T) -> JoinTy
encodeJoin (t,l,r) mma = State $ \s0 ->
    let s2 = encodeT t f g s0
        sLeft = encodeT l f g s0
        sRight = encodeT r f g s0
    in (s2, h sLeft sRight)
  where
    f s0 =
      let (s1,_) = runState mma s0
      in s1
    g s0 s1 =
      let (_,ma) = runState mma s0
          (s2,_) = runState mma s1
      in s2
    h s0 s1 =
      let (_,ma) = runState mma s0
          (_,a) = runState ma s1
      in a

decodeJoin :: JoinTy -> (T,T,T)
decodeJoin join' = combine $ runState (join' mmT) X
  where
    combine (t,(l,r)) = (t,l,r)
    mmT :: State T (State T (T,T))
    mmT = State $ \s -> (F s, State $ \s' -> (G s s', (s,s')))

例えば、通常のStateモナドを定めるusualJoinや、Monad則を満たさない例として挙げたstrangeJoinをエンコードすると、 以下のようになります。

-- usualDef = decodeJoin usualJoin
usualDef :: (T,T,T)
usualDef = (G X (F X), X, F X)

strangeDef :: (T,T,T)
strangeDef = (G (F (F X)) (F X),F (F X),F X)

(T,T,T)のような代数的データ型であれば、それがとる値すべてに対する性質を証明していくことはできそうですね! 示すべき証明のゴールは

join = encodeJoin defMonad則を満たすようなdef :: (T,T,T)usualDefのみである

になりました。 join = encodeJoin defを代入したMonad則は以下の通りです。

  • 左単位則(left unit) encodeJoin def . return === id
  • 右単位則(right unit) encodeJoin def . fmap return === id
  • 結合則(associativity)encodeJoin def . encodeJoin def === encodeJoin def . fmap encodeJoin def

そして更に、これらのMonad則において、

  • (左|右)単位則の両辺はforall s a. State s a -> State s aという型
  • 結合則の両辺はforall s a. State s (State s (State s a)) -> State s aという型

です。JoinTy = forall s a. State s (State s a) -> State s aに対して行ったのと同様にして、 これらの型も適切な代数的データ型と同型になっています。

type UnitLawTy = forall s a. State s a -> State s a

-- UnitLawTy ~ (Nat, Nat)
encode1 :: (Nat, Nat) -> UnitLawTy
encode1 (n,m) (ma :: State s a) =
  let f = fst . runState ma   :: s -> s
      ret = snd . runState ma :: s -> a
  in State $ \s0 -> (encodeNat n f s0, ret (encodeNat m f s0))

decode1 :: UnitLawTy -> (Nat, Nat)
decode1 mm = 
  let mNat = State Nat Nat
      mNat = State $ \n -> (Suc n, n)
  in runState (mm mNat) Z

type AssocLawTy = forall s a. State s (State s (State s a)) -> State s a

data S = Leaf | A S | B S S | C S S S

encode3 :: (S, S, S, S) -> AssocLawTy
encode3 = {- 省略 -}
decode3 :: AssocLawTy -> (S, S, S, S)
decode3 = {- 省略 -}

これらのエンコーディングは同型写像なので、等式の両辺に適用しても 等式が成り立つ条件を変えません。したがって、以下のような補助関数を用いて…

idRep :: (Nat, Nat)
idRep = decode1 (id :: forall s a. State s a -> State s a)

leftUnitLHS :: (T,T,T) -> (Nat, Nat)
leftUnitLHS def = decode1 (encodeJoin def . return)

rightUnitLHS :: (T,T,T) -> (Nat, Nat)
rightUnitLHS def = decode1 (encodeJoin def . fmap return)

assocLHS, assocRHS :: (T,T,T) -> (S, S, S, S)
assocLHS def = decode3 (encodeJoin def . encodeJoin def)
assocRHS def = decode3 (encodeJoin def . fmap (encodeJoin def))

Monad則は以下のように記述できます。

  • 左単位則 leftUnitLHS def = idRep
  • 右単位則 rightUnitLHS def = idRep
  • 結合則 assocLHS def = assocRHS def

この等式は代数的データ型の値の間の等式なので、変数defの動く範囲も含めて 「ポリモーフィックな関数」に一度も言及せずにMonad則を表すことができたことになります。

パート3: 場合分けを頑張れば答えが出せる

さて、前パートのleftUnitLHSなどの補助関数は、型を見るとわかるように 代数的データ型から代数的データ型へのごく普通の関数であり、 その計算方法もencode*, decode*などを展開すると具体的にわかります。 例えば、idRepを実際に評価してみれば3idRep = (Suc Z, Z)がわかりますし、leftUnitLHS, rightUnitLHSも、 具体的には以下のような関数(に同値な振る舞いの関数)であることが計算できます。

leftUnitLHS (t,l,r)
  = (countGs t, countGs r)

rightUnitLHS (t,l,r)
  = (countFs t, countFs l)

-- ただしcountGs, countFsは以下の関数
countGs :: T -> Nat
countGs (F t') = countGs t'
countGs (G _ t') = Suc (countGs t')
countGs X = Z

countFs :: T -> Nat
countFs (F t') = Suc (countFs t')
countFs (G _ t') = countFs t'
countFs X = Z

これにより、Monad則は、具体的な定義をもつ補助関数たちによって書かれた、 代数的データ型(T,T,T)上を動く変数defについての”連立方程式”と見做すことができます。

後はひたすら場合分けをしていくことでこの”連立方程式”を解き、 その解がdef = usualDef = (G X (F X), X, F X)の一点に限ることを示すのみです。 本記事にその細部までは書ききれませんが、雰囲気を感じていただくため、 単位則だけからわかる部分までを解説します。

上に結果だけ記したleftUnitLHS, rightUnitLHSの計算方法をMonad則に代入すると、

  • 左単位則 leftUnitLHS (t,l,r) = (countGs t, countGs r) = (Suc Z, Z)
  • 右単位則 rightUnitLHS (t,l,r) = (countFs t, countFs l) = (Suc Z, Z)

となります。タプル間の等式を成分ごとに書けば、(eq1)–(eq4)が得られます。

countGs t = Suc Z        -- (eq1)
countGs r = Z            -- (eq2)
countFs t = Suc Z        -- (eq3)
countFs l = Z            -- (eq4)

これから、以下の3式が成り立ちます。

  • ある自然数nを用いて、r = F^n X

    • (eq2)よりcountGs r = Zですが、これはrG _ _という部分項を含まないこと、すなわち

      r = X | F X | F (F X) | ...

      であることがわかります。

  • ある自然数mを用いて、l = (G _)^m X

    • (eq4)よりcountFs l = Zです。これは

      l = X | G u₁ X | G u₁ (G u₂ X) | ...

      を意味します。正確な情報を忘れてしまってよいのであれば、l = (G _)^m Xと表記してもよいでしょう。

  • あるu :: Tを用いて、t = F (G u X)またはt = G u (F X)

    • (eq1)と(eq3)より、countFs t = countGs t = Suc Zです。 ここでT型の値tの値について場合分けをします。

      1. t = X ではありえません。

      2. t = F t' の場合:

        まず、countFsの定義より countFs (F t') = Suc (countFs t') = Suc Zであり、 countFs t' = Zでなければなりません。(eq4)と同様にしてt' = (G _)^m X がわかります。

        また、countGs (F t') = countGs t' = countGs ((G _)^m X) = Suc Z であるので、m = 1が得られます。すなわち、t'はあるuを用いてt' = G u Xと表されます。

      3. t = G u t' の場合:

        countGsの定義より countGs (G u t') = Suc (countGs t') = Suc Zであり、 countGs t' = Zでなければなりません。(eq2)と同様にしてt' = F^n Xがわかります。

        また、countFs (G u t') = countFs t' = countFs (F^n X) t = Suc Z であるので、n = 1が得られます。すなわち、t' = F Xです。

      場合分けをまとめると、あるuを使ってt = F (G u X)またはt = G u (F X)のどちらかであることがわかります。

通常のStateモナドはMonad則を満たしているので、 それに対応するusualDef = (G X (F X), X, F X)も当然にこれらの条件をすべてクリアしているはずです。 確かめてみてください。

まとめ

Haskellにおいてinstance Monad (State s) where ...と書けるインスタンスでMonad則を満たすものは、 “普通の”Stateモナドしか存在しません。この記事ではその証明の大まかな方針を説明しました。


  1. 実は、上記の形のインスタンスを持つReverse State Monadというものがあります。 Reverse State Monadは、通常のStateモナドとは異なり、遅延評価と再帰を使って(>>=)が定義されるモナドで、一定の条件下に限ればMonad則を満たしますが、 その条件を満たさなかった場合(>>=)が「全域でない関数s -> (s, a)」を作り出すなど、いまひとつ同じ土俵に乗りません。

    私の知識不足が大きいでしょうが、このモナドは後述する Boehm–Beraducci エンコーディングの枠組みで捉えることもできず、 どう考えるべきかはっきりしません。なるべくad hocに見えない方法でReverse State Monadの類を除外する為に “部分関数はないものとする”という前提を設けました。↩︎

  2. チャーチ・エンコーディング をご存知の場合、これは”厳密な型の付いた”チャーチ・エンコーディングと考えてもかまいません。↩︎

  3. GHCiなどを使ってidRepを評価させることができます↩︎