-- Copyright 2019 Evan Laforge
-- This program is distributed under the terms of the GNU General Public
-- License 3.0, see COPYING or http://www.gnu.org/licenses/gpl-3.0.txt

-- | Utilities for Data.Vector that dispatch to C.
module Util.VectorC (
    mixFloats
) where
import qualified Control.Monad.ST.Unsafe as ST.Unsafe
import qualified Data.Vector.Storable as V
import qualified Data.Vector.Storable.Mutable as VM
import qualified Foreign
import           Foreign (Ptr)
import qualified Foreign.C as C


mixFloats :: Int -> [V.Vector Float] -> V.Vector Float
mixFloats :: Int -> [Vector Float] -> Vector Float
mixFloats Int
minLen [Vector Float]
vs = forall a. Storable a => (forall s. ST s (MVector s a)) -> Vector a
V.create forall a b. (a -> b) -> a -> b
$ do
    let vsLen :: Int
vsLen = forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum forall a b. (a -> b) -> a -> b
$ Int
0 forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map forall a. Storable a => Vector a -> Int
V.length [Vector Float]
vs
    MVector s Float
v <- forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> m (MVector (PrimState m) a)
VM.new (forall a. Ord a => a -> a -> a
max Int
minLen Int
vsLen)
    forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> a -> m ()
VM.set MVector s Float
v Float
0
    forall a s. IO a -> ST s a
ST.Unsafe.unsafeIOToST forall a b. (a -> b) -> a -> b
$ forall a s b.
Storable a =>
MVector s a -> (Ptr a -> CSize -> IO b) -> IO b
withFPtrM MVector s Float
v forall a b. (a -> b) -> a -> b
$ \Ptr Float
vp CSize
vp_len ->
        forall a b.
Storable a =>
[Vector a] -> (Ptr (Ptr a) -> Ptr CSize -> IO b) -> IO b
withFPtrs [Vector Float]
vs forall a b. (a -> b) -> a -> b
$ \Ptr (Ptr Float)
inspp Ptr CSize
lenspp ->
        Ptr CFloat
-> CSize -> Ptr (Ptr CFloat) -> Ptr CSize -> CSize -> IO ()
c_mix_vectors (forall a b. Ptr a -> Ptr b
Foreign.castPtr Ptr Float
vp) CSize
vp_len (forall a b. Ptr a -> Ptr b
Foreign.castPtr Ptr (Ptr Float)
inspp) Ptr CSize
lenspp
            (forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Vector Float]
vs))
    forall (m :: * -> *) a. Monad m => a -> m a
return MVector s Float
v

-- void
-- mix_vectors(float *out, size_t out_len,
--     const float **ins, size_t *in_lens, size_t ins_len)
foreign import ccall "mix_vectors"
    c_mix_vectors :: Ptr C.CFloat -> C.CSize
        -> Ptr (Ptr C.CFloat) -> Ptr C.CSize -> C.CSize -> IO ()


-- * foreign utils

withFPtrs :: Foreign.Storable a => [V.Vector a]
    -> (Ptr (Ptr a) -> Ptr C.CSize -> IO b) -> IO b
withFPtrs :: forall a b.
Storable a =>
[Vector a] -> (Ptr (Ptr a) -> Ptr CSize -> IO b) -> IO b
withFPtrs [Vector a]
vs Ptr (Ptr a) -> Ptr CSize -> IO b
action = [(Ptr a, CSize)] -> [Vector a] -> IO b
go [] [Vector a]
vs
    where
    go :: [(Ptr a, CSize)] -> [Vector a] -> IO b
go [(Ptr a, CSize)]
accum [] = forall a b. Storable a => [a] -> (Ptr a -> IO b) -> IO b
Foreign.withArray [Ptr a]
vps forall a b. (a -> b) -> a -> b
$ \Ptr (Ptr a)
vpp ->
        forall a b. Storable a => [a] -> (Ptr a -> IO b) -> IO b
Foreign.withArray [CSize]
vpLens forall a b. (a -> b) -> a -> b
$ \Ptr CSize
lenpp -> Ptr (Ptr a) -> Ptr CSize -> IO b
action Ptr (Ptr a)
vpp Ptr CSize
lenpp
        where ([Ptr a]
vps, [CSize]
vpLens) = forall a b. [(a, b)] -> ([a], [b])
unzip (forall a. [a] -> [a]
reverse [(Ptr a, CSize)]
accum)
    go [(Ptr a, CSize)]
accum (Vector a
v:[Vector a]
vs) = forall a b.
Storable a =>
Vector a -> (Ptr a -> CSize -> IO b) -> IO b
withFPtr Vector a
v forall a b. (a -> b) -> a -> b
$
        \Ptr a
vp CSize
len -> [(Ptr a, CSize)] -> [Vector a] -> IO b
go ((Ptr a
vp, CSize
len) forall a. a -> [a] -> [a]
: [(Ptr a, CSize)]
accum) [Vector a]
vs

withFPtrM :: Foreign.Storable a => VM.MVector s a
    -> (Ptr a -> C.CSize -> IO b) -> IO b
withFPtrM :: forall a s b.
Storable a =>
MVector s a -> (Ptr a -> CSize -> IO b) -> IO b
withFPtrM = forall v a b.
(v -> (ForeignPtr a, Int)) -> v -> (Ptr a -> CSize -> IO b) -> IO b
withFPtr_ forall s a. MVector s a -> (ForeignPtr a, Int)
VM.unsafeToForeignPtr0

withFPtr :: Foreign.Storable a => V.Vector a
    -> (Ptr a -> C.CSize -> IO b) -> IO b
withFPtr :: forall a b.
Storable a =>
Vector a -> (Ptr a -> CSize -> IO b) -> IO b
withFPtr = forall v a b.
(v -> (ForeignPtr a, Int)) -> v -> (Ptr a -> CSize -> IO b) -> IO b
withFPtr_ forall a. Vector a -> (ForeignPtr a, Int)
V.unsafeToForeignPtr0

withFPtr_ :: (v -> (Foreign.ForeignPtr a, Int)) -> v
    -> (Ptr a -> C.CSize -> IO b) -> IO b
withFPtr_ :: forall v a b.
(v -> (ForeignPtr a, Int)) -> v -> (Ptr a -> CSize -> IO b) -> IO b
withFPtr_ v -> (ForeignPtr a, Int)
toFptr v
v Ptr a -> CSize -> IO b
action =
    forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
Foreign.withForeignPtr ForeignPtr a
fptr forall a b. (a -> b) -> a -> b
$ \Ptr a
ptr -> Ptr a -> CSize -> IO b
action Ptr a
ptr (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len)
    where (ForeignPtr a
fptr, Int
len) = v -> (ForeignPtr a, Int)
toFptr v
v