{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ScopedTypeVariables #-}

-- ---------------------------------------------------------------------------
-- |
-- Module      : Data.Vector.Algorithms.Common
-- Copyright   : (c) 2008-2011 Dan Doel
-- Maintainer  : Dan Doel
-- Stability   : Experimental
-- Portability : Portable
--
-- Common operations and utility functions for all sorts

module Data.Vector.Algorithms.Common
  ( type Comparison
  , copyOffset
  , inc
  , countLoop
  , midPoint
  , uniqueMutableBy
  )
  where

import Prelude hiding (read, length)

import Control.Monad.Primitive

import Data.Vector.Generic.Mutable
import Data.Word (Word)

import qualified Data.Vector.Primitive.Mutable as PV

-- | A type of comparisons between two values of a given type.
type Comparison e = e -> e -> Ordering

copyOffset :: (PrimMonad m, MVector v e)
           => v (PrimState m) e -> v (PrimState m) e -> Int -> Int -> Int -> m ()
copyOffset :: forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
v (PrimState m) e -> v (PrimState m) e -> Int -> Int -> Int -> m ()
copyOffset v (PrimState m) e
from v (PrimState m) e
to Int
iFrom Int
iTo Int
len =
  forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> v (PrimState m) a -> m ()
unsafeCopy (forall (v :: * -> * -> *) a s.
MVector v a =>
Int -> Int -> v s a -> v s a
unsafeSlice Int
iTo Int
len v (PrimState m) e
to) (forall (v :: * -> * -> *) a s.
MVector v a =>
Int -> Int -> v s a -> v s a
unsafeSlice Int
iFrom Int
len v (PrimState m) e
from)
{-# INLINE copyOffset #-}

inc :: (PrimMonad m, MVector v Int) => v (PrimState m) Int -> Int -> m Int
inc :: forall (m :: * -> *) (v :: * -> * -> *).
(PrimMonad m, MVector v Int) =>
v (PrimState m) Int -> Int -> m Int
inc v (PrimState m) Int
arr Int
i = forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
unsafeRead v (PrimState m) Int
arr Int
i forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Int
e -> forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
unsafeWrite v (PrimState m) Int
arr Int
i (Int
eforall a. Num a => a -> a -> a
+Int
1) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (m :: * -> *) a. Monad m => a -> m a
return Int
e
{-# INLINE inc #-}

-- shared bucket sorting stuff
countLoop :: (PrimMonad m, MVector v e)
          => (e -> Int)
          -> v (PrimState m) e -> PV.MVector (PrimState m) Int -> m ()
countLoop :: forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
(e -> Int)
-> v (PrimState m) e -> MVector (PrimState m) Int -> m ()
countLoop e -> Int
rdx v (PrimState m) e
src MVector (PrimState m) Int
count = forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> a -> m ()
set MVector (PrimState m) Int
count Int
0 forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Int -> m ()
go Int
0
 where
 len :: Int
len = forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
length v (PrimState m) e
src
 go :: Int -> m ()
go Int
i
   | Int
i forall a. Ord a => a -> a -> Bool
< Int
len    = forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
unsafeRead v (PrimState m) e
src Int
i forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall (m :: * -> *) (v :: * -> * -> *).
(PrimMonad m, MVector v Int) =>
v (PrimState m) Int -> Int -> m Int
inc MVector (PrimState m) Int
count forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> Int
rdx forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Int -> m ()
go (Int
iforall a. Num a => a -> a -> a
+Int
1)
   | Bool
otherwise  = forall (m :: * -> *) a. Monad m => a -> m a
return ()
{-# INLINE countLoop #-}

midPoint :: Int -> Int -> Int
midPoint :: Int -> Int -> Int
midPoint Int
a Int
b =
  Word -> Int
toInt forall a b. (a -> b) -> a -> b
$ (Int -> Word
toWord Int
a forall a. Num a => a -> a -> a
+ Int -> Word
toWord Int
b) forall a. Integral a => a -> a -> a
`div` Word
2
  where
    toWord :: Int -> Word
    toWord :: Int -> Word
toWord = forall a b. (Integral a, Num b) => a -> b
fromIntegral

    toInt :: Word -> Int
    toInt :: Word -> Int
toInt = forall a b. (Integral a, Num b) => a -> b
fromIntegral
{-# INLINE midPoint #-}

-- Adapted from Andrew Martin's uniquqMutable in the primitive-sort package
uniqueMutableBy :: forall m v a . (PrimMonad m, MVector v a)
  => Comparison a -> v (PrimState m) a -> m (v (PrimState m) a)
uniqueMutableBy :: forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
Comparison a -> v (PrimState m) a -> m (v (PrimState m) a)
uniqueMutableBy Comparison a
cmp v (PrimState m) a
mv = do
  let !len :: Int
len = forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
basicLength v (PrimState m) a
mv
  if Int
len forall a. Ord a => a -> a -> Bool
> Int
1
    then do
      !a
a0 <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
unsafeRead v (PrimState m) a
mv Int
0
      let findFirstDuplicate :: a -> Int -> m Int
          findFirstDuplicate :: a -> Int -> m Int
findFirstDuplicate !a
prev !Int
ix = if Int
ix forall a. Ord a => a -> a -> Bool
< Int
len
            then do
              a
a <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
unsafeRead v (PrimState m) a
mv Int
ix
              if Comparison a
cmp a
a a
prev forall a. Eq a => a -> a -> Bool
== Ordering
EQ
                then forall (m :: * -> *) a. Monad m => a -> m a
return Int
ix
                else a -> Int -> m Int
findFirstDuplicate a
a (Int
ix forall a. Num a => a -> a -> a
+ Int
1)
            else forall (m :: * -> *) a. Monad m => a -> m a
return Int
ix
      Int
dupIx <- a -> Int -> m Int
findFirstDuplicate a
a0 Int
1
      if Int
dupIx forall a. Eq a => a -> a -> Bool
== Int
len
        then forall (m :: * -> *) a. Monad m => a -> m a
return v (PrimState m) a
mv
        else do
          let deduplicate :: a -> Int -> Int -> m Int
              deduplicate :: a -> Int -> Int -> m Int
deduplicate !a
prev !Int
srcIx !Int
dstIx = if Int
srcIx forall a. Ord a => a -> a -> Bool
< Int
len
                then do
                  a
a <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
unsafeRead v (PrimState m) a
mv Int
srcIx
                  if Comparison a
cmp a
a a
prev forall a. Eq a => a -> a -> Bool
== Ordering
EQ
                    then a -> Int -> Int -> m Int
deduplicate a
a (Int
srcIx forall a. Num a => a -> a -> a
+ Int
1) Int
dstIx
                    else do
                      forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
unsafeWrite v (PrimState m) a
mv Int
dstIx a
a
                      a -> Int -> Int -> m Int
deduplicate a
a (Int
srcIx forall a. Num a => a -> a -> a
+ Int
1) (Int
dstIx forall a. Num a => a -> a -> a
+ Int
1)
                else forall (m :: * -> *) a. Monad m => a -> m a
return Int
dstIx
          !a
a <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
unsafeRead v (PrimState m) a
mv Int
dupIx
          !Int
reducedLen <- a -> Int -> Int -> m Int
deduplicate a
a (Int
dupIx forall a. Num a => a -> a -> a
+ Int
1) Int
dupIx
          forall (v :: * -> * -> *) a (m :: * -> *).
(MVector v a, PrimMonad m) =>
v (PrimState m) a -> Int -> m (v (PrimState m) a)
resizeVector v (PrimState m) a
mv Int
reducedLen
    else forall (m :: * -> *) a. Monad m => a -> m a
return v (PrimState m) a
mv
{-# INLINABLE uniqueMutableBy #-}

-- Used internally in uniqueMutableBy: copies the elements of a vector to one
-- of a smaller size.
resizeVector
  :: (MVector v a, PrimMonad m)
  =>  v (PrimState m) a -> Int -> m (v (PrimState m) a)
resizeVector :: forall (v :: * -> * -> *) a (m :: * -> *).
(MVector v a, PrimMonad m) =>
v (PrimState m) a -> Int -> m (v (PrimState m) a)
resizeVector !v (PrimState m) a
src !Int
sz = do
  v (PrimState m) a
dst <- forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
Int -> m (v (PrimState m) a)
unsafeNew Int
sz
  forall (v :: * -> * -> *) a (m :: * -> *).
(MVector v a, PrimMonad m) =>
v (PrimState m) a -> v (PrimState m) a -> m ()
copyToSmaller v (PrimState m) a
dst v (PrimState m) a
src
  forall (f :: * -> *) a. Applicative f => a -> f a
pure v (PrimState m) a
dst
{-# inline resizeVector #-}

-- Used internally in resizeVector: copy a vector from a larger to
-- smaller vector. Should not be used if the source vector
-- is smaller than the target vector.
copyToSmaller
  :: (MVector v a, PrimMonad m)
  => v (PrimState m) a -> v (PrimState m) a -> m ()
copyToSmaller :: forall (v :: * -> * -> *) a (m :: * -> *).
(MVector v a, PrimMonad m) =>
v (PrimState m) a -> v (PrimState m) a -> m ()
copyToSmaller !v (PrimState m) a
dst !v (PrimState m) a
src = forall (m :: * -> *) a. PrimMonad m => ST (PrimState m) a -> m a
stToPrim forall a b. (a -> b) -> a -> b
$ Int -> ST (PrimState m) ()
do_copy Int
0
    where
      !n :: Int
n = forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
basicLength v (PrimState m) a
dst

      do_copy :: Int -> ST (PrimState m) ()
do_copy Int
i | Int
i forall a. Ord a => a -> a -> Bool
< Int
n = do
                            a
x <- forall (v :: * -> * -> *) a (m :: * -> *).
(MVector v a, PrimMonad m) =>
v (PrimState m) a -> Int -> m a
basicUnsafeRead v (PrimState m) a
src Int
i
                            forall (v :: * -> * -> *) a (m :: * -> *).
(MVector v a, PrimMonad m) =>
v (PrimState m) a -> Int -> a -> m ()
basicUnsafeWrite v (PrimState m) a
dst Int
i a
x
                            Int -> ST (PrimState m) ()
do_copy (Int
iforall a. Num a => a -> a -> a
+Int
1)
                | Bool
otherwise = forall (m :: * -> *) a. Monad m => a -> m a
return ()