-- |
-- Module      : ScrubbedBlock
-- License     : BSD-3-Clause
-- Copyright   : (c) 2025 Olivier Chéron
--
-- A block that is always pinned in memory and automatically erased by a
-- finalizer when not referenced anymore.  Same pattern as ScrubbedBytes from
-- package memory but for blocks.
--
-- A complication here is that we distinguish between mutable and immutable
-- values.  And for resiliency against asynchronous exceptions, we need to
-- schedule block scrubbing with a finalizer right at the beginning when the
-- block is still in mutable form.  Fortunately, for the perspective of the GC,
-- ByteArray# and MutableByteArray# are really the same heap object in disguise
-- and unsafeFreezeByteArray# is a true no-op.  So the finalizer set on the
-- initial MutableByteArray# value gets transferred transparently to the final
-- ByteArray# form.
--
-- See GHC note [primOpEffect of unsafe freezes and thaws]
--
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE UnboxedTuples #-}
module ScrubbedBlock
    ( ScrubbedBlock, foldZipWith, ScrubbedBlock.length
    , new, thaw, unsafeFreeze
    ) where

import Data.Primitive.PrimArray as Block

import Control.Exception (assert)
import Control.Monad.ST

import Data.Word

import Unsafe.Coerce

import Base
import Block (Block, MutableBlock)
import qualified Block

import GHC.Base (IO(IO), Int(I#), setByteArray#)
import GHC.Exts (mkWeak#)

newtype ScrubbedBlock ty = ScrubbedBlock (Block ty)
    deriving (ScrubbedBlock ty -> ScrubbedBlock ty -> Bool
(ScrubbedBlock ty -> ScrubbedBlock ty -> Bool)
-> (ScrubbedBlock ty -> ScrubbedBlock ty -> Bool)
-> Eq (ScrubbedBlock ty)
forall ty.
(Eq ty, Prim ty) =>
ScrubbedBlock ty -> ScrubbedBlock ty -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall ty.
(Eq ty, Prim ty) =>
ScrubbedBlock ty -> ScrubbedBlock ty -> Bool
== :: ScrubbedBlock ty -> ScrubbedBlock ty -> Bool
$c/= :: forall ty.
(Eq ty, Prim ty) =>
ScrubbedBlock ty -> ScrubbedBlock ty -> Bool
/= :: ScrubbedBlock ty -> ScrubbedBlock ty -> Bool
Eq, Int -> ScrubbedBlock ty -> ShowS
[ScrubbedBlock ty] -> ShowS
ScrubbedBlock ty -> String
(Int -> ScrubbedBlock ty -> ShowS)
-> (ScrubbedBlock ty -> String)
-> ([ScrubbedBlock ty] -> ShowS)
-> Show (ScrubbedBlock ty)
forall ty. (Show ty, Prim ty) => Int -> ScrubbedBlock ty -> ShowS
forall ty. (Show ty, Prim ty) => [ScrubbedBlock ty] -> ShowS
forall ty. (Show ty, Prim ty) => ScrubbedBlock ty -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall ty. (Show ty, Prim ty) => Int -> ScrubbedBlock ty -> ShowS
showsPrec :: Int -> ScrubbedBlock ty -> ShowS
$cshow :: forall ty. (Show ty, Prim ty) => ScrubbedBlock ty -> String
show :: ScrubbedBlock ty -> String
$cshowList :: forall ty. (Show ty, Prim ty) => [ScrubbedBlock ty] -> ShowS
showList :: [ScrubbedBlock ty] -> ShowS
Show)

foldZipWith :: (PrimType a, PrimType b)
            => (c -> a -> b -> c) -> c -> ScrubbedBlock a -> ScrubbedBlock b -> c
foldZipWith :: forall a b c.
(PrimType a, PrimType b) =>
(c -> a -> b -> c) -> c -> ScrubbedBlock a -> ScrubbedBlock b -> c
foldZipWith c -> a -> b -> c
f c
c (ScrubbedBlock Block a
a) (ScrubbedBlock Block b
b) =
    (c -> a -> b -> c) -> c -> Block a -> Block b -> c
forall a b c.
(PrimType a, PrimType b) =>
(c -> a -> b -> c) -> c -> Block a -> Block b -> c
Block.foldZipWith c -> a -> b -> c
f c
c Block a
a Block b
b
{-# INLINE foldZipWith #-}

length :: PrimType ty => ScrubbedBlock ty -> CountOf ty
length :: forall ty. PrimType ty => ScrubbedBlock ty -> CountOf ty
length (ScrubbedBlock Block ty
b) = Block ty -> CountOf ty
forall ty. PrimType ty => Block ty -> CountOf ty
Block.length Block ty
b

new :: (PrimType ty, PrimMonad prim) => CountOf ty -> prim (MutableBlock ty (PrimState prim))
new :: forall ty (prim :: * -> *).
(PrimType ty, PrimMonad prim) =>
CountOf ty -> prim (MutableBlock ty (PrimState prim))
new CountOf ty
n = CountOf ty -> prim (MutableBlock ty (PrimState prim))
forall (prim :: * -> *) ty.
(PrimMonad prim, PrimType ty) =>
CountOf ty -> prim (MutableBlock ty (PrimState prim))
Block.newPinned CountOf ty
n prim (MutableBlock ty (PrimState prim))
-> (MutableBlock ty (PrimState prim)
    -> prim (MutableBlock ty (PrimState prim)))
-> prim (MutableBlock ty (PrimState prim))
forall a b. prim a -> (a -> prim b) -> prim b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= MutableBlock ty (PrimState prim)
-> prim (MutableBlock ty (PrimState prim))
forall (prim :: * -> *) ty.
PrimMonad prim =>
MutableBlock ty (PrimState prim)
-> prim (MutableBlock ty (PrimState prim))
scrubbed  -- always pinned

thaw :: PrimMonad m => ScrubbedBlock ty -> m (MutableBlock ty (PrimState m))
thaw :: forall (m :: * -> *) ty.
PrimMonad m =>
ScrubbedBlock ty -> m (MutableBlock ty (PrimState m))
thaw (ScrubbedBlock Block ty
b) = Block ty -> m (MutableBlock ty (PrimState m))
forall (prim :: * -> *) ty.
PrimMonad prim =>
Block ty -> prim (MutableBlock ty (PrimState prim))
Block.thawPinned Block ty
b m (MutableBlock ty (PrimState m))
-> (MutableBlock ty (PrimState m)
    -> m (MutableBlock ty (PrimState m)))
-> m (MutableBlock ty (PrimState m))
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= MutableBlock ty (PrimState m) -> m (MutableBlock ty (PrimState m))
forall (prim :: * -> *) ty.
PrimMonad prim =>
MutableBlock ty (PrimState prim)
-> prim (MutableBlock ty (PrimState prim))
scrubbed  -- always pinned

unsafeFreeze :: PrimMonad prim => MutableBlock ty (PrimState prim) -> prim (ScrubbedBlock ty)
unsafeFreeze :: forall (prim :: * -> *) ty.
PrimMonad prim =>
MutableBlock ty (PrimState prim) -> prim (ScrubbedBlock ty)
unsafeFreeze MutableBlock ty (PrimState prim)
mb = Block ty -> ScrubbedBlock ty
forall ty. Block ty -> ScrubbedBlock ty
checkPinned (Block ty -> ScrubbedBlock ty)
-> prim (Block ty) -> prim (ScrubbedBlock ty)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MutableBlock ty (PrimState prim) -> prim (Block ty)
forall (prim :: * -> *) ty.
PrimMonad prim =>
MutableBlock ty (PrimState prim) -> prim (Block ty)
Block.unsafeFreeze MutableBlock ty (PrimState prim)
mb


{- internal -}

assertPinned :: Block ty -> a -> a
assertPinned :: forall ty a. Block ty -> a -> a
assertPinned Block ty
mb = Bool -> a -> a
forall a. (?callStack::CallStack) => Bool -> a -> a
assert (Block ty -> Bool
forall a. PrimArray a -> Bool
Block.isPrimArrayPinned Block ty
mb)

checkPinned :: Block ty -> ScrubbedBlock ty
checkPinned :: forall ty. Block ty -> ScrubbedBlock ty
checkPinned Block ty
b = Block ty -> ScrubbedBlock ty -> ScrubbedBlock ty
forall ty a. Block ty -> a -> a
assertPinned Block ty
b (Block ty -> ScrubbedBlock ty
forall ty. Block ty -> ScrubbedBlock ty
ScrubbedBlock Block ty
b)

scrubbed :: PrimMonad prim => MutableBlock ty (PrimState prim) -> prim (MutableBlock ty (PrimState prim))
scrubbed :: forall (prim :: * -> *) ty.
PrimMonad prim =>
MutableBlock ty (PrimState prim)
-> prim (MutableBlock ty (PrimState prim))
scrubbed MutableBlock ty (PrimState prim)
b = IO (MutableBlock ty (PrimState prim))
-> prim (MutableBlock ty (PrimState prim))
forall (m :: * -> *) a. PrimMonad m => IO a -> m a
unsafePrimFromIO (MutableBlock ty (PrimState prim) -> IO ()
forall ty s. MutableBlock ty s -> IO ()
scheduleBlockScrubbing MutableBlock ty (PrimState prim)
b IO ()
-> IO (MutableBlock ty (PrimState prim))
-> IO (MutableBlock ty (PrimState prim))
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> MutableBlock ty (PrimState prim)
-> IO (MutableBlock ty (PrimState prim))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return MutableBlock ty (PrimState prim)
b)

wakeUpAfterInception :: MutableBlock ty s -> MutableBlock ty RealWorld
wakeUpAfterInception :: forall ty s. MutableBlock ty s -> MutableBlock ty RealWorld
wakeUpAfterInception = MutableBlock ty s -> MutableBlock ty RealWorld
forall a b. a -> b
unsafeCoerce  -- sometimes disappointing

scheduleBlockScrubbing :: MutableBlock ty s -> IO ()
scheduleBlockScrubbing :: forall ty s. MutableBlock ty s -> IO ()
scheduleBlockScrubbing MutableBlock ty s
b = MutableBlock ty s -> IO () -> IO ()
forall ty s. MutableBlock ty s -> IO () -> IO ()
addBlockFinalizer MutableBlock ty s
b (MutableBlock Word8 RealWorld -> IO ()
scrub (MutableBlock Word8 RealWorld -> IO ())
-> MutableBlock Word8 RealWorld -> IO ()
forall a b. (a -> b) -> a -> b
$ MutableBlock ty RealWorld -> MutableBlock Word8 RealWorld
forall a m b. MutableBlock a m -> MutableBlock b m
Block.unsafeCastMut MutableBlock ty RealWorld
b')
  where b' :: MutableBlock ty RealWorld
b' = MutableBlock ty s -> MutableBlock ty RealWorld
forall ty s. MutableBlock ty s -> MutableBlock ty RealWorld
wakeUpAfterInception MutableBlock ty s
b
{-# NOINLINE scheduleBlockScrubbing #-}

scrub :: MutableBlock Word8 RealWorld -> IO ()
scrub :: MutableBlock Word8 RealWorld -> IO ()
scrub MutableBlock Word8 RealWorld
b = MutableBlock Word8 (PrimState IO) -> IO (CountOf Word8)
forall (prim :: * -> *) ty.
(PrimMonad prim, PrimType ty) =>
MutableBlock ty (PrimState prim) -> prim (CountOf ty)
Block.getMutableLength MutableBlock Word8 RealWorld
MutableBlock Word8 (PrimState IO)
b IO (CountOf Word8) -> (CountOf Word8 -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \(CountOf Int
len) -> Int -> MutableBlock Word8 RealWorld -> IO ()
erase Int
len MutableBlock Word8 RealWorld
b

addBlockFinalizer :: MutableBlock ty s -> IO () -> IO ()
addBlockFinalizer :: forall ty s. MutableBlock ty s -> IO () -> IO ()
addBlockFinalizer (Block.MutablePrimArray MutableByteArray# s
mbarr) (IO State# RealWorld -> (# State# RealWorld, () #)
finalizer) = (State# RealWorld -> (# State# RealWorld, () #)) -> IO ()
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO ((State# RealWorld -> (# State# RealWorld, () #)) -> IO ())
-> (State# RealWorld -> (# State# RealWorld, () #)) -> IO ()
forall a b. (a -> b) -> a -> b
$ \State# RealWorld
s ->
   case MutableByteArray# s
-> ()
-> (State# RealWorld -> (# State# RealWorld, () #))
-> State# RealWorld
-> (# State# RealWorld, Weak# () #)
forall a b c.
a
-> b
-> (State# RealWorld -> (# State# RealWorld, c #))
-> State# RealWorld
-> (# State# RealWorld, Weak# b #)
mkWeak# MutableByteArray# s
mbarr () State# RealWorld -> (# State# RealWorld, () #)
finalizer State# RealWorld
s of { (# State# RealWorld
s1, Weak# ()
_ #) -> (# State# RealWorld
s1, () #) }

erase :: Int -> MutableBlock Word8 RealWorld -> IO ()
erase :: Int -> MutableBlock Word8 RealWorld -> IO ()
erase (I# Int#
len) (Block.MutablePrimArray MutableByteArray# RealWorld
mbarr) = (State# RealWorld -> (# State# RealWorld, () #)) -> IO ()
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO ((State# RealWorld -> (# State# RealWorld, () #)) -> IO ())
-> (State# RealWorld -> (# State# RealWorld, () #)) -> IO ()
forall a b. (a -> b) -> a -> b
$ \State# RealWorld
s1 ->
    case MutableByteArray# RealWorld
-> Int# -> Int# -> Int# -> State# RealWorld -> State# RealWorld
forall d.
MutableByteArray# d -> Int# -> Int# -> Int# -> State# d -> State# d
setByteArray# MutableByteArray# RealWorld
mbarr Int#
0# Int#
len Int#
0# State# RealWorld
s1 of
        State# RealWorld
s2 -> (# State# RealWorld
s2, () #)