-----------------------------------------------------------------------------
-- |
-- Module      :  Control.Concurrent.STM.SSem
-- Copyright   :  (c) Chris Kuklewicz, 2012
-- License     :  BSD-style
-- 
-- Maintainer  :  haskell@list.mightyreason.com
-- Stability   :  experimental
-- Portability :  non-portable (concurrency)
--
-- Very simple quantity semaphore.
--
-----------------------------------------------------------------------------
module Control.Concurrent.STM.SSem(SSem, new, wait, signal, tryWait
                                  , waitN, signalN, tryWaitN
                                  , getValue) where

import Control.Monad.STM(STM,retry)
import Control.Concurrent.STM.TVar(newTVar,readTVar,writeTVar)
import Control.Concurrent.STM.SSemInternals(SSem(SSem))

-- | Create a new semaphore with the given argument as the initially available quantity.  This
-- allows new semaphores to start with a negative, zero, or positive quantity.
new :: Int -> STM SSem
new :: Int -> STM SSem
new = (TVar Int -> SSem) -> STM (TVar Int) -> STM SSem
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap TVar Int -> SSem
SSem (STM (TVar Int) -> STM SSem)
-> (Int -> STM (TVar Int)) -> Int -> STM SSem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> STM (TVar Int)
forall a. a -> STM (TVar a)
newTVar

-- | Try to take a unit of value from the semaphore.  This succeeds when the current quantity is
-- positive, and then reduces the quantity by one.  Otherwise this will 'retry'.  This will never
-- result in a negative quantity.  If several threads are retying then which one succeeds next is
-- undefined -- an unlucky thread might starve.
wait :: SSem -> STM ()
wait :: SSem -> STM ()
wait = (SSem -> Int -> STM ()) -> Int -> SSem -> STM ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip SSem -> Int -> STM ()
waitN 1

-- | Try to take the given value from the semaphore.  This succeeds when the quantity is greater or
-- equal to the given value, and then subtracts the given value from the quantity.  Otherwise this
-- will 'retry'.  This will never result in a negative quantity.  If several threads are retrying
-- then which one succeeds next is undefined -- an unlucky thread might starve.
waitN :: SSem -> Int -> STM ()
waitN :: SSem -> Int -> STM ()
waitN (SSem s :: TVar Int
s) i :: Int
i = do
  Int
v <- TVar Int -> STM Int
forall a. TVar a -> STM a
readTVar TVar Int
s
  if Int
v Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
i
    then TVar Int -> Int -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar Int
s (Int -> STM ()) -> Int -> STM ()
forall a b. (a -> b) -> a -> b
$! Int
vInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
i
    else STM ()
forall a. STM a
retry

-- | Signal that single unit of the semaphore is available.  This increases the available quantity
-- by one.
signal :: SSem -> STM ()
signal :: SSem -> STM ()
signal = (SSem -> Int -> STM ()) -> Int -> SSem -> STM ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip SSem -> Int -> STM ()
signalN 1

-- | Signal that many units of the semaphore are available.  This changes the available quantity by
-- adding the passed size.
signalN :: SSem -> Int -> STM ()
signalN :: SSem -> Int -> STM ()
signalN (SSem s :: TVar Int
s) i :: Int
i = do
  Int
v <- TVar Int -> STM Int
forall a. TVar a -> STM a
readTVar TVar Int
s
  TVar Int -> Int -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar Int
s (Int -> STM ()) -> Int -> STM ()
forall a b. (a -> b) -> a -> b
$! Int
vInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
i

-- | Non-retrying version of 'wait'.  `tryWait s` is defined as `tryN s 1`
tryWait :: SSem -> STM (Maybe Int)
tryWait :: SSem -> STM (Maybe Int)
tryWait = (SSem -> Int -> STM (Maybe Int)) -> Int -> SSem -> STM (Maybe Int)
forall a b c. (a -> b -> c) -> b -> a -> c
flip SSem -> Int -> STM (Maybe Int)
tryWaitN 1

-- | Non-retrying version of waitN.  It either takes the quantity from the semaphore like
-- waitN and returns `Just value taken` or finds insufficient quantity to take and returns
-- Nothing
tryWaitN :: SSem -> Int -> STM (Maybe Int)
tryWaitN :: SSem -> Int -> STM (Maybe Int)
tryWaitN (SSem s :: TVar Int
s) i :: Int
i = do
  Int
v <- TVar Int -> STM Int
forall a. TVar a -> STM a
readTVar TVar Int
s
  if Int
v Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
i
    then do TVar Int -> Int -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar Int
s (Int -> STM ()) -> Int -> STM ()
forall a b. (a -> b) -> a -> b
$! Int
vInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
i
            Maybe Int -> STM (Maybe Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> Maybe Int
forall a. a -> Maybe a
Just Int
i)
    else Maybe Int -> STM (Maybe Int)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Int
forall a. Maybe a
Nothing

-- | Return the current quantity in the semaphore.  This is potentially useful in a larger STM
-- transaciton and less useful as `atomically getValueSem :: IO Int` due to race conditions.
getValue :: SSem -> STM Int
getValue :: SSem -> STM Int
getValue (SSem s :: TVar Int
s) = TVar Int -> STM Int
forall a. TVar a -> STM a
readTVar TVar Int
s