ExtLib.Data.Monads.StateMonad
Require Import ExtLib.Structures.Monads.
Require Import ExtLib.Structures.Monoid.
Set Implicit Arguments.
Set Strict Implicit.
Section StateType.
Variable S : Type.
Record state (t : Type) : Type := mkState
{ runState : S -> t * S }.
Definition evalState {t} (c : state t) (s : S) : t :=
fst (runState c s).
Definition execState {t} (c : state t) (s : S) : S :=
snd (runState c s).
Global Instance Monad_state : Monad state :=
{ ret := fun _ v => mkState (fun s => (v, s))
; bind := fun _ _ c1 c2 =>
mkState (fun s =>
let (v,s) := runState c1 s in
runState (c2 v) s)
}.
Global Instance MonadState_state : MonadState S state :=
{ get := mkState (fun x => (x,x))
; put := fun v => mkState (fun _ => (tt, v))
}.
Variable m : Type -> Type.
Record stateT (t : Type) : Type := mkStateT
{ runStateT : S -> m (t * S)%type }.
Variable M : Monad m.
Definition evalStateT {t} (c : stateT t) (s : S) : m t :=
bind (runStateT c s) (fun x => ret (fst x)).
Definition execStateT {t} (c : stateT t) (s : S) : m S :=
bind (runStateT c s) (fun x => ret (snd x)).
Require Import ExtLib.Structures.Monoid.
Set Implicit Arguments.
Set Strict Implicit.
Section StateType.
Variable S : Type.
Record state (t : Type) : Type := mkState
{ runState : S -> t * S }.
Definition evalState {t} (c : state t) (s : S) : t :=
fst (runState c s).
Definition execState {t} (c : state t) (s : S) : S :=
snd (runState c s).
Global Instance Monad_state : Monad state :=
{ ret := fun _ v => mkState (fun s => (v, s))
; bind := fun _ _ c1 c2 =>
mkState (fun s =>
let (v,s) := runState c1 s in
runState (c2 v) s)
}.
Global Instance MonadState_state : MonadState S state :=
{ get := mkState (fun x => (x,x))
; put := fun v => mkState (fun _ => (tt, v))
}.
Variable m : Type -> Type.
Record stateT (t : Type) : Type := mkStateT
{ runStateT : S -> m (t * S)%type }.
Variable M : Monad m.
Definition evalStateT {t} (c : stateT t) (s : S) : m t :=
bind (runStateT c s) (fun x => ret (fst x)).
Definition execStateT {t} (c : stateT t) (s : S) : m S :=
bind (runStateT c s) (fun x => ret (snd x)).
Monad_stateT is not a Global Instance because it can cause an infinite loop
in typeclass inference under certain circumstances. Use Existing Instance
Monad_stateT. to bring the instance into context.
Instance Monad_stateT : Monad stateT :=
{ ret := fun _ x => mkStateT (fun s => @ret _ M _ (x,s))
; bind := fun _ _ c1 c2 =>
mkStateT (fun s =>
@bind _ M _ _ (runStateT c1 s) (fun vs =>
let (v,s) := vs in
runStateT (c2 v) s))
}.
Global Instance MonadState_stateT : MonadState S stateT :=
{ get := mkStateT (fun x => ret (x,x))
; put := fun v => mkStateT (fun _ => ret (tt, v))
}.
Global Instance MonadT_stateT : MonadT stateT m :=
{ lift := fun _ c => mkStateT (fun s => bind c (fun t => ret (t, s)))
}.
Global Instance State_State_stateT T (MS : MonadState T m) : MonadState T stateT :=
{ get := lift get
; put := fun x => lift (put x)
}.
Global Instance MonadReader_stateT T (MR : MonadReader T m) : MonadReader T stateT :=
{ ask := mkStateT (fun s => bind ask (fun t => ret (t, s)))
; local := fun _ f c => mkStateT (fun s => local f (runStateT c s))
}.
Global Instance MonadWriter_stateT T (Mon : Monoid T) (MR : MonadWriter Mon m) : MonadWriter Mon stateT :=
{ tell := fun x => mkStateT (fun s => bind (tell x) (fun v => ret (v, s)))
; listen := fun _ c => mkStateT (fun s => bind (listen (runStateT c s))
(fun x => let '(a,s,t) := x in
ret (a,t,s)))
; pass := fun _ c => mkStateT (fun s => bind (runStateT c s) (fun x =>
let '(a,t,s) := x in pass (ret ((a,s),t))))
}.
Global Instance Exc_stateT T (MR : MonadExc T m) : MonadExc T stateT :=
{ raise := fun _ e => lift (raise e)
; catch := fun _ body hnd =>
mkStateT (fun s => catch (runStateT body s) (fun e => runStateT (hnd e) s))
}.
Global Instance MonadZero_stateT (MR : MonadZero m) : MonadZero stateT :=
{ mzero _A := lift mzero
}.
Global Instance MonadFix_stateT (MF : MonadFix m) : MonadFix stateT :=
{ mfix := fun _ _ r v =>
mkStateT (fun s => mfix2 _ (fun r v s => runStateT (mkStateT (r v)) s) v s)
}.
Global Instance MonadPlus_stateT (MP : MonadPlus m) : MonadPlus stateT :=
{ mplus _A _B a b :=
mkStateT (fun s => bind (mplus (runStateT a s) (runStateT b s))
(fun res => match res with
| inl (a,s) => ret (inl a, s)
| inr (b,s) => ret (inr b, s)
end))
}.
End StateType.
Arguments mkStateT {S} {m} {t} (_).
Arguments evalState {S} {t} (c) (s).
Arguments execState {S} {t} (c) (s).
Arguments evalStateT {S} {m} {M} {t} (c) (s).
Arguments execStateT {S} {m} {M} {t} (c) (s).
Arguments MonadWriter_stateT {S} {m} {_} {T} {Mon} (_).
{ ret := fun _ x => mkStateT (fun s => @ret _ M _ (x,s))
; bind := fun _ _ c1 c2 =>
mkStateT (fun s =>
@bind _ M _ _ (runStateT c1 s) (fun vs =>
let (v,s) := vs in
runStateT (c2 v) s))
}.
Global Instance MonadState_stateT : MonadState S stateT :=
{ get := mkStateT (fun x => ret (x,x))
; put := fun v => mkStateT (fun _ => ret (tt, v))
}.
Global Instance MonadT_stateT : MonadT stateT m :=
{ lift := fun _ c => mkStateT (fun s => bind c (fun t => ret (t, s)))
}.
Global Instance State_State_stateT T (MS : MonadState T m) : MonadState T stateT :=
{ get := lift get
; put := fun x => lift (put x)
}.
Global Instance MonadReader_stateT T (MR : MonadReader T m) : MonadReader T stateT :=
{ ask := mkStateT (fun s => bind ask (fun t => ret (t, s)))
; local := fun _ f c => mkStateT (fun s => local f (runStateT c s))
}.
Global Instance MonadWriter_stateT T (Mon : Monoid T) (MR : MonadWriter Mon m) : MonadWriter Mon stateT :=
{ tell := fun x => mkStateT (fun s => bind (tell x) (fun v => ret (v, s)))
; listen := fun _ c => mkStateT (fun s => bind (listen (runStateT c s))
(fun x => let '(a,s,t) := x in
ret (a,t,s)))
; pass := fun _ c => mkStateT (fun s => bind (runStateT c s) (fun x =>
let '(a,t,s) := x in pass (ret ((a,s),t))))
}.
Global Instance Exc_stateT T (MR : MonadExc T m) : MonadExc T stateT :=
{ raise := fun _ e => lift (raise e)
; catch := fun _ body hnd =>
mkStateT (fun s => catch (runStateT body s) (fun e => runStateT (hnd e) s))
}.
Global Instance MonadZero_stateT (MR : MonadZero m) : MonadZero stateT :=
{ mzero _A := lift mzero
}.
Global Instance MonadFix_stateT (MF : MonadFix m) : MonadFix stateT :=
{ mfix := fun _ _ r v =>
mkStateT (fun s => mfix2 _ (fun r v s => runStateT (mkStateT (r v)) s) v s)
}.
Global Instance MonadPlus_stateT (MP : MonadPlus m) : MonadPlus stateT :=
{ mplus _A _B a b :=
mkStateT (fun s => bind (mplus (runStateT a s) (runStateT b s))
(fun res => match res with
| inl (a,s) => ret (inl a, s)
| inr (b,s) => ret (inr b, s)
end))
}.
End StateType.
Arguments mkStateT {S} {m} {t} (_).
Arguments evalState {S} {t} (c) (s).
Arguments execState {S} {t} (c) (s).
Arguments evalStateT {S} {m} {M} {t} (c) (s).
Arguments execStateT {S} {m} {M} {t} (c) (s).
Arguments MonadWriter_stateT {S} {m} {_} {T} {Mon} (_).