-- Copyright 2014 Evan Laforge
-- This program is distributed under the terms of the GNU General Public
-- License 3.0, see COPYING or http://www.gnu.org/licenses/gpl-3.0.txt

{-# LANGUAGE Rank2Types #-}
-- | This is the core of the Deriver monad, instantiated in detail in
-- "Derive.Deriver.Monad".
module Derive.Deriver.DeriveM (
    Deriver, RunResult, run, write
    , throw, modify, get, gets, put
    , annotate
) where
import qualified Control.Monad.Except as Except

import qualified Util.Log as Log


newtype Deriver st err a = Deriver
    { forall st err a.
Deriver st err a
-> forall r.
   st
   -> [Msg]
   -> Failure st err r
   -> Success st err a r
   -> RunResult st err r
runD :: forall r. st -> [Log.Msg] -> Failure st err r
        -> Success st err a r -> RunResult st err r
    }

type Failure st err r = st -> [Log.Msg] -> err -> RunResult st err r
type Success st err a r = st -> [Log.Msg] -> a -> RunResult st err r
type RunResult st err a = (Either err a, st, [Log.Msg])

run :: st -> Deriver st err a -> RunResult st err a
run :: forall st err a. st -> Deriver st err a -> RunResult st err a
run st
st Deriver st err a
m = forall st err a.
Deriver st err a
-> forall r.
   st
   -> [Msg]
   -> Failure st err r
   -> Success st err a r
   -> RunResult st err r
runD Deriver st err a
m st
st []
    (\st
st [Msg]
logs err
err -> (forall a b. a -> Either a b
Left err
err, st
st, forall a. [a] -> [a]
reverse [Msg]
logs))
    (\st
st [Msg]
logs a
a -> (forall a b. b -> Either a b
Right a
a, st
st, forall a. [a] -> [a]
reverse [Msg]
logs))

write :: Log.Msg -> Deriver st err ()
write :: forall st err. Msg -> Deriver st err ()
write Msg
msg = forall st err a.
(forall r.
 st
 -> [Msg]
 -> Failure st err r
 -> Success st err a r
 -> RunResult st err r)
-> Deriver st err a
Deriver forall a b. (a -> b) -> a -> b
$ \st
st [Msg]
logs Failure st err r
_ Success st err () r
win -> Success st err () r
win st
st (Msg
msgforall a. a -> [a] -> [a]
:[Msg]
logs) ()

throw :: err -> Deriver st err a
throw :: forall err st a. err -> Deriver st err a
throw err
err = forall st err a.
(forall r.
 st
 -> [Msg]
 -> Failure st err r
 -> Success st err a r
 -> RunResult st err r)
-> Deriver st err a
Deriver forall a b. (a -> b) -> a -> b
$ \st
st [Msg]
logs Failure st err r
lose Success st err a r
_ -> Failure st err r
lose st
st [Msg]
logs err
err

-- TODO this INLINE style is just cargo-cult and I probably can just put them
-- in the instance declarations directly.

{-# INLINE modify #-}
modify :: (st -> st) -> Deriver st err ()
modify :: forall st err. (st -> st) -> Deriver st err ()
modify st -> st
f = forall st err a.
(forall r.
 st
 -> [Msg]
 -> Failure st err r
 -> Success st err a r
 -> RunResult st err r)
-> Deriver st err a
Deriver forall a b. (a -> b) -> a -> b
$ \st
st1 [Msg]
logs Failure st err r
_ Success st err () r
win -> let !st2 :: st
st2 = st -> st
f st
st1 in Success st err () r
win st
st2 [Msg]
logs ()

{-# INLINE get #-}
get :: Deriver st err st
get :: forall st err. Deriver st err st
get = forall st err a.
(forall r.
 st
 -> [Msg]
 -> Failure st err r
 -> Success st err a r
 -> RunResult st err r)
-> Deriver st err a
Deriver forall a b. (a -> b) -> a -> b
$ \st
st [Msg]
logs Failure st err r
_ Success st err st r
win -> Success st err st r
win st
st [Msg]
logs st
st

{-# INLINE put #-}
put :: st -> Deriver st err ()
put :: forall st err. st -> Deriver st err ()
put !st
st = forall st err a.
(forall r.
 st
 -> [Msg]
 -> Failure st err r
 -> Success st err a r
 -> RunResult st err r)
-> Deriver st err a
Deriver forall a b. (a -> b) -> a -> b
$ \st
_ [Msg]
logs Failure st err r
_ Success st err () r
win -> Success st err () r
win st
st [Msg]
logs ()

instance Functor (Deriver st err) where
    fmap :: forall a b. (a -> b) -> Deriver st err a -> Deriver st err b
fmap = forall a b st err. (a -> b) -> Deriver st err a -> Deriver st err b
fmapC

{-# INLINE fmapC #-}
fmapC :: (a -> b) -> Deriver st err a -> Deriver st err b
fmapC :: forall a b st err. (a -> b) -> Deriver st err a -> Deriver st err b
fmapC a -> b
f Deriver st err a
m = forall st err a.
(forall r.
 st
 -> [Msg]
 -> Failure st err r
 -> Success st err a r
 -> RunResult st err r)
-> Deriver st err a
Deriver forall a b. (a -> b) -> a -> b
$ \st
st1 [Msg]
logs1 Failure st err r
lose Success st err b r
win ->
    forall st err a.
Deriver st err a
-> forall r.
   st
   -> [Msg]
   -> Failure st err r
   -> Success st err a r
   -> RunResult st err r
runD Deriver st err a
m st
st1 [Msg]
logs1 Failure st err r
lose (\st
st2 [Msg]
logs2 a
a -> Success st err b r
win st
st2 [Msg]
logs2 (a -> b
f a
a))

instance Applicative (Deriver st err) where
    pure :: forall a. a -> Deriver st err a
pure = forall a st err. a -> Deriver st err a
pureC
    <*> :: forall a b.
Deriver st err (a -> b) -> Deriver st err a -> Deriver st err b
(<*>) = forall st err a b.
Deriver st err (a -> b) -> Deriver st err a -> Deriver st err b
apC

{-# INLINE pureC #-}
pureC :: a -> Deriver st err a
pureC :: forall a st err. a -> Deriver st err a
pureC a
a = forall st err a.
(forall r.
 st
 -> [Msg]
 -> Failure st err r
 -> Success st err a r
 -> RunResult st err r)
-> Deriver st err a
Deriver forall a b. (a -> b) -> a -> b
$ \st
st [Msg]
logs Failure st err r
_ Success st err a r
win -> Success st err a r
win st
st [Msg]
logs a
a

{-# INLINE apC #-}
apC :: Deriver st err (a -> b) -> Deriver st err a -> Deriver st err b
apC :: forall st err a b.
Deriver st err (a -> b) -> Deriver st err a -> Deriver st err b
apC Deriver st err (a -> b)
mf Deriver st err a
ma = do
    a -> b
f <- Deriver st err (a -> b)
mf
    a
a <- Deriver st err a
ma
    forall (m :: * -> *) a. Monad m => a -> m a
return (a -> b
f a
a)

instance Monad (Deriver st err) where
    >>= :: forall a b.
Deriver st err a -> (a -> Deriver st err b) -> Deriver st err b
(>>=) = forall st err a b.
Deriver st err a -> (a -> Deriver st err b) -> Deriver st err b
bindC

{-# INLINE bindC #-}
bindC :: Deriver st err a -> (a -> Deriver st err b) -> Deriver st err b
bindC :: forall st err a b.
Deriver st err a -> (a -> Deriver st err b) -> Deriver st err b
bindC Deriver st err a
m a -> Deriver st err b
f = forall st err a.
(forall r.
 st
 -> [Msg]
 -> Failure st err r
 -> Success st err a r
 -> RunResult st err r)
-> Deriver st err a
Deriver forall a b. (a -> b) -> a -> b
$ \st
st1 [Msg]
logs1 Failure st err r
lose Success st err b r
win ->
    forall st err a.
Deriver st err a
-> forall r.
   st
   -> [Msg]
   -> Failure st err r
   -> Success st err a r
   -> RunResult st err r
runD Deriver st err a
m st
st1 [Msg]
logs1 Failure st err r
lose (\st
st2 [Msg]
logs2 a
a -> forall st err a.
Deriver st err a
-> forall r.
   st
   -> [Msg]
   -> Failure st err r
   -> Success st err a r
   -> RunResult st err r
runD (a -> Deriver st err b
f a
a) st
st2 [Msg]
logs2 Failure st err r
lose Success st err b r
win)

instance Except.MonadError err (Deriver st err) where
    throwError :: forall a. err -> Deriver st err a
throwError = forall err st a. err -> Deriver st err a
throw
    catchError :: forall a.
Deriver st err a -> (err -> Deriver st err a) -> Deriver st err a
catchError Deriver st err a
m err -> Deriver st err a
handle = forall st err a.
(forall r.
 st
 -> [Msg]
 -> Failure st err r
 -> Success st err a r
 -> RunResult st err r)
-> Deriver st err a
Deriver forall a b. (a -> b) -> a -> b
$ \st
st1 [Msg]
logs1 Failure st err r
lose Success st err a r
win ->
        forall st err a.
Deriver st err a
-> forall r.
   st
   -> [Msg]
   -> Failure st err r
   -> Success st err a r
   -> RunResult st err r
runD Deriver st err a
m st
st1 [Msg]
logs1 (\st
st2 [Msg]
logs2 err
a -> forall st err a.
Deriver st err a
-> forall r.
   st
   -> [Msg]
   -> Failure st err r
   -> Success st err a r
   -> RunResult st err r
runD (err -> Deriver st err a
handle err
a) st
st2 [Msg]
logs2 Failure st err r
lose Success st err a r
win)
            Success st err a r
win

{-# INLINE gets #-}
gets :: (st -> a) -> Deriver st err a
gets :: forall st a err. (st -> a) -> Deriver st err a
gets st -> a
f = do
    st
st <- forall st err. Deriver st err st
get
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$! st -> a
f st
st

-- | Catch and rethrow an error, presumably to annotate it with more
-- information.
annotate :: (err -> err) -> Deriver st err a -> Deriver st err a
annotate :: forall err st a.
(err -> err) -> Deriver st err a -> Deriver st err a
annotate err -> err
f Deriver st err a
m = forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
Except.catchError Deriver st err a
m (forall err st a. err -> Deriver st err a
throw forall b c a. (b -> c) -> (a -> b) -> a -> c
. err -> err
f)