-- Copyright 2013 Evan Laforge
-- This program is distributed under the terms of the GNU General Public
-- License 3.0, see COPYING or http://www.gnu.org/licenses/gpl-3.0.txt

{-# LANGUAGE FlexibleContexts, MultiParamTypeClasses #-}
-- | Vector utilities.
module Util.Vector where
import qualified Data.Vector.Generic as Generic
import qualified Data.Vector.Unboxed as Unboxed
import qualified Data.Vector as V


-- * search

count :: Generic.Vector v a => (a -> Bool) -> v a -> Int
count :: forall (v :: * -> *) a. Vector v a => (a -> Bool) -> v a -> Int
count a -> Bool
f = (Int -> a -> Int) -> Int -> v a -> Int
forall (v :: * -> *) b a.
Vector v b =>
(a -> b -> a) -> a -> v b -> a
Generic.foldl' (\Int
c a
a -> if a -> Bool
f a
a then Int -> Int
forall a. Enum a => a -> a
succ Int
c else Int
c) Int
0

-- | Like 'Generic.find', but from the end.
{-# SPECIALIZE find_end :: (a -> Bool) -> V.Vector a -> Maybe a #-}
{-# INLINEABLE find_end #-}
find_end :: Generic.Vector v a => (a -> Bool) -> v a -> Maybe a
find_end :: forall (v :: * -> *) a. Vector v a => (a -> Bool) -> v a -> Maybe a
find_end a -> Bool
f v a
vec = Int -> Maybe a
go (v a -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
Generic.length v a
vec Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
    where
    go :: Int -> Maybe a
go Int
i
        | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 = Maybe a
forall a. Maybe a
Nothing
        | a -> Bool
f a
val = a -> Maybe a
forall a. a -> Maybe a
Just a
val
        | Bool
otherwise = Int -> Maybe a
go (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)
        where val :: a
val = v a -> Int -> a
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
Generic.unsafeIndex v a
vec Int
i

to_reverse_list :: Generic.Vector v a => v a -> [a]
to_reverse_list :: forall (v :: * -> *) a. Vector v a => v a -> [a]
to_reverse_list v a
vec = (Int -> a) -> [Int] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (v a -> Int -> a
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
Generic.unsafeIndex v a
vec) [Int
from, Int
fromInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1 .. Int
0]
    where from :: Int
from = v a -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
Generic.length v a
vec Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1

fold_abort :: Generic.Vector v a => (accum -> a -> Maybe accum) -> accum
    -> v a -> accum
fold_abort :: forall (v :: * -> *) a accum.
Vector v a =>
(accum -> a -> Maybe accum) -> accum -> v a -> accum
fold_abort accum -> a -> Maybe accum
f accum
accum v a
vec = Int -> accum -> accum
go Int
0 accum
accum
    where go :: Int -> accum -> accum
go Int
i accum
accum = accum -> (accum -> accum) -> Maybe accum -> accum
forall b a. b -> (a -> b) -> Maybe a -> b
maybe accum
accum (Int -> accum -> accum
go (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)) (Maybe accum -> accum) -> Maybe accum -> accum
forall a b. (a -> b) -> a -> b
$ accum -> a -> Maybe accum
f accum
accum (a -> Maybe accum) -> Maybe a -> Maybe accum
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< v a
vec v a -> Int -> Maybe a
forall (v :: * -> *) a. Vector v a => v a -> Int -> Maybe a
Generic.!? Int
i

-- | Find the index of the last value whose running sum is still below the
-- given number.
find_before :: Generic.Vector v Int => Int -> v Int -> Int
find_before :: forall (v :: * -> *). Vector v Int => Int -> v Int -> Int
find_before Int
n = (Int, Int) -> Int
forall a b. (a, b) -> a
fst ((Int, Int) -> Int) -> (v Int -> (Int, Int)) -> v Int -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Int, Int) -> Int -> Maybe (Int, Int))
-> (Int, Int) -> v Int -> (Int, Int)
forall (v :: * -> *) a accum.
Vector v a =>
(accum -> a -> Maybe accum) -> accum -> v a -> accum
fold_abort (Int, Int) -> Int -> Maybe (Int, Int)
forall {a}. Num a => (a, Int) -> Int -> Maybe (a, Int)
go (Int
0, Int
0)
    where
    go :: (a, Int) -> Int -> Maybe (a, Int)
go (a
i, Int
total) Int
a
        | Int
total Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
a Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
n = (a, Int) -> Maybe (a, Int)
forall a. a -> Maybe a
Just (a
ia -> a -> a
forall a. Num a => a -> a -> a
+a
1, Int
totalInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
a)
        | Bool
otherwise = Maybe (a, Int)
forall a. Maybe a
Nothing

-- | Find the first numbers bracketing @a@.
bracket :: Unboxed.Vector Double -> Double -> Maybe (Int, Double, Double)
    -- ^ (i, low, high) where low <= a < high, and i is the index of low.
    -- If @a@ is out of range, then low==high.
bracket :: Vector Double -> Double -> Maybe (Int, Double, Double)
bracket Vector Double
vec Double
a = case (Double -> Bool) -> Vector Double -> Maybe Int
forall (v :: * -> *) a.
Vector v a =>
(a -> Bool) -> v a -> Maybe Int
Generic.findIndex (Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
>=Double
a) Vector Double
vec of
    Just Int
i
        | Int -> Double
get Int
i Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
== Double
a -> (Int, Double, Double) -> Maybe (Int, Double, Double)
forall a. a -> Maybe a
Just (Int
i, Double
a, Double
a)
        | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 -> (Int, Double, Double) -> Maybe (Int, Double, Double)
forall a. a -> Maybe a
Just (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1, Int -> Double
get (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1), Int -> Double
get Int
i)
    Maybe Int
_ -> Maybe (Int, Double, Double)
forall a. Maybe a
Nothing
    where get :: Int -> Double
get = Vector Double -> Int -> Double
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
Generic.unsafeIndex Vector Double
vec

-- | Binary search for the lowest index of the given value, or where it would
-- be if it were present.
--
-- TODO this is likely the same as
-- Data.Vector.Algorithms.Search.binarySearchLBy
{-# SPECIALIZE lowest_index ::
    Ord key => (a -> key) -> key -> V.Vector a -> Int #-}
{-# SPECIALIZE lowest_index ::
    (Generic.Vector Unboxed.Vector a, Ord key) => (a -> key) -> key
        -> Unboxed.Vector a -> Int #-}
{-# INLINEABLE lowest_index #-}
lowest_index :: (Ord key, Generic.Vector v a) => (a -> key) -> key -> v a -> Int
lowest_index :: forall key (v :: * -> *) a.
(Ord key, Vector v a) =>
(a -> key) -> key -> v a -> Int
lowest_index a -> key
key key
x v a
vec = v a -> Int -> Int -> Int
forall {v :: * -> *}. Vector v a => v a -> Int -> Int -> Int
go v a
vec Int
0 (v a -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
Generic.length v a
vec)
    where
    go :: v a -> Int -> Int -> Int
go v a
vec Int
low Int
high
        | Int
low Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
high = Int
low
        | key
x key -> key -> Bool
forall a. Ord a => a -> a -> Bool
<= a -> key
key (v a -> Int -> a
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
Generic.unsafeIndex v a
vec Int
mid) = v a -> Int -> Int -> Int
go v a
vec Int
low Int
mid
        | Bool
otherwise = v a -> Int -> Int -> Int
go v a
vec (Int
midInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Int
high
        where mid :: Int
mid = (Int
low Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
high) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2

-- | Binary search for the highest index of the given X.  So the next value is
-- guaranteed to have a higher x, if it exists.  Return -1 if @x@ is before
-- the first element.
{-# SPECIALIZE highest_index ::
    (Generic.Vector Unboxed.Vector a, Ord key) => (a -> key) -> key
        -> Unboxed.Vector a -> Int #-}
{-# INLINEABLE highest_index #-}
highest_index :: (Ord key, Generic.Vector v a) => (a -> key) -> key -> v a
    -> Int
highest_index :: forall key (v :: * -> *) a.
(Ord key, Vector v a) =>
(a -> key) -> key -> v a -> Int
highest_index a -> key
key key
x v a
vec
    | v a -> Bool
forall (v :: * -> *) a. Vector v a => v a -> Bool
Generic.null v a
vec = -Int
1
    | Bool
otherwise = Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
    where
    i :: Int
i = Int -> Int -> Int
go Int
0 (v a -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
Generic.length v a
vec)
    go :: Int -> Int -> Int
go Int
low Int
high
        | Int
low Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
high = Int
low
        | key
x key -> key -> Bool
forall a. Ord a => a -> a -> Bool
>= a -> key
key (v a -> Int -> a
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
Generic.unsafeIndex v a
vec Int
mid) = Int -> Int -> Int
go (Int
midInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Int
high
        | Bool
otherwise = Int -> Int -> Int
go Int
low Int
mid
        where mid :: Int
mid = (Int
low Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
high) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2

-- | Partition on a key.  This is the vector version of
-- 'Util.Seq.keyed_group_stable'.
partition_on :: (Eq key, Generic.Vector v a) => (a -> key) -> v a
    -> [(key, v a)]
partition_on :: forall key (v :: * -> *) a.
(Eq key, Vector v a) =>
(a -> key) -> v a -> [(key, v a)]
partition_on a -> key
key = v a -> [(key, v a)]
forall {v :: * -> *}. Vector v a => v a -> [(key, v a)]
go
    where
    go :: v a -> [(key, v a)]
go v a
v
        | v a -> Bool
forall (v :: * -> *) a. Vector v a => v a -> Bool
Generic.null v a
v = []
        | Bool
otherwise = (key
k, v a
equal) (key, v a) -> [(key, v a)] -> [(key, v a)]
forall a. a -> [a] -> [a]
: v a -> [(key, v a)]
go v a
unequal
            where
            (v a
equal, v a
unequal) = (a -> Bool) -> v a -> (v a, v a)
forall (v :: * -> *) a.
Vector v a =>
(a -> Bool) -> v a -> (v a, v a)
Generic.partition ((key -> key -> Bool
forall a. Eq a => a -> a -> Bool
==key
k) (key -> Bool) -> (a -> key) -> a -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> key
key) v a
v
            k :: key
k = a -> key
key (v a -> a
forall (v :: * -> *) a. Vector v a => v a -> a
Generic.head v a
v)