-- |
-- Module      : Auxiliary
-- License     : BSD-3-Clause
-- Copyright   : (c) 2025 Olivier Chéron
--
-- ML-KEM auxiliary functions
--
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UnboxedTuples #-}
module Auxiliary
    ( Zq, Rq, Tq, (..+), (..-)
    , ntt, nttInv, rcompress, rdecompress
    , byteEncode, byteDecode, byteEncode12, byteDecode12
    , byteEncode1, byteDecode1, sampleNTT, samplePolyCBD
#ifdef ML_KEM_TESTING
    , compress, decompress
    , bitRev7, fromZq, toZq, fromCoeffs, toCoeffs
#endif
    ) where

import Crypto.Hash.Algorithms

import Data.ByteArray (ByteArrayAccess, Bytes, View)
import qualified Data.ByteArray as B

import Data.Primitive.Types (Prim(..))

import Control.DeepSeq (NFData(..))
import Control.Monad
import Control.Monad.ST

import Data.Bits
import Data.Int
import Data.Proxy
import Data.Word

import GHC.TypeNats

import Foreign.Ptr (Ptr, plusPtr)
import Foreign.Storable (pokeByteOff)

import Unsafe.Coerce

import Base
import Block (blockIndex)
import BlockN (BlockN, MutableBlockN)
import Builder (Builder)
import Crypto (BlockDigest)
import Machine
import Marking (Classified, SecurityMarking(..), Leak(..))
import SecureBlock (SecureBlock)
import SecureBytes (SecureBytes)
import qualified BlockN
import qualified Builder
import qualified ByteArrayST as ST
import qualified Crypto
import Math

type N = 256

n :: Int
n :: Int
n = Int
256

q :: Integer
q :: Integer
q = Integer
3329

q16 :: Word16
q16 :: Word16
q16 = Integer -> Word16
forall a. Num a => Integer -> a
fromInteger Integer
q

q32 :: Word32
q32 :: Word32
q32 = Integer -> Word32
forall a. Num a => Integer -> a
fromInteger Integer
q

q64 :: Word64
q64 :: Word64
q64 = Integer -> Word64
forall a. Num a => Integer -> a
fromInteger Integer
q

bitRev7 :: Word8 -> Word8
bitRev7 :: WordM -> WordM
bitRev7 WordM
b =
    (WordM
b WordM -> Int -> WordM
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
6 WordM -> WordM -> WordM
forall a. Bits a => a -> a -> a
.&. WordM
1) WordM -> WordM -> WordM
forall a. Bits a => a -> a -> a
.|.
    (WordM
b WordM -> Int -> WordM
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
5 WordM -> WordM -> WordM
forall a. Bits a => a -> a -> a
.&. WordM
1) WordM -> Int -> WordM
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
1 WordM -> WordM -> WordM
forall a. Bits a => a -> a -> a
.|.
    (WordM
b WordM -> Int -> WordM
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
4 WordM -> WordM -> WordM
forall a. Bits a => a -> a -> a
.&. WordM
1) WordM -> Int -> WordM
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
2 WordM -> WordM -> WordM
forall a. Bits a => a -> a -> a
.|.
    (WordM
b WordM -> Int -> WordM
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
3 WordM -> WordM -> WordM
forall a. Bits a => a -> a -> a
.&. WordM
1) WordM -> Int -> WordM
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
3 WordM -> WordM -> WordM
forall a. Bits a => a -> a -> a
.|.
    (WordM
b WordM -> Int -> WordM
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
2 WordM -> WordM -> WordM
forall a. Bits a => a -> a -> a
.&. WordM
1) WordM -> Int -> WordM
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
4 WordM -> WordM -> WordM
forall a. Bits a => a -> a -> a
.|.
    (WordM
b WordM -> Int -> WordM
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
1 WordM -> WordM -> WordM
forall a. Bits a => a -> a -> a
.&. WordM
1) WordM -> Int -> WordM
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
5 WordM -> WordM -> WordM
forall a. Bits a => a -> a -> a
.|.
    (WordM
b WordM -> WordM -> WordM
forall a. Bits a => a -> a -> a
.&. WordM
1) WordM -> Int -> WordM
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
6

unsafeShiftIR :: Word16 -> Int -> Word16
unsafeShiftIR :: Word16 -> Int -> Word16
unsafeShiftIR Word16
x Int
s = Int16 -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Word16 -> Int16
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
x :: Int16) Int16 -> Int -> Int16
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
s)
{-# INLINE unsafeShiftIR #-}

-- Reduction 𝑥 mod 𝑞 for 0 ≤ 𝑥 < 2𝑞
reduceSimple :: Word16 -> Word16
reduceSimple :: Word16 -> Word16
reduceSimple Word16
x = (Word16
mask Word16 -> Word16 -> Word16
forall a. Bits a => a -> a -> a
.&. Word16
x) Word16 -> Word16 -> Word16
forall a. Bits a => a -> a -> a
.|. (Word16 -> Word16
forall a. Bits a => a -> a
complement Word16
mask Word16 -> Word16 -> Word16
forall a. Bits a => a -> a -> a
.&. Word16
subtracted)
  where
    subtracted :: Word16
subtracted = Word16
x Word16 -> Word16 -> Word16
forall a. Num a => a -> a -> a
- Word16
q16
    mask :: Word16
mask = Word16
subtracted Word16 -> Int -> Word16
`unsafeShiftIR` Int
15
{-# INLINE reduceSimple #-}

-- Reduction 𝑥 mod 𝑞 for 0 ≤ 𝑥 < 2𝑞² + 𝑞
reduce :: Word32 -> Word16
reduce :: Word32 -> Word16
reduce Word32
x = Word16 -> Word16
reduceSimple (Word32 -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
remainder)
  where
    p :: Word64
p = Word32 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
x Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
* ((Word64
1 Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
24) Word64 -> Word64 -> Word64
forall a. Integral a => a -> a -> a
`div` Word64
q64)
    quotient :: Word32
quotient = Word64 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word64
p Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
24)
    remainder :: Word32
remainder = Word32
x Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
- Word32
quotient Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
* Word32
q32
{-# INLINE reduce #-}

newtype Zq = Zq Word16
#ifdef ML_KEM_TESTING
    deriving (Eq, Show)
#else
    deriving Zq -> Zq -> Bool
(Zq -> Zq -> Bool) -> (Zq -> Zq -> Bool) -> Eq Zq
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Zq -> Zq -> Bool
== :: Zq -> Zq -> Bool
$c/= :: Zq -> Zq -> Bool
/= :: Zq -> Zq -> Bool
Eq
#endif

instance Prim Zq where
    sizeOf# :: Zq -> Int#
sizeOf# (Zq Word16
a) = Word16 -> Int#
forall a. Prim a => a -> Int#
sizeOf# Word16
a
    {-# INLINE sizeOf# #-}
    alignment# :: Zq -> Int#
alignment# (Zq Word16
a) = Word16 -> Int#
forall a. Prim a => a -> Int#
alignment# Word16
a
    {-# INLINE alignment# #-}
#if MIN_VERSION_primitive(0,9,0)
    sizeOfType# :: Proxy Zq -> Int#
sizeOfType# Proxy Zq
_ = Proxy Word16 -> Int#
forall a. Prim a => Proxy a -> Int#
sizeOfType# (Proxy Word16
forall {k} (t :: k). Proxy t
Proxy :: Proxy Word16)
    {-# INLINE sizeOfType# #-}
    alignmentOfType# :: Proxy Zq -> Int#
alignmentOfType# Proxy Zq
_ = Proxy Word16 -> Int#
forall a. Prim a => Proxy a -> Int#
alignmentOfType# (Proxy Word16
forall {k} (t :: k). Proxy t
Proxy :: Proxy Word16)
    {-# INLINE alignmentOfType# #-}
#endif
    indexByteArray# :: ByteArray# -> Int# -> Zq
indexByteArray# ByteArray#
ba Int#
i = Word16 -> Zq
Zq (ByteArray# -> Int# -> Word16
forall a. Prim a => ByteArray# -> Int# -> a
indexByteArray# ByteArray#
ba Int#
i)
    {-# INLINE indexByteArray# #-}
    readByteArray# :: forall s.
MutableByteArray# s -> Int# -> State# s -> (# State# s, Zq #)
readByteArray# MutableByteArray# s
mba Int#
i State# s
s =
        case MutableByteArray# s -> Int# -> State# s -> (# State# s, Word16 #)
forall s.
MutableByteArray# s -> Int# -> State# s -> (# State# s, Word16 #)
forall a s.
Prim a =>
MutableByteArray# s -> Int# -> State# s -> (# State# s, a #)
readByteArray# MutableByteArray# s
mba Int#
i State# s
s of
            (# State# s
s', Word16
a #) -> (# State# s
s', Word16 -> Zq
Zq Word16
a #)
    {-# INLINE readByteArray# #-}
    writeByteArray# :: forall s. MutableByteArray# s -> Int# -> Zq -> State# s -> State# s
writeByteArray# MutableByteArray# s
mba Int#
i (Zq Word16
a) = MutableByteArray# s -> Int# -> Word16 -> State# s -> State# s
forall s.
MutableByteArray# s -> Int# -> Word16 -> State# s -> State# s
forall a s.
Prim a =>
MutableByteArray# s -> Int# -> a -> State# s -> State# s
writeByteArray# MutableByteArray# s
mba Int#
i Word16
a
    {-# INLINE writeByteArray# #-}
    setByteArray# :: forall s.
MutableByteArray# s -> Int# -> Int# -> Zq -> State# s -> State# s
setByteArray# MutableByteArray# s
mba Int#
i Int#
len (Zq Word16
a) = MutableByteArray# s
-> Int# -> Int# -> Word16 -> State# s -> State# s
forall s.
MutableByteArray# s
-> Int# -> Int# -> Word16 -> State# s -> State# s
forall a s.
Prim a =>
MutableByteArray# s -> Int# -> Int# -> a -> State# s -> State# s
setByteArray# MutableByteArray# s
mba Int#
i Int#
len Word16
a
    {-# INLINE setByteArray# #-}
    indexOffAddr# :: Addr# -> Int# -> Zq
indexOffAddr# Addr#
addr Int#
i = Word16 -> Zq
Zq (Addr# -> Int# -> Word16
forall a. Prim a => Addr# -> Int# -> a
indexOffAddr# Addr#
addr Int#
i)
    {-# INLINE indexOffAddr# #-}
    readOffAddr# :: forall s. Addr# -> Int# -> State# s -> (# State# s, Zq #)
readOffAddr# Addr#
addr Int#
i State# s
s =
        case Addr# -> Int# -> State# s -> (# State# s, Word16 #)
forall s. Addr# -> Int# -> State# s -> (# State# s, Word16 #)
forall a s.
Prim a =>
Addr# -> Int# -> State# s -> (# State# s, a #)
readOffAddr# Addr#
addr Int#
i State# s
s of
            (# State# s
s', Word16
a #) -> (# State# s
s', Word16 -> Zq
Zq Word16
a #)
    {-# INLINE readOffAddr# #-}
    writeOffAddr# :: forall s. Addr# -> Int# -> Zq -> State# s -> State# s
writeOffAddr# Addr#
addr Int#
i (Zq Word16
a) = Addr# -> Int# -> Word16 -> State# s -> State# s
forall s. Addr# -> Int# -> Word16 -> State# s -> State# s
forall a s. Prim a => Addr# -> Int# -> a -> State# s -> State# s
writeOffAddr# Addr#
addr Int#
i Word16
a
    {-# INLINE writeOffAddr# #-}
    setOffAddr# :: forall s. Addr# -> Int# -> Int# -> Zq -> State# s -> State# s
setOffAddr# Addr#
addr Int#
i Int#
len (Zq Word16
a) = Addr# -> Int# -> Int# -> Word16 -> State# s -> State# s
forall s. Addr# -> Int# -> Int# -> Word16 -> State# s -> State# s
forall a s.
Prim a =>
Addr# -> Int# -> Int# -> a -> State# s -> State# s
setOffAddr# Addr#
addr Int#
i Int#
len Word16
a
    {-# INLINE setOffAddr# #-}

instance PrimSized Zq where
    type PrimSize Zq = 2

instance Add Zq where
    zero :: Zq
zero = Word16 -> Zq
Zq Word16
0
    Zq Word16
a .+ :: Zq -> Zq -> Zq
.+ Zq Word16
b = Word16 -> Zq
Zq (Word16 -> Zq) -> Word16 -> Zq
forall a b. (a -> b) -> a -> b
$ Word16 -> Word16
reduceSimple (Word16
a Word16 -> Word16 -> Word16
forall a. Num a => a -> a -> a
+ Word16
b)
    Zq Word16
a .- :: Zq -> Zq -> Zq
.- Zq Word16
b = Word16 -> Zq
Zq (Word16 -> Zq) -> Word16 -> Zq
forall a b. (a -> b) -> a -> b
$ Word16 -> Word16
reduceSimple (Word16
a Word16 -> Word16 -> Word16
forall a. Num a => a -> a -> a
+ Word16
q16 Word16 -> Word16 -> Word16
forall a. Num a => a -> a -> a
- Word16
b)
    neg :: Zq -> Zq
neg (Zq Word16
a) = Word16 -> Zq
Zq (Word16 -> Zq) -> Word16 -> Zq
forall a b. (a -> b) -> a -> b
$ Word16 -> Word16
reduceSimple (Word16
q16 Word16 -> Word16 -> Word16
forall a. Num a => a -> a -> a
- Word16
a)

instance Mul Zq where
    one :: Zq
one = Word16 -> Zq
Zq Word16
1
    Zq Word16
a .* :: Zq -> Zq -> Zq
.* Zq Word16
b = Word16 -> Zq
Zq (Word16 -> Zq) -> Word16 -> Zq
forall a b. (a -> b) -> a -> b
$ Word32 -> Word16
reduce (Word16 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
a Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
* Word16 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
b)

#ifdef ML_KEM_TESTING
instance MulAdd Zq where
    mulAdd (Zq a) (Zq b) (Zq c) = Zq $ reduce $
        fromIntegral a * fromIntegral b + fromIntegral c

instance BiMul Zq Zq where
    (..*) = (.*)

instance BiMulAdd Zq Zq where
    biMulAdd = mulAdd

fromZq :: Zq -> Word16
fromZq (Zq a) = a
#endif

toZq :: Word16 -> Zq
toZq :: Word16 -> Zq
toZq = Word16 -> Zq
Zq (Word16 -> Zq) -> (Word16 -> Word16) -> Word16 -> Zq
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word32 -> Word16
reduce (Word32 -> Word16) -> (Word16 -> Word32) -> Word16 -> Word16
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word16 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral

newtype Rq marking = Rq (BlockN marking N Zq)
#ifdef ML_KEM_TESTING
    deriving (Eq, Show)
#endif

instance Classified marking => Add (Rq marking) where
    zero :: Rq marking
zero = BlockN marking N Zq -> Rq marking
forall (marking :: SecurityMarking).
BlockN marking N Zq -> Rq marking
Rq BlockN marking N Zq
forall a. Add a => a
zero
    Rq BlockN marking N Zq
a .+ :: Rq marking -> Rq marking -> Rq marking
.+ Rq BlockN marking N Zq
b = BlockN marking N Zq -> Rq marking
forall (marking :: SecurityMarking).
BlockN marking N Zq -> Rq marking
Rq (BlockN marking N Zq
a BlockN marking N Zq -> BlockN marking N Zq -> BlockN marking N Zq
forall a. Add a => a -> a -> a
.+ BlockN marking N Zq
b)
    {-# INLINE (.+) #-}
    Rq BlockN marking N Zq
a .- :: Rq marking -> Rq marking -> Rq marking
.- Rq BlockN marking N Zq
b = BlockN marking N Zq -> Rq marking
forall (marking :: SecurityMarking).
BlockN marking N Zq -> Rq marking
Rq (BlockN marking N Zq
a BlockN marking N Zq -> BlockN marking N Zq -> BlockN marking N Zq
forall a. Add a => a -> a -> a
.- BlockN marking N Zq
b)
    {-# INLINE (.-) #-}
    neg :: Rq marking -> Rq marking
neg (Rq BlockN marking N Zq
a) = BlockN marking N Zq -> Rq marking
forall (marking :: SecurityMarking).
BlockN marking N Zq -> Rq marking
Rq (BlockN marking N Zq -> BlockN marking N Zq
forall a. Add a => a -> a
neg BlockN marking N Zq
a)
    {-# INLINE neg #-}

infixl 6 ..+, ..-

-- Transformation called only at expected location in the LWE problem, after
-- adding noise to secret information.
(..+) :: Rq Sec -> Rq Sec -> Rq Pub
Rq 'Sec
a ..+ :: Rq 'Sec -> Rq 'Sec -> Rq 'Pub
..+ Rq 'Sec
b = Rq 'Sec -> Rq 'Pub
forall (t :: SecurityMarking -> *). Leak t => t 'Sec -> t 'Pub
leak (Rq 'Sec
a Rq 'Sec -> Rq 'Sec -> Rq 'Sec
forall a. Add a => a -> a -> a
.+ Rq 'Sec
b)
{-# INLINE (..+) #-}

(..-) :: Rq Pub -> Rq Sec -> Rq Sec
Rq BlockN 'Pub N Zq
a ..- :: Rq 'Pub -> Rq 'Sec -> Rq 'Sec
..- Rq BlockN 'Sec N Zq
b = BlockN 'Sec N Zq -> Rq 'Sec
forall (marking :: SecurityMarking).
BlockN marking N Zq -> Rq marking
Rq (BlockN 'Sec N Zq -> Rq 'Sec) -> BlockN 'Sec N Zq -> Rq 'Sec
forall a b. (a -> b) -> a -> b
$ (Zq -> Zq -> Zq)
-> BlockN 'Sec N Zq -> BlockN 'Pub N Zq -> BlockN 'Sec N Zq
forall (mc :: SecurityMarking) (n :: Nat) a b c
       (ma :: SecurityMarking) (mb :: SecurityMarking).
(Classified mc, KnownNat n, PrimType a, PrimType b, PrimType c) =>
(a -> b -> c) -> BlockN ma n a -> BlockN mb n b -> BlockN mc n c
BlockN.zipWith ((Zq -> Zq -> Zq) -> Zq -> Zq -> Zq
forall a b c. (a -> b -> c) -> b -> a -> c
flip Zq -> Zq -> Zq
forall a. Add a => a -> a -> a
(.-)) BlockN 'Sec N Zq
b BlockN 'Pub N Zq
a
{-# INLINE (..-) #-}

instance Leak Rq

#ifdef ML_KEM_TESTING
fromCoeffs :: [Zq] -> Maybe (Rq Sec)
fromCoeffs = fmap Rq . BlockN.fromList

toCoeffs :: Rq Sec -> [Zq]
toCoeffs (Rq a) = BlockN.toList a
#endif

newtype Tq marking = Tq (BlockN marking N Zq)
#ifdef ML_KEM_TESTING
    deriving (Eq, Show, NFData)
#else
    deriving Tq marking -> ()
(Tq marking -> ()) -> NFData (Tq marking)
forall a. (a -> ()) -> NFData a
forall (marking :: SecurityMarking). Tq marking -> ()
$crnf :: forall (marking :: SecurityMarking). Tq marking -> ()
rnf :: Tq marking -> ()
NFData
#endif

instance Classified marking => Add (Tq marking) where
    zero :: Tq marking
zero = BlockN marking N Zq -> Tq marking
forall (marking :: SecurityMarking).
BlockN marking N Zq -> Tq marking
Tq BlockN marking N Zq
forall a. Add a => a
zero
    Tq BlockN marking N Zq
a .+ :: Tq marking -> Tq marking -> Tq marking
.+ Tq BlockN marking N Zq
b = BlockN marking N Zq -> Tq marking
forall (marking :: SecurityMarking).
BlockN marking N Zq -> Tq marking
Tq (BlockN marking N Zq
a BlockN marking N Zq -> BlockN marking N Zq -> BlockN marking N Zq
forall a. Add a => a -> a -> a
.+ BlockN marking N Zq
b)
    {-# INLINE (.+) #-}
    Tq BlockN marking N Zq
a .- :: Tq marking -> Tq marking -> Tq marking
.- Tq BlockN marking N Zq
b = BlockN marking N Zq -> Tq marking
forall (marking :: SecurityMarking).
BlockN marking N Zq -> Tq marking
Tq (BlockN marking N Zq
a BlockN marking N Zq -> BlockN marking N Zq -> BlockN marking N Zq
forall a. Add a => a -> a -> a
.- BlockN marking N Zq
b)
    {-# INLINE (.-) #-}
    neg :: Tq marking -> Tq marking
neg (Tq BlockN marking N Zq
a) = BlockN marking N Zq -> Tq marking
forall (marking :: SecurityMarking).
BlockN marking N Zq -> Tq marking
Tq (BlockN marking N Zq -> BlockN marking N Zq
forall a. Add a => a -> a
neg BlockN marking N Zq
a)
    {-# INLINE neg #-}

instance Leak Tq

instance BiMul (Tq Pub) (Tq Sec) where
    ..* :: Tq 'Pub -> Tq 'Sec -> Tq 'Sec
(..*) = Tq 'Pub -> Tq 'Sec -> Tq 'Sec
multiplyNTTs
    {-# INLINE (..*) #-}

instance BiMulAdd (Tq Pub) (Tq Sec) where
    biMulFold :: forall (t :: * -> *).
Foldable t =>
Tq 'Sec -> t (Tq 'Pub, Tq 'Sec) -> Tq 'Sec
biMulFold = Tq 'Sec -> t (Tq 'Pub, Tq 'Sec) -> Tq 'Sec
forall (t :: * -> *).
Foldable t =>
Tq 'Sec -> t (Tq 'Pub, Tq 'Sec) -> Tq 'Sec
multiplyNTTsFold
    {-# INLINE biMulFold #-}

#ifdef ML_KEM_TESTING
instance Mul (Tq Sec) where
    one = Tq $ BlockN.create $ \(Offset i) -> if even i then one else zero
    (.*) = (..*) . leak

instance MulAdd (Tq Sec) where
    mulAdd = biMulAdd . leak
#endif

instance Crypto.ConstEqW (Tq Sec) where
    constEqW :: Tq 'Sec -> Tq 'Sec -> BoolW
constEqW (Tq BlockN 'Sec N Zq
a) (Tq BlockN 'Sec N Zq
b) = ScrubbedBlock Word -> ScrubbedBlock Word -> BoolW
forall a. ConstEqW a => a -> a -> BoolW
Crypto.constEqW
        (BlockN 'Sec N Zq -> SecureBlock 'Sec Word
forall (marking :: SecurityMarking) (n :: Nat) a b.
BlockN marking n a -> SecureBlock marking b
BlockN.unsafeCast BlockN 'Sec N Zq
a :: SecureBlock Sec Word)
        (BlockN 'Sec N Zq -> SecureBlock 'Sec Word
forall (marking :: SecurityMarking) (n :: Nat) a b.
BlockN marking n a -> SecureBlock marking b
BlockN.unsafeCast BlockN 'Sec N Zq
b :: SecureBlock Sec Word)

instance Crypto.ConstEqW (Tq Pub) where
    constEqW :: Tq 'Pub -> Tq 'Pub -> BoolW
constEqW (Tq BlockN 'Pub N Zq
a) (Tq BlockN 'Pub N Zq
b) = PrimArray Word -> PrimArray Word -> BoolW
forall a. ConstEqW a => a -> a -> BoolW
Crypto.constEqW
        (BlockN 'Pub N Zq -> SecureBlock 'Pub Word
forall (marking :: SecurityMarking) (n :: Nat) a b.
BlockN marking n a -> SecureBlock marking b
BlockN.unsafeCast BlockN 'Pub N Zq
a :: SecureBlock Pub Word)
        (BlockN 'Pub N Zq -> SecureBlock 'Pub Word
forall (marking :: SecurityMarking) (n :: Nat) a b.
BlockN marking n a -> SecureBlock marking b
BlockN.unsafeCast BlockN 'Pub N Zq
b :: SecureBlock Pub Word)

-- Computes the NTT representation of the given polynomial
ntt :: Classified marking => Rq marking -> Tq marking
ntt :: forall (marking :: SecurityMarking).
Classified marking =>
Rq marking -> Tq marking
ntt (Rq BlockN marking N Zq
a) = BlockN marking N Zq -> Tq marking
forall (marking :: SecurityMarking).
BlockN marking N Zq -> Tq marking
Tq (BlockN marking N Zq -> Tq marking)
-> BlockN marking N Zq -> Tq marking
forall a b. (a -> b) -> a -> b
$ BlockN marking N Zq
-> (forall s. MutableBlockN marking N Zq s -> ST s ())
-> BlockN marking N Zq
forall (marking :: SecurityMarking) (n :: Nat) a.
(Classified marking, KnownNat n, PrimType a) =>
BlockN marking n a
-> (forall s. MutableBlockN marking n a s -> ST s ())
-> BlockN marking n a
BlockN.runThaw BlockN marking N Zq
a MutableBlockN marking N Zq s -> ST s ()
forall s. MutableBlockN marking N Zq s -> ST s ()
forall (marking :: SecurityMarking) s.
MutableBlockN marking N Zq s -> ST s ()
mutNtt
{-# INLINE ntt #-}

mutNtt :: MutableBlockN marking N Zq s -> ST s ()
mutNtt :: forall (marking :: SecurityMarking) s.
MutableBlockN marking N Zq s -> ST s ()
mutNtt !MutableBlockN marking N Zq s
b = Offset Zq -> Offset Zq -> ST s ()
outer Offset Zq
1 Offset Zq
128
  where
    outer :: Offset Zq -> Offset Zq -> ST s ()
outer !Offset Zq
i Offset Zq
len = Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Offset Zq
len Offset Zq -> Offset Zq -> Bool
forall a. Ord a => a -> a -> Bool
>= Offset Zq
2) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ Offset Zq -> Offset Zq -> Offset Zq -> ST s ()
inner Offset Zq
i Offset Zq
len Offset Zq
0

    inner :: Offset Zq -> Offset Zq -> Offset Zq -> ST s ()
inner !Offset Zq
i !Offset Zq
len Offset Zq
start
        | Offset Zq
start Offset Zq -> Offset Zq -> Bool
forall a. Ord a => a -> a -> Bool
< Offset Zq
256 = do
            let zeta :: Zq
zeta = BlockN 'Pub 128 Zq -> Offset Zq -> Zq
forall a (marking :: SecurityMarking) (n :: Nat).
PrimType a =>
BlockN marking n a -> Offset a -> a
BlockN.index BlockN 'Pub 128 Zq
zetaPowBitRev Offset Zq
i -- 17 ^ bitRev7 i
            Zq -> Offset Zq -> Offset Zq -> Offset Zq -> ST s ()
loop Zq
zeta (Offset Zq
start Offset Zq -> Offset Zq -> Offset Zq
forall a. Num a => a -> a -> a
+ Offset Zq
len) Offset Zq
len Offset Zq
start
            Offset Zq -> Offset Zq -> Offset Zq -> ST s ()
inner (Offset Zq
i Offset Zq -> Offset Zq -> Offset Zq
forall a. Num a => a -> a -> a
+ Offset Zq
1) Offset Zq
len (Offset Zq
start Offset Zq -> Offset Zq -> Offset Zq
forall a. Num a => a -> a -> a
+ Int -> Offset Zq -> Offset Zq
forall ty ty2. Int -> Offset ty -> Offset ty2
offsetShiftL Int
1 Offset Zq
len)
        | Bool
otherwise = Offset Zq -> Offset Zq -> ST s ()
outer Offset Zq
i (Int -> Offset Zq -> Offset Zq
forall ty ty2. Int -> Offset ty -> Offset ty2
offsetShiftR Int
1 Offset Zq
len)

    loop :: Zq -> Offset Zq -> Offset Zq -> Offset Zq -> ST s ()
loop !Zq
zeta Offset Zq
end Offset Zq
len Offset Zq
j =
        Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Offset Zq
j Offset Zq -> Offset Zq -> Bool
forall a. Ord a => a -> a -> Bool
< Offset Zq
end) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
            t <- (Zq
zeta Zq -> Zq -> Zq
forall a. Mul a => a -> a -> a
.*) (Zq -> Zq) -> ST s Zq -> ST s Zq
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MutableBlockN marking N Zq (PrimState (ST s))
-> Offset Zq -> ST s Zq
forall (prim :: * -> *) a (marking :: SecurityMarking) (n :: Nat).
(PrimMonad prim, PrimType a) =>
MutableBlockN marking n a (PrimState prim) -> Offset a -> prim a
BlockN.read MutableBlockN marking N Zq s
MutableBlockN marking N Zq (PrimState (ST s))
b (Offset Zq
j Offset Zq -> Offset Zq -> Offset Zq
forall a. Num a => a -> a -> a
+ Offset Zq
len)
            x <- BlockN.read b j
            BlockN.write b (j + len) (x .- t)
            BlockN.write b j (x .+ t)
            loop zeta end len (j + 1)
{-# NOINLINE mutNtt #-}

-- Computes the polynomial that corresponds to the given NTT representation
nttInv :: Tq Sec -> Rq Sec
nttInv :: Tq 'Sec -> Rq 'Sec
nttInv (Tq BlockN 'Sec N Zq
a) = BlockN 'Sec N Zq -> Rq 'Sec
forall (marking :: SecurityMarking).
BlockN marking N Zq -> Rq marking
Rq (BlockN 'Sec N Zq -> Rq 'Sec) -> BlockN 'Sec N Zq -> Rq 'Sec
forall a b. (a -> b) -> a -> b
$ BlockN 'Sec N Zq
-> (forall s. MutableBlockN 'Sec N Zq s -> ST s ())
-> BlockN 'Sec N Zq
forall (marking :: SecurityMarking) (n :: Nat) a.
(Classified marking, KnownNat n, PrimType a) =>
BlockN marking n a
-> (forall s. MutableBlockN marking n a s -> ST s ())
-> BlockN marking n a
BlockN.runThaw BlockN 'Sec N Zq
a MutableBlockN 'Sec N Zq s -> ST s ()
forall s. MutableBlockN 'Sec N Zq s -> ST s ()
mutNttInv
{-# INLINE nttInv #-}

mutNttInv :: MutableBlockN Sec N Zq s -> ST s ()
mutNttInv :: forall s. MutableBlockN 'Sec N Zq s -> ST s ()
mutNttInv !MutableBlockN 'Sec N Zq s
b = do
    Offset Zq -> Offset Zq -> ST s ()
outer Offset Zq
127 Offset Zq
2
    (Zq -> Zq) -> MutableBlockN 'Sec N Zq (PrimState (ST s)) -> ST s ()
forall (marking :: SecurityMarking) (n :: Nat) ty (prim :: * -> *).
(PrimType ty, KnownNat n, PrimMonad prim) =>
(ty -> ty)
-> MutableBlockN marking n ty (PrimState prim) -> prim ()
BlockN.iterModify (\Zq
x -> Zq
x Zq -> Zq -> Zq
forall a. Mul a => a -> a -> a
.* Word16 -> Zq
Zq Word16
3303) MutableBlockN 'Sec N Zq s
MutableBlockN 'Sec N Zq (PrimState (ST s))
b
  where
    outer :: Offset Zq -> Offset Zq -> ST s ()
outer !Offset Zq
i Offset Zq
len = Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Offset Zq
len Offset Zq -> Offset Zq -> Bool
forall a. Ord a => a -> a -> Bool
<= Offset Zq
128) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ Offset Zq -> Offset Zq -> Offset Zq -> ST s ()
inner Offset Zq
i Offset Zq
len Offset Zq
0

    inner :: Offset Zq -> Offset Zq -> Offset Zq -> ST s ()
inner !Offset Zq
i !Offset Zq
len Offset Zq
start
        | Offset Zq
start Offset Zq -> Offset Zq -> Bool
forall a. Ord a => a -> a -> Bool
< Offset Zq
256 = do
            let zeta :: Zq
zeta = BlockN 'Pub 128 Zq -> Offset Zq -> Zq
forall a (marking :: SecurityMarking) (n :: Nat).
PrimType a =>
BlockN marking n a -> Offset a -> a
BlockN.index BlockN 'Pub 128 Zq
zetaPowBitRev Offset Zq
i -- 17 ^ bitRev7 i
            Zq -> Offset Zq -> Offset Zq -> Offset Zq -> ST s ()
loop Zq
zeta (Offset Zq
start Offset Zq -> Offset Zq -> Offset Zq
forall a. Num a => a -> a -> a
+ Offset Zq
len) Offset Zq
len Offset Zq
start
            Offset Zq -> Offset Zq -> Offset Zq -> ST s ()
inner (Offset Zq
i Offset Zq -> Offset Zq -> Offset Zq
forall a. Num a => a -> a -> a
- Offset Zq
1) Offset Zq
len (Offset Zq
start Offset Zq -> Offset Zq -> Offset Zq
forall a. Num a => a -> a -> a
+ Int -> Offset Zq -> Offset Zq
forall ty ty2. Int -> Offset ty -> Offset ty2
offsetShiftL Int
1 Offset Zq
len)
        | Bool
otherwise = Offset Zq -> Offset Zq -> ST s ()
outer Offset Zq
i (Int -> Offset Zq -> Offset Zq
forall ty ty2. Int -> Offset ty -> Offset ty2
offsetShiftL Int
1 Offset Zq
len)

    loop :: Zq -> Offset Zq -> Offset Zq -> Offset Zq -> ST s ()
loop !Zq
zeta Offset Zq
end Offset Zq
len Offset Zq
j =
        Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Offset Zq
j Offset Zq -> Offset Zq -> Bool
forall a. Ord a => a -> a -> Bool
< Offset Zq
end) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
            t <- MutableBlockN 'Sec N Zq (PrimState (ST s)) -> Offset Zq -> ST s Zq
forall (prim :: * -> *) a (marking :: SecurityMarking) (n :: Nat).
(PrimMonad prim, PrimType a) =>
MutableBlockN marking n a (PrimState prim) -> Offset a -> prim a
BlockN.read MutableBlockN 'Sec N Zq s
MutableBlockN 'Sec N Zq (PrimState (ST s))
b Offset Zq
j
            x <- BlockN.read b (j + len)
            BlockN.write b j (t .+ x)
            BlockN.write b (j + len) (zeta .* (x .- t))
            loop zeta end len (j + 1)
{-# NOINLINE mutNttInv #-}

-- Computes the product of two NTT representations
multiplyNTTs :: Tq Pub -> Tq Sec -> Tq Sec
multiplyNTTs :: Tq 'Pub -> Tq 'Sec -> Tq 'Sec
multiplyNTTs Tq 'Pub
f Tq 'Sec
g = BlockN 'Sec N Zq -> Tq 'Sec
forall (marking :: SecurityMarking).
BlockN marking N Zq -> Tq marking
Tq (BlockN 'Sec N Zq -> Tq 'Sec) -> BlockN 'Sec N Zq -> Tq 'Sec
forall a b. (a -> b) -> a -> b
$
    Proxy 'Sec
-> (forall s. MutableBlockN 'Sec N Zq s -> ST s ())
-> BlockN 'Sec N Zq
forall (marking :: SecurityMarking) (n :: Nat) a
       (proxy :: SecurityMarking -> *).
(Classified marking, KnownNat n, PrimType a) =>
proxy marking
-> (forall s. MutableBlockN marking n a s -> ST s ())
-> BlockN marking n a
BlockN.runNew (Proxy 'Sec
forall {k} (t :: k). Proxy t
Proxy :: Proxy Sec) ((forall s. MutableBlockN 'Sec N Zq s -> ST s ())
 -> BlockN 'Sec N Zq)
-> (forall s. MutableBlockN 'Sec N Zq s -> ST s ())
-> BlockN 'Sec N Zq
forall a b. (a -> b) -> a -> b
$ Tq 'Pub -> Tq 'Sec -> MutableBlockN 'Sec N Zq s -> ST s ()
forall s.
Tq 'Pub -> Tq 'Sec -> MutableBlockN 'Sec N Zq s -> ST s ()
mutMultiplyNTTs Tq 'Pub
f Tq 'Sec
g
{-# INLINE multiplyNTTs #-}

mutMultiplyNTTs :: Tq Pub -> Tq Sec -> MutableBlockN Sec N Zq s -> ST s ()
mutMultiplyNTTs :: forall s.
Tq 'Pub -> Tq 'Sec -> MutableBlockN 'Sec N Zq s -> ST s ()
mutMultiplyNTTs (Tq !BlockN 'Pub N Zq
f) (Tq !BlockN 'Sec N Zq
g) MutableBlockN 'Sec N Zq s
bb = MutableBlockN 'Sec N Zq s -> Offset Zq -> ST s ()
forall s. MutableBlockN 'Sec N Zq s -> Offset Zq -> ST s ()
loop MutableBlockN 'Sec N Zq s
bb Offset Zq
0
  where
    loop :: MutableBlockN Sec N Zq s -> Offset Zq -> ST s ()
    loop :: forall s. MutableBlockN 'Sec N Zq s -> Offset Zq -> ST s ()
loop !MutableBlockN 'Sec N Zq s
b Offset Zq
i = Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Offset Zq
i Offset Zq -> Offset Zq -> Bool
forall a. Ord a => a -> a -> Bool
< Offset Zq
128) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
        let ii :: Offset Zq
ii = Int -> Offset Zq -> Offset Zq
forall ty ty2. Int -> Offset ty -> Offset ty2
offsetShiftL Int
1 Offset Zq
i
            a0 :: Zq
a0 = BlockN 'Pub N Zq -> Offset Zq -> Zq
forall a (marking :: SecurityMarking) (n :: Nat).
PrimType a =>
BlockN marking n a -> Offset a -> a
BlockN.index BlockN 'Pub N Zq
f Offset Zq
ii
            a1 :: Zq
a1 = BlockN 'Pub N Zq -> Offset Zq -> Zq
forall a (marking :: SecurityMarking) (n :: Nat).
PrimType a =>
BlockN marking n a -> Offset a -> a
BlockN.index BlockN 'Pub N Zq
f (Offset Zq
ii Offset Zq -> Offset Zq -> Offset Zq
forall a. Num a => a -> a -> a
+ Offset Zq
1)
            b0 :: Zq
b0 = BlockN 'Sec N Zq -> Offset Zq -> Zq
forall a (marking :: SecurityMarking) (n :: Nat).
PrimType a =>
BlockN marking n a -> Offset a -> a
BlockN.index BlockN 'Sec N Zq
g Offset Zq
ii
            b1 :: Zq
b1 = BlockN 'Sec N Zq -> Offset Zq -> Zq
forall a (marking :: SecurityMarking) (n :: Nat).
PrimType a =>
BlockN marking n a -> Offset a -> a
BlockN.index BlockN 'Sec N Zq
g (Offset Zq
ii Offset Zq -> Offset Zq -> Offset Zq
forall a. Num a => a -> a -> a
+ Offset Zq
1)
            (Zq
c0, Zq
c1) = Zq -> Zq -> Zq -> Zq -> Zq -> (Zq, Zq)
baseCaseMultiply Zq
a0 Zq
a1 Zq
b0 Zq
b1 (BlockN 'Pub 128 Zq -> Offset Zq -> Zq
forall a (marking :: SecurityMarking) (n :: Nat).
PrimType a =>
BlockN marking n a -> Offset a -> a
BlockN.index BlockN 'Pub 128 Zq
gamma Offset Zq
i)
        MutableBlockN 'Sec N Zq (PrimState (ST s))
-> Offset Zq -> Zq -> ST s ()
forall (prim :: * -> *) a (marking :: SecurityMarking) (n :: Nat).
(PrimMonad prim, PrimType a) =>
MutableBlockN marking n a (PrimState prim)
-> Offset a -> a -> prim ()
BlockN.write MutableBlockN 'Sec N Zq s
MutableBlockN 'Sec N Zq (PrimState (ST s))
b Offset Zq
ii Zq
c0
        MutableBlockN 'Sec N Zq (PrimState (ST s))
-> Offset Zq -> Zq -> ST s ()
forall (prim :: * -> *) a (marking :: SecurityMarking) (n :: Nat).
(PrimMonad prim, PrimType a) =>
MutableBlockN marking n a (PrimState prim)
-> Offset a -> a -> prim ()
BlockN.write MutableBlockN 'Sec N Zq s
MutableBlockN 'Sec N Zq (PrimState (ST s))
b (Offset Zq
ii Offset Zq -> Offset Zq -> Offset Zq
forall a. Num a => a -> a -> a
+ Offset Zq
1) Zq
c1
        MutableBlockN 'Sec N Zq s -> Offset Zq -> ST s ()
forall s. MutableBlockN 'Sec N Zq s -> Offset Zq -> ST s ()
loop MutableBlockN 'Sec N Zq s
b (Offset Zq
i Offset Zq -> Offset Zq -> Offset Zq
forall a. Num a => a -> a -> a
+ Offset Zq
1)

-- Computes the product of two degree-one polynomials with respect to a quadratic modulus
baseCaseMultiply :: Zq -> Zq -> Zq -> Zq -> Zq -> (Zq, Zq)
baseCaseMultiply :: Zq -> Zq -> Zq -> Zq -> Zq -> (Zq, Zq)
baseCaseMultiply (Zq Word16
a0) (Zq Word16
a1) (Zq Word16
b0) (Zq Word16
b1) (Zq Word16
g) = (Word16 -> Zq
Zq Word16
c0, Word16 -> Zq
Zq Word16
c1)
  where
    a
x mul :: a -> a -> a
`mul` a
y = a -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
x a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
y
    b1g :: Word16
b1g = Word32 -> Word16
reduce (Word16
b1 Word16 -> Word16 -> Word32
forall {a} {a} {a}. (Integral a, Integral a, Num a) => a -> a -> a
`mul` Word16
g)
    !c0 :: Word16
c0 = Word32 -> Word16
reduce (Word16
a0 Word16 -> Word16 -> Word32
forall {a} {a} {a}. (Integral a, Integral a, Num a) => a -> a -> a
`mul` Word16
b0 Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
+ Word16
a1 Word16 -> Word16 -> Word32
forall {a} {a} {a}. (Integral a, Integral a, Num a) => a -> a -> a
`mul` Word16
b1g)
    !c1 :: Word16
c1 = Word32 -> Word16
reduce (Word16
a0 Word16 -> Word16 -> Word32
forall {a} {a} {a}. (Integral a, Integral a, Num a) => a -> a -> a
`mul` Word16
b1 Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
+ Word16
a1 Word16 -> Word16 -> Word32
forall {a} {a} {a}. (Integral a, Integral a, Num a) => a -> a -> a
`mul` Word16
b0)

multiplyNTTsFold :: Foldable t => Tq Sec -> t (Tq Pub, Tq Sec) -> Tq Sec
multiplyNTTsFold :: forall (t :: * -> *).
Foldable t =>
Tq 'Sec -> t (Tq 'Pub, Tq 'Sec) -> Tq 'Sec
multiplyNTTsFold (Tq BlockN 'Sec N Zq
c) =
    BlockN 'Sec N Zq -> Tq 'Sec
forall (marking :: SecurityMarking).
BlockN marking N Zq -> Tq marking
Tq (BlockN 'Sec N Zq -> Tq 'Sec)
-> (t (Tq 'Pub, Tq 'Sec) -> BlockN 'Sec N Zq)
-> t (Tq 'Pub, Tq 'Sec)
-> Tq 'Sec
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BlockN 'Sec N Zq
-> (forall s.
    (Tq 'Pub, Tq 'Sec) -> MutableBlockN 'Sec N Zq s -> ST s ())
-> t (Tq 'Pub, Tq 'Sec)
-> BlockN 'Sec N Zq
forall (marking :: SecurityMarking) (n :: Nat) a (t :: * -> *) b.
(Classified marking, KnownNat n, PrimType a, Foldable t) =>
BlockN marking n a
-> (forall s. b -> MutableBlockN marking n a s -> ST s ())
-> t b
-> BlockN marking n a
BlockN.runFold BlockN 'Sec N Zq
c ((Tq 'Pub -> Tq 'Sec -> MutableBlockN 'Sec N Zq s -> ST s ())
-> (Tq 'Pub, Tq 'Sec) -> MutableBlockN 'Sec N Zq s -> ST s ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Tq 'Pub -> Tq 'Sec -> MutableBlockN 'Sec N Zq s -> ST s ()
forall s.
Tq 'Pub -> Tq 'Sec -> MutableBlockN 'Sec N Zq s -> ST s ()
multiplyNTTsAdd)
{-# INLINE multiplyNTTsFold #-}

-- Multiply then add a third term
multiplyNTTsAdd :: Tq Pub -> Tq Sec -> MutableBlockN Sec N Zq s -> ST s ()
multiplyNTTsAdd :: forall s.
Tq 'Pub -> Tq 'Sec -> MutableBlockN 'Sec N Zq s -> ST s ()
multiplyNTTsAdd (Tq !BlockN 'Pub N Zq
f) (Tq !BlockN 'Sec N Zq
g) MutableBlockN 'Sec N Zq s
bb = MutableBlockN 'Sec N Zq s -> Offset Zq -> ST s ()
forall s. MutableBlockN 'Sec N Zq s -> Offset Zq -> ST s ()
loop MutableBlockN 'Sec N Zq s
bb Offset Zq
0
  where
    loop :: MutableBlockN Sec N Zq s -> Offset Zq -> ST s ()
    loop :: forall s. MutableBlockN 'Sec N Zq s -> Offset Zq -> ST s ()
loop !MutableBlockN 'Sec N Zq s
b Offset Zq
i = Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Offset Zq
i Offset Zq -> Offset Zq -> Bool
forall a. Ord a => a -> a -> Bool
< Offset Zq
128) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
        let ii :: Offset Zq
ii = Int -> Offset Zq -> Offset Zq
forall ty ty2. Int -> Offset ty -> Offset ty2
offsetShiftL Int
1 Offset Zq
i
        c0 <- MutableBlockN 'Sec N Zq (PrimState (ST s)) -> Offset Zq -> ST s Zq
forall (prim :: * -> *) a (marking :: SecurityMarking) (n :: Nat).
(PrimMonad prim, PrimType a) =>
MutableBlockN marking n a (PrimState prim) -> Offset a -> prim a
BlockN.read MutableBlockN 'Sec N Zq s
MutableBlockN 'Sec N Zq (PrimState (ST s))
b Offset Zq
ii
        c1 <- BlockN.read b (ii + 1)
        let a0 = BlockN 'Pub N Zq -> Offset Zq -> Zq
forall a (marking :: SecurityMarking) (n :: Nat).
PrimType a =>
BlockN marking n a -> Offset a -> a
BlockN.index BlockN 'Pub N Zq
f Offset Zq
ii
            a1 = BlockN 'Pub N Zq -> Offset Zq -> Zq
forall a (marking :: SecurityMarking) (n :: Nat).
PrimType a =>
BlockN marking n a -> Offset a -> a
BlockN.index BlockN 'Pub N Zq
f (Offset Zq
ii Offset Zq -> Offset Zq -> Offset Zq
forall a. Num a => a -> a -> a
+ Offset Zq
1)
            b0 = BlockN 'Sec N Zq -> Offset Zq -> Zq
forall a (marking :: SecurityMarking) (n :: Nat).
PrimType a =>
BlockN marking n a -> Offset a -> a
BlockN.index BlockN 'Sec N Zq
g Offset Zq
ii
            b1 = BlockN 'Sec N Zq -> Offset Zq -> Zq
forall a (marking :: SecurityMarking) (n :: Nat).
PrimType a =>
BlockN marking n a -> Offset a -> a
BlockN.index BlockN 'Sec N Zq
g (Offset Zq
ii Offset Zq -> Offset Zq -> Offset Zq
forall a. Num a => a -> a -> a
+ Offset Zq
1)
            (d0, d1) = baseCaseMultiplyAdd a0 a1 b0 b1 c0 c1 (BlockN.index gamma i)
        BlockN.write b ii d0
        BlockN.write b (ii + 1) d1
        loop b (i + 1)

-- baseCaseMultiply then add a third term
baseCaseMultiplyAdd :: Zq -> Zq -> Zq -> Zq -> Zq -> Zq -> Zq -> (Zq, Zq)
baseCaseMultiplyAdd :: Zq -> Zq -> Zq -> Zq -> Zq -> Zq -> Zq -> (Zq, Zq)
baseCaseMultiplyAdd (Zq Word16
a0) (Zq Word16
a1) (Zq Word16
b0) (Zq Word16
b1) (Zq Word16
c0) (Zq Word16
c1) (Zq Word16
g) = (Word16 -> Zq
Zq Word16
d0, Word16 -> Zq
Zq Word16
d1)
  where
    a
x mul :: a -> a -> a
`mul` a
y = a -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
x a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
y
    b1g :: Word16
b1g = Word32 -> Word16
reduce (Word16
b1 Word16 -> Word16 -> Word32
forall {a} {a} {a}. (Integral a, Integral a, Num a) => a -> a -> a
`mul` Word16
g)
    !d0 :: Word16
d0 = Word32 -> Word16
reduce (Word16 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
c0 Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
+ Word16
a0 Word16 -> Word16 -> Word32
forall {a} {a} {a}. (Integral a, Integral a, Num a) => a -> a -> a
`mul` Word16
b0 Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
+ Word16
a1 Word16 -> Word16 -> Word32
forall {a} {a} {a}. (Integral a, Integral a, Num a) => a -> a -> a
`mul` Word16
b1g)
    !d1 :: Word16
d1 = Word32 -> Word16
reduce (Word16 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
c1 Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
+ Word16
a0 Word16 -> Word16 -> Word32
forall {a} {a} {a}. (Integral a, Integral a, Num a) => a -> a -> a
`mul` Word16
b1 Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
+ Word16
a1 Word16 -> Word16 -> Word32
forall {a} {a} {a}. (Integral a, Integral a, Num a) => a -> a -> a
`mul` Word16
b0)

-- Values of 17 ^ BitRev7(𝑖) mod 𝑞 for 𝑖 ∈ {0, … , 127}
zetaPowBitRev :: BlockN Pub 128 Zq
zetaPowBitRev :: BlockN 'Pub 128 Zq
zetaPowBitRev = Proxy 'Pub
-> (forall s. MutableBlockN 'Pub 128 Zq s -> ST s ())
-> BlockN 'Pub 128 Zq
forall (marking :: SecurityMarking) (n :: Nat) a
       (proxy :: SecurityMarking -> *).
(Classified marking, KnownNat n, PrimType a) =>
proxy marking
-> (forall s. MutableBlockN marking n a s -> ST s ())
-> BlockN marking n a
BlockN.runNew (Proxy 'Pub
forall {k} (t :: k). Proxy t
Proxy :: Proxy Pub) ((forall s. MutableBlockN 'Pub 128 Zq s -> ST s ())
 -> BlockN 'Pub 128 Zq)
-> (forall s. MutableBlockN 'Pub 128 Zq s -> ST s ())
-> BlockN 'Pub 128 Zq
forall a b. (a -> b) -> a -> b
$ \MutableBlockN 'Pub 128 Zq s
out ->
    (Zq -> Offset Zq -> ST s Zq) -> Zq -> [Offset Zq] -> ST s ()
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m ()
foldM_ (MutableBlockN 'Pub 128 Zq (PrimState (ST s))
-> Zq -> Offset Zq -> ST s Zq
forall {m :: * -> *} {marking :: SecurityMarking} {n :: Nat}.
PrimMonad m =>
MutableBlockN marking n Zq (PrimState m) -> Zq -> Offset Zq -> m Zq
loop MutableBlockN 'Pub 128 Zq s
MutableBlockN 'Pub 128 Zq (PrimState (ST s))
out) Zq
forall a. Mul a => a
one [Offset Zq]
offsets
  where
    offsets :: [Offset Zq]
offsets = (WordM -> Offset Zq) -> [WordM] -> [Offset Zq]
forall a b. (a -> b) -> [a] -> [b]
Prelude.map (WordM -> Offset Zq
forall a b. (Integral a, Num b) => a -> b
fromIntegral (WordM -> Offset Zq) -> (WordM -> WordM) -> WordM -> Offset Zq
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WordM -> WordM
bitRev7) [WordM
0 .. WordM
127]
    loop :: MutableBlockN marking n Zq (PrimState m) -> Zq -> Offset Zq -> m Zq
loop MutableBlockN marking n Zq (PrimState m)
b Zq
acc Offset Zq
i = MutableBlockN marking n Zq (PrimState m) -> Offset Zq -> Zq -> m ()
forall (prim :: * -> *) a (marking :: SecurityMarking) (n :: Nat).
(PrimMonad prim, PrimType a) =>
MutableBlockN marking n a (PrimState prim)
-> Offset a -> a -> prim ()
BlockN.write MutableBlockN marking n Zq (PrimState m)
b Offset Zq
i Zq
acc m () -> m Zq -> m Zq
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Zq -> m Zq
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Word16 -> Zq
Zq Word16
17 Zq -> Zq -> Zq
forall a. Mul a => a -> a -> a
.* Zq
acc)

-- Values of 17 ^ 2.BitRev7(𝑖)+1 mod 𝑞 for 𝑖 ∈ {0, … , 127}
gamma :: BlockN Pub 128 Zq
gamma :: BlockN 'Pub 128 Zq
gamma = (Zq -> Zq) -> BlockN 'Pub 128 Zq -> BlockN 'Pub 128 Zq
forall (marking :: SecurityMarking) (n :: Nat) a b.
(Classified marking, KnownNat n, EqPrimSize a b) =>
(a -> b) -> BlockN marking n a -> BlockN marking n b
BlockN.mapEqPrimSize (\Zq
z -> Zq
z Zq -> Zq -> Zq
forall a. Mul a => a -> a -> a
.* Zq
z Zq -> Zq -> Zq
forall a. Mul a => a -> a -> a
.* Word16 -> Zq
Zq Word16
17) BlockN 'Pub 128 Zq
zetaPowBitRev

-- Compress a field element with 𝑑 < 12
compress :: Int -> Zq -> Word16
compress :: Int -> Zq -> Word16
compress Int
d (Zq Word16
x) = Word64 -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word64 -> Word16) -> Word64 -> Word16
forall a b. (a -> b) -> a -> b
$
    ((Word16 -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
x Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
d Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
qHalf) Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
* Word64
factor) Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
34
  where
    qHalf :: Word64
qHalf = (Word64
q64 Word64 -> Word64 -> Word64
forall a. Num a => a -> a -> a
+ Word64
1) Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
1
    factor :: Word64
factor = (Word64
1 Word64 -> Int -> Word64
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
34) Word64 -> Word64 -> Word64
forall a. Integral a => a -> a -> a
`div` Word64
q64
{-# INLINE compress #-}

-- Decompress a field element with 𝑑 < 12
decompress :: Int -> Word16 -> Zq
decompress :: Int -> Word16 -> Zq
decompress Int
d Word16
y = Word16 -> Zq
Zq (Word16 -> Zq) -> Word16 -> Zq
forall a b. (a -> b) -> a -> b
$ Word32 -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32
x2d Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
d)
  where x2d :: Word32
x2d = Word16 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
y Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
* Word32
q32 Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
+ (Word32
1 Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`unsafeShiftL` (Int
d Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1))
{-# INLINE decompress #-}

-- Compress a polynomial with 𝑑 < 12
rcompress :: Classified marking => Int -> Rq marking -> BlockN marking N Word16
rcompress :: forall (marking :: SecurityMarking).
Classified marking =>
Int -> Rq marking -> BlockN marking N Word16
rcompress Int
d (Rq BlockN marking N Zq
a) = Int -> BlockN marking N Word16 -> BlockN marking N Word16
forall (marking :: SecurityMarking) (n :: Nat) a b.
(Classified marking, KnownNat n, PrimType a) =>
b -> BlockN marking n a -> BlockN marking n a
BlockN.seq Int
d (BlockN marking N Word16 -> BlockN marking N Word16)
-> BlockN marking N Word16 -> BlockN marking N Word16
forall a b. (a -> b) -> a -> b
$ (Zq -> Word16) -> BlockN marking N Zq -> BlockN marking N Word16
forall (marking :: SecurityMarking) (n :: Nat) a b.
(Classified marking, KnownNat n, EqPrimSize a b) =>
(a -> b) -> BlockN marking n a -> BlockN marking n b
BlockN.mapEqPrimSize (Int -> Zq -> Word16
compress Int
d) BlockN marking N Zq
a
{-# INLINE rcompress #-}

-- Decompress a polynomial with 𝑑 < 12
rdecompress :: Classified marking => Int -> BlockN marking N Word16 -> Rq marking
rdecompress :: forall (marking :: SecurityMarking).
Classified marking =>
Int -> BlockN marking N Word16 -> Rq marking
rdecompress Int
d = BlockN marking N Zq -> Rq marking
forall (marking :: SecurityMarking).
BlockN marking N Zq -> Rq marking
Rq (BlockN marking N Zq -> Rq marking)
-> (BlockN marking N Word16 -> BlockN marking N Zq)
-> BlockN marking N Word16
-> Rq marking
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> BlockN marking N Zq -> BlockN marking N Zq
forall (marking :: SecurityMarking) (n :: Nat) a b.
(Classified marking, KnownNat n, PrimType a) =>
b -> BlockN marking n a -> BlockN marking n a
BlockN.seq Int
d (BlockN marking N Zq -> BlockN marking N Zq)
-> (BlockN marking N Word16 -> BlockN marking N Zq)
-> BlockN marking N Word16
-> BlockN marking N Zq
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Word16 -> Zq) -> BlockN marking N Word16 -> BlockN marking N Zq
forall (marking :: SecurityMarking) (n :: Nat) a b.
(Classified marking, KnownNat n, EqPrimSize a b) =>
(a -> b) -> BlockN marking n a -> BlockN marking n b
BlockN.mapEqPrimSize (Int -> Word16 -> Zq
decompress Int
d)
{-# INLINE rdecompress #-}

-- Generates a pseudorandom element of T𝑞 from a seed and two indices
sampleNTT :: SecureBytes Pub -> Word8 -> Word8 -> Tq Pub
sampleNTT :: SecureBytes 'Pub -> WordM -> WordM -> Tq 'Pub
sampleNTT SecureBytes 'Pub
seed !WordM
x !WordM
y = BlockN 'Pub N Zq -> Tq 'Pub
forall (marking :: SecurityMarking).
BlockN marking N Zq -> Tq marking
Tq (BlockN 'Pub N Zq -> Tq 'Pub) -> BlockN 'Pub N Zq -> Tq 'Pub
forall a b. (a -> b) -> a -> b
$
    Proxy 'Pub
-> (forall s. MutableBlockN 'Pub N Zq s -> ST s ())
-> BlockN 'Pub N Zq
forall (marking :: SecurityMarking) (n :: Nat) a
       (proxy :: SecurityMarking -> *).
(Classified marking, KnownNat n, PrimType a) =>
proxy marking
-> (forall s. MutableBlockN marking n a s -> ST s ())
-> BlockN marking n a
BlockN.runNew (Proxy 'Pub
forall {k} (t :: k). Proxy t
Proxy :: Proxy Pub) ((forall s. MutableBlockN 'Pub N Zq s -> ST s ())
 -> BlockN 'Pub N Zq)
-> (forall s. MutableBlockN 'Pub N Zq s -> ST s ())
-> BlockN 'Pub N Zq
forall a b. (a -> b) -> a -> b
$ \MutableBlockN 'Pub N Zq s
b -> MutableBlockN 'Pub N Zq (PrimState (ST s))
-> Int -> Offset WordM -> Offset Zq -> ST s ()
forall {m :: * -> *} {marking :: SecurityMarking} {n :: Nat}.
PrimMonad m =>
MutableBlockN marking n Zq (PrimState m)
-> Int -> Offset WordM -> Offset Zq -> m ()
runXof MutableBlockN 'Pub N Zq s
MutableBlockN 'Pub N Zq (PrimState (ST s))
b (Int
280 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
3) Offset WordM
0 Offset Zq
0
  where
    runXof :: MutableBlockN marking n Zq (PrimState m)
-> Int -> Offset WordM -> Offset Zq -> m ()
runXof !MutableBlockN marking n Zq (PrimState m)
b !Int
xofLen !Offset WordM
pos !Offset Zq
j = case Nat -> SomeNat
someNatVal (Int -> Nat
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
8 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
xofLen)) of
        SomeNat Proxy n
proxy -> do
            let bytes :: Block WordM
bytes = BlockDigest (SHAKE128 n) -> Block WordM
forall a. BlockDigest a -> Block WordM
Crypto.unBlockDigest (Proxy n -> BlockDigest (SHAKE128 n)
forall (bitlen :: Nat) (proxy :: Nat -> *).
KnownNat bitlen =>
proxy bitlen -> BlockDigest (SHAKE128 bitlen)
doHash Proxy n
proxy)
            MutableBlockN marking n Zq (PrimState m)
-> Int -> Block WordM -> Offset WordM -> Offset Zq -> m ()
loop MutableBlockN marking n Zq (PrimState m)
b Int
xofLen Block WordM
bytes Offset WordM
pos Offset Zq
j

    loop :: MutableBlockN marking n Zq (PrimState m)
-> Int -> Block WordM -> Offset WordM -> Offset Zq -> m ()
loop !MutableBlockN marking n Zq (PrimState m)
b !Int
xofLen !Block WordM
bytes !Offset WordM
pos Offset Zq
j
        | Offset Zq
j Offset Zq -> Offset Zq -> Bool
forall a. Eq a => a -> a -> Bool
== Offset Zq
256 = () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        | Offset WordM
pos Offset WordM -> Offset WordM -> Bool
forall a. Ord a => a -> a -> Bool
>= Int -> Offset WordM
forall ty. Int -> Offset ty
Offset Int
xofLen = MutableBlockN marking n Zq (PrimState m)
-> Int -> Offset WordM -> Offset Zq -> m ()
runXof MutableBlockN marking n Zq (PrimState m)
b (Int
xofLen Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
56 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
3) Offset WordM
pos Offset Zq
j
        | Bool
otherwise = do
            let c0 :: Word16
c0 = WordM -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral (WordM -> Word16) -> WordM -> Word16
forall a b. (a -> b) -> a -> b
$ Block WordM -> Offset WordM -> WordM
forall ty. PrimType ty => Block ty -> Offset ty -> ty
blockIndex Block WordM
bytes Offset WordM
pos
                c1 :: Word16
c1 = WordM -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral (WordM -> Word16) -> WordM -> Word16
forall a b. (a -> b) -> a -> b
$ Block WordM -> Offset WordM -> WordM
forall ty. PrimType ty => Block ty -> Offset ty -> ty
blockIndex Block WordM
bytes (Offset WordM
pos Offset WordM -> Offset WordM -> Offset WordM
forall a. Num a => a -> a -> a
+ Offset WordM
1)
                c2 :: Word16
c2 = WordM -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral (WordM -> Word16) -> WordM -> Word16
forall a b. (a -> b) -> a -> b
$ Block WordM -> Offset WordM -> WordM
forall ty. PrimType ty => Block ty -> Offset ty -> ty
blockIndex Block WordM
bytes (Offset WordM
pos Offset WordM -> Offset WordM -> Offset WordM
forall a. Num a => a -> a -> a
+ Offset WordM
2)
                d1 :: Word16
d1 = Word16
c0 Word16 -> Word16 -> Word16
forall a. Num a => a -> a -> a
+ (Word16
c1 Word16 -> Word16 -> Word16
forall a. Bits a => a -> a -> a
.&. Word16
0xF) Word16 -> Int -> Word16
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
8
                d2 :: Word16
d2 = (Word16
c1 Word16 -> Int -> Word16
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
4) Word16 -> Word16 -> Word16
forall a. Num a => a -> a -> a
+ (Word16
c2 Word16 -> Int -> Word16
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
4)
            j2 <- MutableBlockN marking n Zq (PrimState m)
-> Offset Zq -> Word16 -> m (Offset Zq)
forall {m :: * -> *} {marking :: SecurityMarking} {n :: Nat}.
PrimMonad m =>
MutableBlockN marking n Zq (PrimState m)
-> Offset Zq -> Word16 -> m (Offset Zq)
poke MutableBlockN marking n Zq (PrimState m)
b Offset Zq
j Word16
d1
            when (j2 < 256) $ poke b j2 d2 >>= loop b xofLen bytes (pos + 3)

    poke :: MutableBlockN marking n Zq (PrimState m)
-> Offset Zq -> Word16 -> m (Offset Zq)
poke MutableBlockN marking n Zq (PrimState m)
b Offset Zq
j Word16
d
        | Word16
d Word16 -> Word16 -> Bool
forall a. Ord a => a -> a -> Bool
< Word16
q16 = MutableBlockN marking n Zq (PrimState m) -> Offset Zq -> Zq -> m ()
forall (prim :: * -> *) a (marking :: SecurityMarking) (n :: Nat).
(PrimMonad prim, PrimType a) =>
MutableBlockN marking n a (PrimState prim)
-> Offset a -> a -> prim ()
BlockN.write MutableBlockN marking n Zq (PrimState m)
b Offset Zq
j (Word16 -> Zq
Zq Word16
d) m () -> m (Offset Zq) -> m (Offset Zq)
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Offset Zq -> m (Offset Zq)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Offset Zq
j Offset Zq -> Offset Zq -> Offset Zq
forall a. Num a => a -> a -> a
+ Offset Zq
1)
        | Bool
otherwise = Offset Zq -> m (Offset Zq)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Offset Zq
j

    doHash :: KnownNat bitlen => proxy bitlen -> BlockDigest (SHAKE128 bitlen)
    doHash :: forall (bitlen :: Nat) (proxy :: Nat -> *).
KnownNat bitlen =>
proxy bitlen -> BlockDigest (SHAKE128 bitlen)
doHash proxy bitlen
_ = Bytes -> BlockDigest (SHAKE128 bitlen)
forall a. HashAlgorithm a => Bytes -> BlockDigest a
Crypto.hashToBlock Bytes
SecureBytes 'Pub
input

    input :: SecureBytes Pub
    !input :: SecureBytes 'Pub
input = Int -> (Ptr (ZonkAny 0) -> IO ()) -> SecureBytes 'Pub
forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.unsafeCreate (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2) ((Ptr (ZonkAny 0) -> IO ()) -> SecureBytes 'Pub)
-> (Ptr (ZonkAny 0) -> IO ()) -> SecureBytes 'Pub
forall a b. (a -> b) -> a -> b
$ \Ptr (ZonkAny 0)
d -> do
        Bytes -> Ptr (ZonkAny 0) -> IO ()
forall ba p. ByteArrayAccess ba => ba -> Ptr p -> IO ()
forall p. Bytes -> Ptr p -> IO ()
B.copyByteArrayToPtr Bytes
SecureBytes 'Pub
seed Ptr (ZonkAny 0)
d
        Ptr (ZonkAny 0) -> Int -> WordM -> IO ()
forall b. Ptr b -> Int -> WordM -> IO ()
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr (ZonkAny 0)
d Int
len WordM
x
        Ptr (ZonkAny 0) -> Int -> WordM -> IO ()
forall b. Ptr b -> Int -> WordM -> IO ()
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr (ZonkAny 0)
d (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) WordM
y
    len :: Int
len = Bytes -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length Bytes
SecureBytes 'Pub
seed

peekWord :: Ptr WordLE -> ST s WordM
peekWord :: forall s. Ptr WordM -> ST s WordM
peekWord Ptr WordM
p = WordM -> WordM
fromLE (WordM -> WordM) -> ST s WordM -> ST s WordM
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr WordM -> ST s WordM
forall a s. Storable a => Ptr a -> ST s a
ST.peek Ptr WordM
p

peekWordPos :: Ptr WordLE -> BitPos -> ST s WordM
peekWordPos :: forall s. Ptr WordM -> BitPos -> ST s WordM
peekWordPos Ptr WordM
a BitPos
bp = WordM -> WordM
fromLE (WordM -> WordM) -> ST s WordM -> ST s WordM
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr WordM -> Int -> ST s WordM
forall a s. Storable a => Ptr a -> Int -> ST s a
ST.peekElemOff Ptr WordM
a (BitPos -> Int
wordOff BitPos
bp)

pokeWordPos :: Ptr WordLE -> BitPos -> WordM -> ST s ()
pokeWordPos :: forall s. Ptr WordM -> BitPos -> WordM -> ST s ()
pokeWordPos Ptr WordM
a BitPos
bp = Ptr WordM -> Int -> WordM -> ST s ()
forall a s. Storable a => Ptr a -> Int -> a -> ST s ()
ST.pokeElemOff Ptr WordM
a (BitPos -> Int
wordOff BitPos
bp) (WordM -> ST s ()) -> (WordM -> WordM) -> WordM -> ST s ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WordM -> WordM
toLE

newtype BitPos = BitPos Int

zeroPos :: BitPos
zeroPos :: BitPos
zeroPos = Int -> BitPos
BitPos Int
0

wordOff :: BitPos -> Int
wordOff :: BitPos -> Int
wordOff (BitPos Int
p) = Int -> Int -> Int
forall a. Integral a => a -> a -> a
div Int
p Int
wordBits

bitPos :: BitPos -> Int
bitPos :: BitPos -> Int
bitPos (BitPos Int
p) = Int
p Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. (Int
wordBits Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)

availPos :: Int -> BitPos -> Int
availPos :: Int -> BitPos -> Int
availPos Int
requested (BitPos Int
p) = Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
available Int
requested
  where available :: Int
available = Int
wordBits Int -> Int -> Int
forall a. Num a => a -> a -> a
- (Int
p Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. (Int
wordBits Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1))

nextPos :: Int -> BitPos -> (Int, BitPos)
nextPos :: Int -> BitPos -> (Int, BitPos)
nextPos Int
requested (BitPos Int
p) = (Int
howMany, Int -> BitPos
BitPos (Int -> BitPos) -> Int -> BitPos
forall a b. (a -> b) -> a -> b
$ Int
p Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
howMany)
  where howMany :: Int
howMany = Int -> BitPos -> Int
availPos Int
requested (Int -> BitPos
BitPos Int
p)

getMask :: Int -> WordM
getMask :: Int -> WordM
getMask Int
howMany
    | Int
howMany Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
wordBits = WordM
forall a. Bounded a => a
maxBound
    | Bool
otherwise = (WordM
1 WordM -> Int -> WordM
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
howMany) WordM -> WordM -> WordM
forall a. Num a => a -> a -> a
- WordM
1
    -- branch useful only when processing one byte at a time due to
    -- architecture not supporting unaligned memory access

-- Takes a seed as input and outputs a pseudorandom sample from the
-- distribution D_eta
samplePolyCBD :: Word -> SecureBytes Sec -> Rq Sec
samplePolyCBD :: Word -> SecureBytes 'Sec -> Rq 'Sec
samplePolyCBD Word
eta SecureBytes 'Sec
input = BlockN 'Sec N Zq -> Rq 'Sec
forall (marking :: SecurityMarking).
BlockN marking N Zq -> Rq marking
Rq (BlockN 'Sec N Zq -> Rq 'Sec) -> BlockN 'Sec N Zq -> Rq 'Sec
forall a b. (a -> b) -> a -> b
$
    Proxy 'Sec
-> (forall s. MutableBlockN 'Sec N Zq s -> ST s ())
-> BlockN 'Sec N Zq
forall (marking :: SecurityMarking) (n :: Nat) a
       (proxy :: SecurityMarking -> *).
(Classified marking, KnownNat n, PrimType a) =>
proxy marking
-> (forall s. MutableBlockN marking n a s -> ST s ())
-> BlockN marking n a
BlockN.runNew (Proxy 'Sec
forall {k} (t :: k). Proxy t
Proxy :: Proxy Sec) ((forall s. MutableBlockN 'Sec N Zq s -> ST s ())
 -> BlockN 'Sec N Zq)
-> (forall s. MutableBlockN 'Sec N Zq s -> ST s ())
-> BlockN 'Sec N Zq
forall a b. (a -> b) -> a -> b
$ Word -> SecureBytes 'Sec -> MutableBlockN 'Sec N Zq s -> ST s ()
forall s.
Word -> SecureBytes 'Sec -> MutableBlockN 'Sec N Zq s -> ST s ()
mutSamplePolyCBD Word
eta SecureBytes 'Sec
input
{-# INLINE samplePolyCBD #-}

mutSamplePolyCBD :: Word -> SecureBytes Sec -> MutableBlockN Sec N Zq s -> ST s ()
mutSamplePolyCBD :: forall s.
Word -> SecureBytes 'Sec -> MutableBlockN 'Sec N Zq s -> ST s ()
mutSamplePolyCBD !Word
eta !SecureBytes 'Sec
input MutableBlockN 'Sec N Zq s
ff =
    ScrubbedBytes -> (Ptr WordM -> ST s ()) -> ST s ()
forall ba p s a.
ByteArrayAccess ba =>
ba -> (Ptr p -> ST s a) -> ST s a
ST.withByteArray ScrubbedBytes
SecureBytes 'Sec
input ((Ptr WordM -> ST s ()) -> ST s ())
-> (Ptr WordM -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Ptr WordM
p -> Ptr WordM
-> MutableBlockN 'Sec N Zq s -> Offset Zq -> BitPos -> ST s ()
forall s.
Ptr WordM
-> MutableBlockN 'Sec N Zq s -> Offset Zq -> BitPos -> ST s ()
loop Ptr WordM
p MutableBlockN 'Sec N Zq s
ff Offset Zq
0 BitPos
zeroPos
  where
    loop :: Ptr WordLE -> MutableBlockN Sec N Zq s -> Offset Zq -> BitPos -> ST s ()
    loop :: forall s.
Ptr WordM
-> MutableBlockN 'Sec N Zq s -> Offset Zq -> BitPos -> ST s ()
loop !Ptr WordM
p !MutableBlockN 'Sec N Zq s
f !Offset Zq
i !BitPos
bp = Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Offset Zq
i Offset Zq -> Offset Zq -> Bool
forall a. Ord a => a -> a -> Bool
< Int -> Offset Zq
forall ty. Int -> Offset ty
Offset Int
n) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
        (xs, bp') <- Ptr WordM -> BitPos -> Word16 -> Int -> ST s (Word16, BitPos)
forall s.
Ptr WordM -> BitPos -> Word16 -> Int -> ST s (Word16, BitPos)
getBits Ptr WordM
p BitPos
bp Word16
0 (Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word
eta)
        (ys, bp'') <- getBits p bp' 0 (fromIntegral eta)
        BlockN.write f i (Zq xs .- Zq ys)
        loop p f (i + 1) bp''

    getBits :: Ptr WordLE -> BitPos -> Word16 -> Int -> ST s (Word16, BitPos)
    getBits :: forall s.
Ptr WordM -> BitPos -> Word16 -> Int -> ST s (Word16, BitPos)
getBits !Ptr WordM
p !BitPos
bp !Word16
acc !Int
j
        | Int
j Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0    = (Word16, BitPos) -> ST s (Word16, BitPos)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (Word16
acc, BitPos
bp)
        | Bool
otherwise = do
            x <- (WordM -> Int -> WordM
forall a. Bits a => a -> Int -> a
`unsafeShiftR` BitPos -> Int
bitPos BitPos
bp) (WordM -> WordM) -> ST s WordM -> ST s WordM
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr WordM -> BitPos -> ST s WordM
forall s. Ptr WordM -> BitPos -> ST s WordM
peekWordPos Ptr WordM
p BitPos
bp
            let (howMany, bp') = nextPos j bp
                bits = WordM
x WordM -> WordM -> WordM
forall a. Bits a => a -> a -> a
.&. Int -> WordM
getMask Int
howMany
            getBits p bp' (acc + fromIntegral (popCount bits)) (j - howMany)
{-# NOINLINE mutSamplePolyCBD #-}

-- Encodes an array of 𝑑-bit integers into a byte array for 1 ≤ 𝑑 ≤ 12
byteEncode :: Int -> BlockN marking N Word16 -> Builder marking
byteEncode :: forall (marking :: SecurityMarking).
Int -> BlockN marking N Word16 -> Builder marking
byteEncode Int
d BlockN marking N Word16
f = Int -> (forall s. Ptr WordM -> ST s ()) -> Builder marking
forall a (marking :: SecurityMarking).
Int -> (forall s. Ptr a -> ST s ()) -> Builder marking
Builder.create (Int
32 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
d) (Int -> BlockN marking N Word16 -> Ptr WordM -> ST s ()
forall (marking :: SecurityMarking) s.
Int -> BlockN marking N Word16 -> Ptr WordM -> ST s ()
runByteEncode Int
d BlockN marking N Word16
f)
{-# INLINE byteEncode #-}

runByteEncode :: Int -> BlockN marking N Word16 -> Ptr WordLE -> ST s ()
runByteEncode :: forall (marking :: SecurityMarking) s.
Int -> BlockN marking N Word16 -> Ptr WordM -> ST s ()
runByteEncode !Int
d !BlockN marking N Word16
f Ptr WordM
dst = Ptr WordM -> Int -> BitPos -> WordM -> Word16 -> Int -> ST s ()
loop Ptr WordM
dst Int
0 BitPos
zeroPos WordM
0 (Int -> Word16
get Int
0) Int
d
  where
    get :: Int -> Word16
get = BlockN marking N Word16 -> Offset Word16 -> Word16
forall a (marking :: SecurityMarking) (n :: Nat).
PrimType a =>
BlockN marking n a -> Offset a -> a
BlockN.index BlockN marking N Word16
f (Offset Word16 -> Word16)
-> (Int -> Offset Word16) -> Int -> Word16
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Offset Word16
forall ty. Int -> Offset ty
Offset
    {-# INLINE get #-}

    loop :: Ptr WordM -> Int -> BitPos -> WordM -> Word16 -> Int -> ST s ()
loop !Ptr WordM
b !Int
pos !BitPos
bp !WordM
o !Word16
a Int
j
        | Int
j Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0, Int
pos' Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n = () -> ST s ()
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        | Int
j Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Ptr WordM -> Int -> BitPos -> WordM -> Word16 -> Int -> ST s ()
loop Ptr WordM
b Int
pos' BitPos
bp WordM
o (Int -> Word16
get Int
pos') Int
d
        | BitPos -> Int
bitPos BitPos
bp Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
howMany Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
wordBits = Ptr WordM -> Int -> BitPos -> WordM -> Word16 -> Int -> ST s ()
loop Ptr WordM
b Int
pos BitPos
bp' WordM
o' Word16
a' Int
j'
        | Bool
otherwise = Ptr WordM -> BitPos -> WordM -> ST s ()
forall s. Ptr WordM -> BitPos -> WordM -> ST s ()
pokeWordPos Ptr WordM
b BitPos
bp WordM
o' ST s () -> ST s () -> ST s ()
forall a b. ST s a -> ST s b -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Ptr WordM -> Int -> BitPos -> WordM -> Word16 -> Int -> ST s ()
loop Ptr WordM
b Int
pos BitPos
bp' WordM
0 Word16
a' Int
j'
      where
        pos' :: Int
pos' = Int
pos Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
        (Int
howMany, BitPos
bp') = Int -> BitPos -> (Int, BitPos)
nextPos Int
j BitPos
bp
        x :: WordM
x = Word16 -> WordM
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
a WordM -> WordM -> WordM
forall a. Bits a => a -> a -> a
.&. Int -> WordM
getMask Int
howMany
        o' :: WordM
o' = WordM
o WordM -> WordM -> WordM
forall a. Bits a => a -> a -> a
.|. (WordM
x WordM -> Int -> WordM
forall a. Bits a => a -> Int -> a
`unsafeShiftL` BitPos -> Int
bitPos BitPos
bp)
        a' :: Word16
a' = Word16
a Word16 -> Int -> Word16
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
howMany
        j' :: Int
j' = Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
howMany

-- Optimization of byteEncode when 𝑑=1
byteEncode1 :: BlockN Sec N Word16 -> Builder Sec
byteEncode1 :: BlockN 'Sec N Word16 -> Builder 'Sec
byteEncode1 !BlockN 'Sec N Word16
f = Int -> (forall s. Ptr WordM -> ST s ()) -> Builder 'Sec
forall a (marking :: SecurityMarking).
Int -> (forall s. Ptr a -> ST s ()) -> Builder marking
Builder.create Int
32 (BlockN 'Sec N Word16 -> Ptr WordM -> ST s ()
forall (marking :: SecurityMarking) s.
BlockN marking N Word16 -> Ptr WordM -> ST s ()
runByteEncode1 BlockN 'Sec N Word16
f)
{-# INLINE byteEncode1 #-}

runByteEncode1 :: BlockN marking N Word16 -> Ptr WordLE -> ST s ()
runByteEncode1 :: forall (marking :: SecurityMarking) s.
BlockN marking N Word16 -> Ptr WordM -> ST s ()
runByteEncode1 !BlockN marking N Word16
f Ptr WordM
dst = Ptr WordM -> WordM -> Int -> ST s ()
forall s. Ptr WordM -> WordM -> Int -> ST s ()
loop Ptr WordM
dst WordM
0 Int
0
  where
    loop :: Ptr WordLE -> WordM -> Int -> ST s ()
    loop :: forall s. Ptr WordM -> WordM -> Int -> ST s ()
loop !Ptr WordM
b !WordM
o Int
pos
        | Int
pos Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n = () -> ST s ()
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        | BitPos -> Int
bitPos BitPos
bp Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
wordBits = Ptr WordM -> WordM -> Int -> ST s ()
forall s. Ptr WordM -> WordM -> Int -> ST s ()
loop Ptr WordM
b WordM
o' (Int
pos Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
        | Bool
otherwise = Ptr WordM -> BitPos -> WordM -> ST s ()
forall s. Ptr WordM -> BitPos -> WordM -> ST s ()
pokeWordPos Ptr WordM
b BitPos
bp WordM
o' ST s () -> ST s () -> ST s ()
forall a b. ST s a -> ST s b -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Ptr WordM -> WordM -> Int -> ST s ()
forall s. Ptr WordM -> WordM -> Int -> ST s ()
loop Ptr WordM
b WordM
0 (Int
pos Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
      where
        bp :: BitPos
bp = Int -> BitPos
BitPos Int
pos
        x :: WordM
x = Word16 -> WordM
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word16
a Word16 -> Word16 -> Word16
forall a. Bits a => a -> a -> a
.&. Word16
1)
        o' :: WordM
o' = WordM
o WordM -> WordM -> WordM
forall a. Bits a => a -> a -> a
.|. (WordM
x WordM -> Int -> WordM
forall a. Bits a => a -> Int -> a
`unsafeShiftL` BitPos -> Int
bitPos BitPos
bp)
        a :: Word16
a = BlockN marking N Word16 -> Offset Word16 -> Word16
forall a (marking :: SecurityMarking) (n :: Nat).
PrimType a =>
BlockN marking n a -> Offset a -> a
BlockN.index BlockN marking N Word16
f (Int -> Offset Word16
forall ty. Int -> Offset ty
Offset Int
pos)

-- byteEncode with 𝑑=12 after conversion from the field
byteEncode12 :: Tq marking -> Builder marking
byteEncode12 :: forall (marking :: SecurityMarking). Tq marking -> Builder marking
byteEncode12 = Int -> BlockN marking N Word16 -> Builder marking
forall (marking :: SecurityMarking).
Int -> BlockN marking N Word16 -> Builder marking
byteEncode Int
12 (BlockN marking N Word16 -> Builder marking)
-> (Tq marking -> BlockN marking N Word16)
-> Tq marking
-> Builder marking
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tq marking -> BlockN marking N Word16
forall (marking :: SecurityMarking).
Tq marking -> BlockN marking N Word16
fromField
  where
    fromField :: Tq marking -> BlockN marking N Word16
    fromField :: forall (marking :: SecurityMarking).
Tq marking -> BlockN marking N Word16
fromField (Tq BlockN marking N Zq
f) = BlockN marking N Zq -> BlockN marking N Word16
forall a b. a -> b
unsafeCoerce BlockN marking N Zq
f
{-# INLINE byteEncode12 #-}

-- Decodes a byte array into an array of 𝑑-bit integers for 1 ≤ 𝑑 ≤ 12
byteDecode :: forall marking ba. (Classified marking, ByteArrayAccess ba) => Int -> ba -> BlockN marking N Word16
byteDecode :: forall (marking :: SecurityMarking) ba.
(Classified marking, ByteArrayAccess ba) =>
Int -> ba -> BlockN marking N Word16
byteDecode Int
d ba
b = Proxy marking
-> (forall s. MutableBlockN marking N Word16 s -> ST s ())
-> BlockN marking N Word16
forall (marking :: SecurityMarking) (n :: Nat) a
       (proxy :: SecurityMarking -> *).
(Classified marking, KnownNat n, PrimType a) =>
proxy marking
-> (forall s. MutableBlockN marking n a s -> ST s ())
-> BlockN marking n a
BlockN.runNew (Proxy marking
forall {k} (t :: k). Proxy t
Proxy :: Proxy marking) ((forall s. MutableBlockN marking N Word16 s -> ST s ())
 -> BlockN marking N Word16)
-> (forall s. MutableBlockN marking N Word16 s -> ST s ())
-> BlockN marking N Word16
forall a b. (a -> b) -> a -> b
$ Int -> ba -> MutableBlockN marking N Word16 s -> ST s ()
forall ba (marking :: SecurityMarking) s.
ByteArrayAccess ba =>
Int -> ba -> MutableBlockN marking N Word16 s -> ST s ()
mutByteDecode Int
d ba
b
{-# INLINE byteDecode #-}

mutByteDecode :: ByteArrayAccess ba => Int -> ba -> MutableBlockN marking N Word16 s -> ST s ()
mutByteDecode :: forall ba (marking :: SecurityMarking) s.
ByteArrayAccess ba =>
Int -> ba -> MutableBlockN marking N Word16 s -> ST s ()
mutByteDecode !Int
d !ba
b !MutableBlockN marking N Word16 s
f = ba -> (Ptr WordM -> ST s ()) -> ST s ()
forall ba p s a.
ByteArrayAccess ba =>
ba -> (Ptr p -> ST s a) -> ST s a
ST.withByteArray ba
b ((Ptr WordM -> ST s ()) -> ST s ())
-> (Ptr WordM -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Ptr WordM
p -> Ptr WordM -> BitPos -> Offset Word16 -> ST s ()
outer Ptr WordM
p BitPos
zeroPos Offset Word16
0
  where
    outer :: Ptr WordM -> BitPos -> Offset Word16 -> ST s ()
outer !Ptr WordM
p !BitPos
bp Offset Word16
i = Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Offset Word16
i Offset Word16 -> Offset Word16 -> Bool
forall a. Ord a => a -> a -> Bool
< Int -> Offset Word16
forall ty. Int -> Offset ty
Offset Int
n) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ Ptr WordM -> Offset Word16 -> BitPos -> Word16 -> Int -> ST s ()
inner Ptr WordM
p Offset Word16
i BitPos
bp Word16
0 Int
0

    inner :: Ptr WordM -> Offset Word16 -> BitPos -> Word16 -> Int -> ST s ()
inner !Ptr WordM
p !Offset Word16
i !BitPos
bp !Word16
v Int
j
        | Int
j Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
d = MutableBlockN marking N Word16 (PrimState (ST s))
-> Offset Word16 -> Word16 -> ST s ()
forall (prim :: * -> *) a (marking :: SecurityMarking) (n :: Nat).
(PrimMonad prim, PrimType a) =>
MutableBlockN marking n a (PrimState prim)
-> Offset a -> a -> prim ()
BlockN.write MutableBlockN marking N Word16 s
MutableBlockN marking N Word16 (PrimState (ST s))
f Offset Word16
i Word16
v ST s () -> ST s () -> ST s ()
forall a b. ST s a -> ST s b -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Ptr WordM -> BitPos -> Offset Word16 -> ST s ()
outer Ptr WordM
p BitPos
bp (Offset Word16
i Offset Word16 -> Offset Word16 -> Offset Word16
forall a. Num a => a -> a -> a
+ Offset Word16
1)
        | Bool
otherwise = do
            let (Int
howMany, BitPos
bp') = Int -> BitPos -> (Int, BitPos)
nextPos (Int
d Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
j) BitPos
bp
            y <- Ptr WordM -> BitPos -> Int -> ST s WordM
forall s. Ptr WordM -> BitPos -> Int -> ST s WordM
get Ptr WordM
p BitPos
bp Int
howMany
            let v' = Word16
v Word16 -> Word16 -> Word16
forall a. Bits a => a -> a -> a
.|. (WordM -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral WordM
y Word16 -> Int -> Word16
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
j)
                j' = Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
howMany
            inner p i bp' v' j'

    get :: Ptr WordLE -> BitPos -> Int -> ST s WordM
    get :: forall s. Ptr WordM -> BitPos -> Int -> ST s WordM
get Ptr WordM
p BitPos
bp Int
howMany = do
        x <- (WordM -> Int -> WordM
forall a. Bits a => a -> Int -> a
`unsafeShiftR` BitPos -> Int
bitPos BitPos
bp) (WordM -> WordM) -> ST s WordM -> ST s WordM
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr WordM -> BitPos -> ST s WordM
forall s. Ptr WordM -> BitPos -> ST s WordM
peekWordPos Ptr WordM
p BitPos
bp
        return (x .&. getMask howMany)
{-# SPECIALIZE mutByteDecode :: forall marking s. Int -> View Bytes -> MutableBlockN marking N Word16 s -> ST s () #-}

-- Optimization of byteDecode when 𝑑=1
byteDecode1 :: ByteArrayAccess ba => ba -> BlockN Sec N Word16
byteDecode1 :: forall ba. ByteArrayAccess ba => ba -> BlockN 'Sec N Word16
byteDecode1 ba
b = Proxy 'Sec
-> (forall s. MutableBlockN 'Sec N Word16 s -> ST s ())
-> BlockN 'Sec N Word16
forall (marking :: SecurityMarking) (n :: Nat) a
       (proxy :: SecurityMarking -> *).
(Classified marking, KnownNat n, PrimType a) =>
proxy marking
-> (forall s. MutableBlockN marking n a s -> ST s ())
-> BlockN marking n a
BlockN.runNew (Proxy 'Sec
forall {k} (t :: k). Proxy t
Proxy :: Proxy Sec) ((forall s. MutableBlockN 'Sec N Word16 s -> ST s ())
 -> BlockN 'Sec N Word16)
-> (forall s. MutableBlockN 'Sec N Word16 s -> ST s ())
-> BlockN 'Sec N Word16
forall a b. (a -> b) -> a -> b
$ ba -> MutableBlockN 'Sec N Word16 s -> ST s ()
forall ba s.
ByteArrayAccess ba =>
ba -> MutableBlockN 'Sec N Word16 s -> ST s ()
mutByteDecode1 ba
b
{-# INLINE byteDecode1 #-}

mutByteDecode1 :: ByteArrayAccess ba => ba -> MutableBlockN Sec N Word16 s -> ST s ()
mutByteDecode1 :: forall ba s.
ByteArrayAccess ba =>
ba -> MutableBlockN 'Sec N Word16 s -> ST s ()
mutByteDecode1 !ba
b !MutableBlockN 'Sec N Word16 s
f = ba -> (Ptr WordM -> ST s ()) -> ST s ()
forall ba p s a.
ByteArrayAccess ba =>
ba -> (Ptr p -> ST s a) -> ST s a
ST.withByteArray ba
b ((Ptr WordM -> ST s ()) -> ST s ())
-> (Ptr WordM -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Ptr WordM
p -> Ptr WordM -> Int -> ST s ()
outer Ptr WordM
p Int
0
  where
    outer :: Ptr WordM -> Int -> ST s ()
outer !Ptr WordM
p Int
i = Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
        x <- Ptr WordM -> ST s WordM
forall s. Ptr WordM -> ST s WordM
peekWord Ptr WordM
p
        inner (p `plusPtr` wordBytes) x i 0

    inner :: Ptr WordM -> WordM -> Int -> Int -> ST s ()
inner !Ptr WordM
p !WordM
acc !Int
i Int
j
        | Int
j Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
wordBits = Ptr WordM -> Int -> ST s ()
outer Ptr WordM
p Int
i
        | Bool
otherwise = do
            let v :: Word16
v = WordM -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral (WordM
acc WordM -> WordM -> WordM
forall a. Bits a => a -> a -> a
.&. WordM
1)
            MutableBlockN 'Sec N Word16 (PrimState (ST s))
-> Offset Word16 -> Word16 -> ST s ()
forall (prim :: * -> *) a (marking :: SecurityMarking) (n :: Nat).
(PrimMonad prim, PrimType a) =>
MutableBlockN marking n a (PrimState prim)
-> Offset a -> a -> prim ()
BlockN.write MutableBlockN 'Sec N Word16 s
MutableBlockN 'Sec N Word16 (PrimState (ST s))
f (Int -> Offset Word16
forall ty. Int -> Offset ty
Offset Int
i) Word16
v
            Ptr WordM -> WordM -> Int -> Int -> ST s ()
inner Ptr WordM
p (WordM
acc WordM -> Int -> WordM
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
1) (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)

-- byteDecode with 𝑑=12 and conversion to the field
byteDecode12 :: (Classified marking, ByteArrayAccess ba) => ba -> Tq marking
byteDecode12 :: forall (marking :: SecurityMarking) ba.
(Classified marking, ByteArrayAccess ba) =>
ba -> Tq marking
byteDecode12 = BlockN marking N Zq -> Tq marking
forall (marking :: SecurityMarking).
BlockN marking N Zq -> Tq marking
Tq (BlockN marking N Zq -> Tq marking)
-> (ba -> BlockN marking N Zq) -> ba -> Tq marking
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Word16 -> Zq) -> BlockN marking N Word16 -> BlockN marking N Zq
forall (marking :: SecurityMarking) (n :: Nat) a b.
(Classified marking, KnownNat n, EqPrimSize a b) =>
(a -> b) -> BlockN marking n a -> BlockN marking n b
BlockN.mapEqPrimSize Word16 -> Zq
toZq (BlockN marking N Word16 -> BlockN marking N Zq)
-> (ba -> BlockN marking N Word16) -> ba -> BlockN marking N Zq
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> ba -> BlockN marking N Word16
forall (marking :: SecurityMarking) ba.
(Classified marking, ByteArrayAccess ba) =>
Int -> ba -> BlockN marking N Word16
byteDecode Int
12
{-# INLINE byteDecode12 #-}