-- Copyright 2013 Evan Laforge
-- This program is distributed under the terms of the GNU General Public
-- License 3.0, see COPYING or http://www.gnu.org/licenses/gpl-3.0.txt

{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE LambdaCase #-}
-- | Control flow and monadic utilities.
module Util.Control (
    module Util.Control
    , module Data.Bifunctor, module Control.Monad.Extra, module Util.CallStack
) where
import qualified Control.Monad as Monad
import qualified Control.Monad.Except as Except
import           Control.Monad.Extra
    (allM, andM, anyM, findM, mapMaybeM, notM, orM, partitionM)
import qualified Control.Monad.Fix as Fix

import           Data.Bifunctor (Bifunctor(bimap, first, second))

import           Util.CallStack (errorIO, errorStack)


-- These are the same as Control.Monad.Extra, but they are frequently used, and
-- by defining them here I can explicitly INLINE them.  Surely they're short
-- enough that ghc will inline anyway, but -fprof-auto-exported isn't that
-- clever.  I got around by recompiling all of hackage with
-- 'profiling-detail: none', but I might as well keep the definitions anyway
-- since it gives me more control.

{-# INLINE whenJust #-}
whenJust :: Applicative m => Maybe a -> (a -> m ()) -> m ()
whenJust :: forall (m :: * -> *) a.
Applicative m =>
Maybe a -> (a -> m ()) -> m ()
whenJust Maybe a
ma a -> m ()
f = forall b a. b -> (a -> b) -> Maybe a -> b
maybe (forall (f :: * -> *) a. Applicative f => a -> f a
pure ()) a -> m ()
f Maybe a
ma

{-# INLINE whenJustM #-}
whenJustM :: Monad m => m (Maybe a) -> (a -> m ()) -> m ()
whenJustM :: forall (m :: * -> *) a.
Monad m =>
m (Maybe a) -> (a -> m ()) -> m ()
whenJustM m (Maybe a)
mma a -> m ()
f = m (Maybe a)
mma forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Maybe a
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    Just a
a -> a -> m ()
f a
a

{-# INLINE whenM #-}
whenM :: Monad m => m Bool -> m () -> m ()
whenM :: forall (m :: * -> *). Monad m => m Bool -> m () -> m ()
whenM m Bool
mb m ()
true = forall (m :: * -> *) a. Monad m => m Bool -> m a -> m a -> m a
ifM m Bool
mb m ()
true (forall (f :: * -> *) a. Applicative f => a -> f a
pure ())

{-# INLINE unlessM #-}
unlessM :: Monad m => m Bool -> m () -> m ()
unlessM :: forall (m :: * -> *). Monad m => m Bool -> m () -> m ()
unlessM m Bool
mb m ()
false = forall (m :: * -> *) a. Monad m => m Bool -> m a -> m a -> m a
ifM m Bool
mb (forall (f :: * -> *) a. Applicative f => a -> f a
pure ()) m ()
false

{-# INLINE ifM #-}
ifM :: Monad m => m Bool -> m a -> m a -> m a
ifM :: forall (m :: * -> *) a. Monad m => m Bool -> m a -> m a -> m a
ifM m Bool
mb m a
true m a
false = m Bool
mb forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Bool
True -> m a
true
    Bool
False -> m a
false

uncurry3 :: (a -> b -> c -> d) -> (a, b, c) -> d
uncurry3 :: forall a b c d. (a -> b -> c -> d) -> (a, b, c) -> d
uncurry3 a -> b -> c -> d
f (a
a, b
b, c
c) = a -> b -> c -> d
f a
a b
b c
c

uncurry4 :: (a -> b -> c -> d -> e) -> (a, b, c, d) -> e
uncurry4 :: forall a b c d e. (a -> b -> c -> d -> e) -> (a, b, c, d) -> e
uncurry4 a -> b -> c -> d -> e
f (a
a, b
b, c
c, d
d) = a -> b -> c -> d -> e
f a
a b
b c
c d
d

-- * local

while :: Monad m => m Bool -> m a -> m [a]
while :: forall (m :: * -> *) a. Monad m => m Bool -> m a -> m [a]
while m Bool
cond m a
op = do
    Bool
b <- m Bool
cond
    case Bool
b of
        Bool
True -> do
            a
val <- m a
op
            [a]
rest <- forall (m :: * -> *) a. Monad m => m Bool -> m a -> m [a]
while m Bool
cond m a
op
            forall (m :: * -> *) a. Monad m => a -> m a
return (a
valforall a. a -> [a] -> [a]
:[a]
rest)
        Bool
False -> forall (m :: * -> *) a. Monad m => a -> m a
return []

while_ :: Monad m => m Bool -> m a -> m ()
while_ :: forall (m :: * -> *) a. Monad m => m Bool -> m a -> m ()
while_ m Bool
cond m a
op = do
    Bool
b <- m Bool
cond
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
Monad.when Bool
b forall a b. (a -> b) -> a -> b
$ m a
op forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (m :: * -> *) a. Monad m => m Bool -> m a -> m ()
while_ m Bool
cond m a
op

-- | Loop with no arguments.  This is the same as 'Fix.fix' but the name is
-- clearer.
loop0 :: (a -> a) -> a
loop0 :: forall a. (a -> a) -> a
loop0 = forall a. (a -> a) -> a
Fix.fix

-- | Loop with a single state argument.
loop1 :: forall state a. state -> ((state -> a) -> state -> a) -> a
loop1 :: forall state a. state -> ((state -> a) -> state -> a) -> a
loop1 state
state (state -> a) -> state -> a
f = (state -> a) -> state -> a
f state -> a
again state
state
    where
    again :: state -> a
    again :: state -> a
again = (state -> a) -> state -> a
f state -> a
again

-- | Loop with two state arguments.  You could use loop1 with a pair, but
-- sometimes the currying is convenient.
loop2 :: forall s1 s2 a. s1 -> s2 -> ((s1 -> s2 -> a) -> s1 -> s2 -> a) -> a
loop2 :: forall s1 s2 a. s1 -> s2 -> ((s1 -> s2 -> a) -> s1 -> s2 -> a) -> a
loop2 s1
s1 s2
s2 (s1 -> s2 -> a) -> s1 -> s2 -> a
f = (s1 -> s2 -> a) -> s1 -> s2 -> a
f s1 -> s2 -> a
again s1
s1 s2
s2
    where
    again :: s1 -> s2 -> a
    again :: s1 -> s2 -> a
again = (s1 -> s2 -> a) -> s1 -> s2 -> a
f s1 -> s2 -> a
again

-- | This is 'Foldable.foldMap' specialized to lists.
mconcatMap :: Monoid b => (a -> b) -> [a] -> b
mconcatMap :: forall b a. Monoid b => (a -> b) -> [a] -> b
mconcatMap a -> b
f = forall a. Monoid a => [a] -> a
mconcat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map a -> b
f

-- | This is actually a mconcatMapM.
--
-- A further generalized version would be:
--
-- > foldMapA :: (Applicative f, Traversable t, Monoid m) =>
-- >    (a -> f m) -> t a -> f m
-- > foldMapA f = fmap Foldable.fold . traverse f
concatMapM :: (Monad m, Monoid b) => (a -> m b) -> [a] -> m b
concatMapM :: forall (m :: * -> *) b a.
(Monad m, Monoid b) =>
(a -> m b) -> [a] -> m b
concatMapM a -> m b
f = forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
Monad.liftM forall a. Monoid a => [a] -> a
mconcat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM a -> m b
f

-- | Run the second action only if the first action returns Just.
--
-- This is like MaybeT, but using MaybeT itself required lots of annoying
-- explicit lifting.
justm :: Monad m => m (Maybe a) -> (a -> m (Maybe b)) -> m (Maybe b)
justm :: forall (m :: * -> *) a b.
Monad m =>
m (Maybe a) -> (a -> m (Maybe b)) -> m (Maybe b)
justm m (Maybe a)
op1 a -> m (Maybe b)
op2 = forall b a. b -> (a -> b) -> Maybe a -> b
maybe (forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing) a -> m (Maybe b)
op2 forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Maybe a)
op1

-- | The Either equivalent of 'justm'.  EitherT solves the same problem, but
-- requires a runEitherT and lots of hoistEithers.
rightm :: Monad m => m (Either err a) -> (a -> m (Either err b))
    -> m (Either err b)
rightm :: forall (m :: * -> *) err a b.
Monad m =>
m (Either err a) -> (a -> m (Either err b)) -> m (Either err b)
rightm m (Either err a)
op1 a -> m (Either err b)
op2 = m (Either err a)
op1 forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Either err a
x -> case Either err a
x of
    Left err
err -> forall (m :: * -> *) a. Monad m => a -> m a
return (forall a b. a -> Either a b
Left err
err)
    Right a
val -> a -> m (Either err b)
op2 a
val

{-
    I could generalize justm and rightm with:

    bind2 :: (Monad m1, Traversable m2, Monad m2)
        => m1 (m2 a) -> (a -> m1 (m2 b)) -> m1 (m2 b)
    bind2 ma mb = ma >>= traverse mb >>= return . Monad.join

    But I can't think of any other Traversables I want.
-}

-- | Return the first action to return Just.
firstJust :: Monad m => m (Maybe a) -> m (Maybe a) -> m (Maybe a)
firstJust :: forall (m :: * -> *) a.
Monad m =>
m (Maybe a) -> m (Maybe a) -> m (Maybe a)
firstJust m (Maybe a)
action m (Maybe a)
alternative = forall b a. b -> (a -> b) -> Maybe a -> b
maybe m (Maybe a)
alternative (forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. a -> Maybe a
Just) forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Maybe a)
action

-- | 'firstJust' applied to a list.
firstJusts :: Monad m => [m (Maybe a)] -> m (Maybe a)
firstJusts :: forall (m :: * -> *) a. Monad m => [m (Maybe a)] -> m (Maybe a)
firstJusts = forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr forall (m :: * -> *) a.
Monad m =>
m (Maybe a) -> m (Maybe a) -> m (Maybe a)
firstJust (forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing)

-- * errors

-- The names are chosen to be consistent with the @errors@ package.

-- | Throw on Nothing.
justErr :: err -> Maybe a -> Either err a
justErr :: forall err a. err -> Maybe a -> Either err a
justErr err
err = forall b a. b -> (a -> b) -> Maybe a -> b
maybe (forall a b. a -> Either a b
Left err
err) forall a b. b -> Either a b
Right

-- | I usually call this @require@.
tryJust :: Except.MonadError e m => e -> Maybe a -> m a
tryJust :: forall e (m :: * -> *) a. MonadError e m => e -> Maybe a -> m a
tryJust e
err = forall b a. b -> (a -> b) -> Maybe a -> b
maybe (forall e (m :: * -> *) a. MonadError e m => e -> m a
Except.throwError e
err) forall (m :: * -> *) a. Monad m => a -> m a
return

-- | I usually call this @require_right@.
tryRight :: Except.MonadError e m => Either e a -> m a
tryRight :: forall e (m :: * -> *) a. MonadError e m => Either e a -> m a
tryRight = forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either forall e (m :: * -> *) a. MonadError e m => e -> m a
Except.throwError forall (m :: * -> *) a. Monad m => a -> m a
return

rethrow :: Except.MonadError e m => (e -> e) -> m a -> m a
rethrow :: forall e (m :: * -> *) a. MonadError e m => (e -> e) -> m a -> m a
rethrow e -> e
modify m a
action = m a
action forall e (m :: * -> *) a.
MonadError e m =>
m a -> (e -> m a) -> m a
`Except.catchError` \e
e ->
    forall e (m :: * -> *) a. MonadError e m => e -> m a
Except.throwError (e -> e
modify e
e)