-- Copyright 2013 Evan Laforge
-- This program is distributed under the terms of the GNU General Public
-- License 3.0, see COPYING or http://www.gnu.org/licenses/gpl-3.0.txt

module Util.Thread (
    start, startLogged, asyncLogged
    , Seconds, delay
    , timeout
    -- * Flag
    , Flag, flag, set, wait, poll
    -- * timing
    , force, timeAction, timeActionText
    , printTimer, printTimer_, printTimerVal
    , currentCpu
    -- * map concurrent
    , forCpu_
    -- * Metric
    , Metric(..), metric, diffMetric, showMetric
) where
import qualified Control.Concurrent as Concurrent
import qualified Control.Concurrent.Async as Async
import qualified Control.Concurrent.QSem as QSem
import qualified Control.Concurrent.STM as STM
import qualified Control.DeepSeq as DeepSeq
import qualified Control.Exception as Exception
import qualified Control.Monad.Trans as Trans

import qualified Data.Text as Text
import           Data.Text (Text)
import qualified Data.Text.IO as Text.IO
import qualified Data.Time as Time

import qualified GHC.Conc as Conc
import qualified System.CPUTime as CPUTime
import qualified System.IO as IO
import qualified System.Timeout as Timeout

import qualified Text.Printf as Printf

import qualified Util.Log as Log


start :: IO () -> IO Concurrent.ThreadId
start :: IO () -> IO ThreadId
start = IO () -> IO ThreadId
Concurrent.forkIO

-- | Start a noisy thread that will log when it starts and stops, and warn if
-- it dies from an exception.
startLogged :: String -> IO () -> IO Concurrent.ThreadId
startLogged :: String -> IO () -> IO ThreadId
startLogged String
name = IO () -> IO ThreadId
Concurrent.forkIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. String -> IO a -> IO a
logged String
name

asyncLogged :: String -> IO a -> IO (Async.Async a)
asyncLogged :: forall a. String -> IO a -> IO (Async a)
asyncLogged String
name = forall a. IO a -> IO (Async a)
Async.async forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. String -> IO a -> IO a
logged String
name

logged :: String -> IO a -> IO a
logged :: forall a. String -> IO a -> IO a
logged String
name IO a
thread = do
    ThreadId
threadId <- IO ThreadId
Concurrent.myThreadId
    ThreadId -> String -> IO ()
Conc.labelThread ThreadId
threadId String
name
    let threadName :: Text
threadName = String -> Text
Text.pack forall a b. (a -> b) -> a -> b
$ forall a. Show a => a -> String
show ThreadId
threadId forall a. [a] -> [a] -> [a]
++ String
" " forall a. [a] -> [a] -> [a]
++ String
name forall a. [a] -> [a] -> [a]
++ String
": "
    forall (m :: * -> *). (Stack, LogMonad m) => Text -> m ()
Log.debug forall a b. (a -> b) -> a -> b
$ Text
threadName forall a. Semigroup a => a -> a -> a
<> Text
"started"
    forall e a. Exception e => IO a -> IO (Either e a)
Exception.try IO a
thread forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Left SomeException
exc
            | Just AsyncCancelled
Async.AsyncCancelled <- forall e. Exception e => SomeException -> Maybe e
Exception.fromException SomeException
exc ->
                forall e a. Exception e => e -> IO a
Exception.throwIO SomeException
exc
            | Bool
otherwise -> do
                forall (m :: * -> *). (Stack, LogMonad m) => Text -> m ()
Log.warn forall a b. (a -> b) -> a -> b
$ Text
threadName forall a. Semigroup a => a -> a -> a
<> Text
"died: "
                    forall a. Semigroup a => a -> a -> a
<> String -> Text
Text.pack (forall a. Show a => a -> String
show (SomeException
exc :: Exception.SomeException))
                forall e a. Exception e => e -> IO a
Exception.throwIO SomeException
exc
        Right a
a -> do
            forall (m :: * -> *). (Stack, LogMonad m) => Text -> m ()
Log.debug forall a b. (a -> b) -> a -> b
$ Text
threadName forall a. Semigroup a => a -> a -> a
<> Text
"completed"
            forall (m :: * -> *) a. Monad m => a -> m a
return a
a

-- | This is just NominalDiffTime, but with a name I might remember.
type Seconds = Time.NominalDiffTime

-- | Delay in seconds.
delay :: Seconds -> IO ()
delay :: Seconds -> IO ()
delay = Int -> IO ()
Concurrent.threadDelay forall b c a. (b -> c) -> (a -> b) -> a -> c
. Seconds -> Int
toUsec

timeout :: Seconds -> IO a -> IO (Maybe a)
timeout :: forall a. Seconds -> IO a -> IO (Maybe a)
timeout = forall a. Int -> IO a -> IO (Maybe a)
Timeout.timeout forall b c a. (b -> c) -> (a -> b) -> a -> c
. Seconds -> Int
toUsec

toUsec :: Seconds -> Int
toUsec :: Seconds -> Int
toUsec = forall a b. (RealFrac a, Integral b) => a -> b
round forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a. Num a => a -> a -> a
*Seconds
1000000)

-- * Flag

-- | A Flag starts False, and can eventually become True.  It never goes back
-- to False again.
newtype Flag = Flag (STM.TVar Bool)
    deriving (Flag -> Flag -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Flag -> Flag -> Bool
$c/= :: Flag -> Flag -> Bool
== :: Flag -> Flag -> Bool
$c== :: Flag -> Flag -> Bool
Eq)

instance Show Flag where show :: Flag -> String
show Flag
_ = String
"((Flag))"

flag :: IO Flag
flag :: IO Flag
flag = TVar Bool -> Flag
Flag forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. a -> IO (TVar a)
STM.newTVarIO Bool
False

set :: Flag -> IO ()
set :: Flag -> IO ()
set (Flag TVar Bool
var) = forall a. STM a -> IO a
STM.atomically forall a b. (a -> b) -> a -> b
$ forall a. TVar a -> a -> STM ()
STM.writeTVar TVar Bool
var Bool
True

-- | Wait a finite amount of time for the flag to become true.
poll :: Seconds -> Flag -> IO Bool
poll :: Seconds -> Flag -> IO Bool
poll Seconds
time (Flag TVar Bool
var)
    | Seconds
time forall a. Ord a => a -> a -> Bool
<= Seconds
0 = forall a. TVar a -> IO a
STM.readTVarIO TVar Bool
var
    | Bool
otherwise = forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (forall a b. a -> b -> a
const Bool
True) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Seconds -> IO a -> IO (Maybe a)
timeout Seconds
time (Flag -> IO ()
wait (TVar Bool -> Flag
Flag TVar Bool
var))

-- | Wait until the flag becomes true.
wait :: Flag -> IO ()
wait :: Flag -> IO ()
wait (Flag TVar Bool
var) = forall a. STM a -> IO a
STM.atomically forall a b. (a -> b) -> a -> b
$ do
    Bool
val <- forall a. TVar a -> STM a
STM.readTVar TVar Bool
var
    if Bool
val then forall (m :: * -> *) a. Monad m => a -> m a
return () else forall a. STM a
STM.retry

-- * timing

force :: DeepSeq.NFData a => a -> IO ()
force :: forall a. NFData a => a -> IO ()
force a
x = forall a. a -> IO a
Exception.evaluate (forall a. NFData a => a -> ()
DeepSeq.rnf a
x)

-- | Time an IO action in CPU and wall seconds.  Technically not thread
-- related, but I don't have a better place at the moment.
timeAction :: Trans.MonadIO m => m a -> m (a, Metric Seconds)
timeAction :: forall (m :: * -> *) a. MonadIO m => m a -> m (a, Metric Seconds)
timeAction m a
action = do
    Metric UTCTime
start <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
Trans.liftIO IO (Metric UTCTime)
metric
    !a
val <- m a
action
    Metric UTCTime
end <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
Trans.liftIO IO (Metric UTCTime)
metric
    forall (m :: * -> *) a. Monad m => a -> m a
return (a
val, Metric UTCTime -> Metric UTCTime -> Metric Seconds
diffMetric Metric UTCTime
start Metric UTCTime
end)

-- | Like 'timeAction', but return a Text msg instead of the values.
timeActionText :: Trans.MonadIO m => m a -> m (a, Text)
timeActionText :: forall (m :: * -> *) a. MonadIO m => m a -> m (a, Text)
timeActionText = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Metric Seconds -> Text
showMetric) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. MonadIO m => m a -> m (a, Metric Seconds)
timeAction

cpuToSec :: Integer -> Seconds
cpuToSec :: Integer -> Seconds
cpuToSec Integer
s = forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
s forall a. Fractional a => a -> a -> a
/ forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer
10forall a b. (Num a, Integral b) => a -> b -> a
^Integer
12)

printTimer :: Text -> (a -> String) -> IO a -> IO a
printTimer :: forall a. Text -> (a -> String) -> IO a -> IO a
printTimer Text
msg a -> String
showVal IO a
action = do
    Text -> IO ()
Text.IO.putStr forall a b. (a -> b) -> a -> b
$ Text
msg forall a. Semigroup a => a -> a -> a
<> Text
" - "
    Handle -> IO ()
IO.hFlush Handle
IO.stdout
    Either SomeException (a, Text)
result <- forall e a. Exception e => IO a -> IO (Either e a)
Exception.try forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. MonadIO m => m a -> m (a, Text)
timeActionText forall a b. (a -> b) -> a -> b
$ do
        !a
val <- IO a
action
        forall (m :: * -> *) a. Monad m => a -> m a
return a
val
    case Either SomeException (a, Text)
result of
        Right (a
val, Text
msg) -> do
            Text -> IO ()
Text.IO.putStrLn forall a b. (a -> b) -> a -> b
$
                Text
"time: " forall a. Semigroup a => a -> a -> a
<> Text
msg forall a. Semigroup a => a -> a -> a
<> Text
" - " forall a. Semigroup a => a -> a -> a
<> String -> Text
Text.pack (a -> String
showVal a
val)
            forall (m :: * -> *) a. Monad m => a -> m a
return a
val
        Left (SomeException
exc :: Exception.SomeException) -> do
            -- Complete the line so the exception doesn't interrupt it.  This
            -- is important if it's a 'failure' line!
            String -> IO ()
putStrLn forall a b. (a -> b) -> a -> b
$ String
"threw exception: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show SomeException
exc
            forall e a. Exception e => e -> IO a
Exception.throwIO SomeException
exc

printTimer_ :: Trans.MonadIO m => Text -> m a -> m a
printTimer_ :: forall (m :: * -> *) a. MonadIO m => Text -> m a -> m a
printTimer_ Text
msg m a
action = do
    (a
a, Metric Seconds
metric) <- forall (m :: * -> *) a. MonadIO m => m a -> m (a, Metric Seconds)
timeAction m a
action
    forall (m :: * -> *) a. MonadIO m => IO a -> m a
Trans.liftIO forall a b. (a -> b) -> a -> b
$ Handle -> Text -> IO ()
Text.IO.hPutStrLn Handle
IO.stderr forall a b. (a -> b) -> a -> b
$
        Text
msg forall a. Semigroup a => a -> a -> a
<> Text
": " forall a. Semigroup a => a -> a -> a
<> Metric Seconds -> Text
showMetric Metric Seconds
metric
    forall (m :: * -> *) a. Monad m => a -> m a
return a
a

printTimerVal :: (DeepSeq.NFData a, Trans.MonadIO m) => Text -> a -> m a
printTimerVal :: forall a (m :: * -> *). (NFData a, MonadIO m) => Text -> a -> m a
printTimerVal Text
msg a
val = forall (m :: * -> *) a. MonadIO m => Text -> m a -> m a
printTimer_ Text
msg forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. NFData a => a -> ()
DeepSeq.rnf a
val seq :: forall a b. a -> b -> b
`seq` a
val

currentCpu :: IO Seconds
currentCpu :: IO Seconds
currentCpu = Integer -> Seconds
cpuToSec forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO Integer
CPUTime.getCPUTime

toSecs :: Seconds -> Double
toSecs :: Seconds -> Double
toSecs = forall a b. (Real a, Fractional b) => a -> b
realToFrac

-- * concurrent map

forCpu_ :: [a] -> (a -> IO b) -> IO ()
forCpu_ :: forall a b. [a] -> (a -> IO b) -> IO ()
forCpu_ [a]
xs a -> IO b
f = do
    QSem
sem <- Int -> IO QSem
QSem.newQSem forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO Int
Concurrent.getNumCapabilities
    forall (f :: * -> *) a b. Foldable f => f a -> (a -> IO b) -> IO ()
Async.forConcurrently_ [a]
xs forall a b. (a -> b) -> a -> b
$ \a
x ->
        forall a b c. IO a -> IO b -> IO c -> IO c
Exception.bracket_ (QSem -> IO ()
QSem.waitQSem QSem
sem) (QSem -> IO ()
QSem.signalQSem QSem
sem) (a -> IO b
f a
x)

-- * Metric

data Metric time = Metric {
    forall time. Metric time -> Seconds
metricCpu :: Seconds
    , forall time. Metric time -> time
metricWall :: time
    } deriving (Int -> Metric time -> ShowS
forall time. Show time => Int -> Metric time -> ShowS
forall time. Show time => [Metric time] -> ShowS
forall time. Show time => Metric time -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Metric time] -> ShowS
$cshowList :: forall time. Show time => [Metric time] -> ShowS
show :: Metric time -> String
$cshow :: forall time. Show time => Metric time -> String
showsPrec :: Int -> Metric time -> ShowS
$cshowsPrec :: forall time. Show time => Int -> Metric time -> ShowS
Show)

metric :: IO (Metric Time.UTCTime)
metric :: IO (Metric UTCTime)
metric = forall time. Seconds -> time -> Metric time
Metric forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO Seconds
currentCpu forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IO UTCTime
Time.getCurrentTime

diffMetric :: Metric Time.UTCTime -> Metric Time.UTCTime -> Metric Seconds
diffMetric :: Metric UTCTime -> Metric UTCTime -> Metric Seconds
diffMetric (Metric Seconds
cpu1 UTCTime
time1) (Metric Seconds
cpu2 UTCTime
time2) =
    forall time. Seconds -> time -> Metric time
Metric (Seconds
cpu2forall a. Num a => a -> a -> a
-Seconds
cpu1) (UTCTime
time2 UTCTime -> UTCTime -> Seconds
`Time.diffUTCTime` UTCTime
time1)

showMetric :: Metric Seconds -> Text
showMetric :: Metric Seconds -> Text
showMetric (Metric Seconds
cpu Seconds
wall) =
    String -> Text
Text.pack forall a b. (a -> b) -> a -> b
$ forall r. PrintfType r => String -> r
Printf.printf String
"%.2f cpu / %.2fs" (Seconds -> Double
toSecs Seconds
cpu) (Seconds -> Double
toSecs Seconds
wall)