module Codec.Compression.Zstd.Internal
(
CCtx(..)
, DCtx(..)
, compressWith
, decompressWith
, decompressedSize
, withCCtx
, withDCtx
, withDict
, trainFromSamples
, getDictID
) where
import Codec.Compression.Zstd.Types (Decompress(..), Dict(..))
import Control.Exception.Base (bracket)
import Data.ByteString.Internal (ByteString(..))
import Data.Word (Word, Word8)
import Foreign.C.Types (CInt, CSize)
import Foreign.Marshal.Array (withArray)
import Foreign.ForeignPtr (withForeignPtr)
import Foreign.Ptr (Ptr, castPtr, plusPtr)
import System.IO.Unsafe (unsafePerformIO)
import qualified Codec.Compression.Zstd.FFI as C
import qualified Data.ByteString as B
import qualified Data.ByteString.Internal as B
compressWith
:: String
-> (Ptr Word8 -> CSize -> Ptr Word8 -> CSize -> CInt -> IO CSize)
-> Int
-> ByteString
-> IO ByteString
compressWith :: String
-> (Ptr Word8 -> CSize -> Ptr Word8 -> CSize -> CInt -> IO CSize)
-> Int
-> ByteString
-> IO ByteString
compressWith String
name Ptr Word8 -> CSize -> Ptr Word8 -> CSize -> CInt -> IO CSize
compressor Int
level (PS ForeignPtr Word8
sfp Int
off Int
len)
| Int
level Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
1 Bool -> Bool -> Bool
|| Int
level Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
C.maxCLevel
= String -> String -> IO ByteString
forall a. String -> String -> a
bail String
name String
"unsupported compression level"
| Bool
otherwise =
ForeignPtr Word8 -> (Ptr Word8 -> IO ByteString) -> IO ByteString
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
sfp ((Ptr Word8 -> IO ByteString) -> IO ByteString)
-> (Ptr Word8 -> IO ByteString) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
sp -> do
maxSize <- CSize -> IO CSize
C.compressBound (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len)
dfp <- B.mallocByteString (fromIntegral maxSize)
withForeignPtr dfp $ \Ptr Word8
dst -> do
let src :: Ptr b
src = Ptr Word8
sp Ptr Word8 -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
off
csz <- Ptr Word8 -> CSize -> Ptr Word8 -> CSize -> CInt -> IO CSize
compressor Ptr Word8
dst CSize
maxSize Ptr Word8
forall {b}. Ptr b
src (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len) (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
level)
handleError csz name $ do
let size = CSize -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CSize
csz
if csz < 128 || csz >= maxSize `div` 2
then return (PS dfp 0 size)
else B.create size $ \Ptr Word8
p -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
B.memcpy Ptr Word8
p Ptr Word8
dst Int
size
decompressedSize :: ByteString -> Maybe Int
decompressedSize :: ByteString -> Maybe Int
decompressedSize (PS ForeignPtr Word8
fp Int
off Int
len) =
IO (Maybe Int) -> Maybe Int
forall a. IO a -> a
unsafePerformIO (IO (Maybe Int) -> Maybe Int)
-> ((Ptr Word8 -> IO (Maybe Int)) -> IO (Maybe Int))
-> (Ptr Word8 -> IO (Maybe Int))
-> Maybe Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ForeignPtr Word8 -> (Ptr Word8 -> IO (Maybe Int)) -> IO (Maybe Int)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
fp ((Ptr Word8 -> IO (Maybe Int)) -> Maybe Int)
-> (Ptr Word8 -> IO (Maybe Int)) -> Maybe Int
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr -> do
sz <- Ptr (ZonkAny 2) -> CSize -> IO CULLong
forall src. Ptr src -> CSize -> IO CULLong
C.getDecompressedSize (Ptr Word8
ptr Ptr Word8 -> Int -> Ptr (ZonkAny 2)
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
off) (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len)
return $ if sz == 0 || sz > fromIntegral (maxBound :: Int)
then Nothing
else Just (fromIntegral sz)
decompressWith :: (Ptr Word8 -> CSize -> Ptr Word8 -> CSize -> IO CSize)
-> ByteString
-> IO Decompress
decompressWith :: (Ptr Word8 -> CSize -> Ptr Word8 -> CSize -> IO CSize)
-> ByteString -> IO Decompress
decompressWith Ptr Word8 -> CSize -> Ptr Word8 -> CSize -> IO CSize
decompressor (PS ForeignPtr Word8
sfp Int
off Int
len) = do
ForeignPtr Word8 -> (Ptr Word8 -> IO Decompress) -> IO Decompress
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
sfp ((Ptr Word8 -> IO Decompress) -> IO Decompress)
-> (Ptr Word8 -> IO Decompress) -> IO Decompress
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
sp -> do
let src :: Ptr b
src = Ptr Word8
sp Ptr Word8 -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
off
dstSize <- Ptr (ZonkAny 1) -> CSize -> IO CULLong
forall src. Ptr src -> CSize -> IO CULLong
C.getDecompressedSize Ptr (ZonkAny 1)
forall {b}. Ptr b
src (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len)
if dstSize == 0
then return Skip
else if dstSize > fromIntegral (maxBound :: Int)
then return (Error "invalid compressed payload size")
else do
dfp <- B.mallocByteString (fromIntegral dstSize)
size <- withForeignPtr dfp $ \Ptr Word8
dst ->
Ptr Word8 -> CSize -> Ptr Word8 -> CSize -> IO CSize
decompressor Ptr Word8
dst (CULLong -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral CULLong
dstSize) Ptr Word8
forall {b}. Ptr b
src (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len)
return $ if C.isError size
then Error (C.getErrorName size)
else Decompress (PS dfp 0 (fromIntegral size))
newtype CCtx = CCtx { CCtx -> Ptr CCtx
getCCtx :: Ptr C.CCtx }
withCCtx :: (CCtx -> IO a) -> IO a
withCCtx :: forall a. (CCtx -> IO a) -> IO a
withCCtx CCtx -> IO a
act =
IO CCtx -> (CCtx -> IO ()) -> (CCtx -> IO a) -> IO a
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket ((Ptr CCtx -> CCtx) -> IO (Ptr CCtx) -> IO CCtx
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Ptr CCtx -> CCtx
CCtx (String -> IO (Ptr CCtx) -> IO (Ptr CCtx)
forall a. String -> IO (Ptr a) -> IO (Ptr a)
C.checkAlloc String
"withCCtx" IO (Ptr CCtx)
C.createCCtx))
(Ptr CCtx -> IO ()
C.freeCCtx (Ptr CCtx -> IO ()) -> (CCtx -> Ptr CCtx) -> CCtx -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CCtx -> Ptr CCtx
getCCtx) CCtx -> IO a
act
newtype DCtx = DCtx { DCtx -> Ptr DCtx
getDCtx :: Ptr C.DCtx }
withDCtx :: (DCtx -> IO a) -> IO a
withDCtx :: forall a. (DCtx -> IO a) -> IO a
withDCtx DCtx -> IO a
act =
IO DCtx -> (DCtx -> IO ()) -> (DCtx -> IO a) -> IO a
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket ((Ptr DCtx -> DCtx) -> IO (Ptr DCtx) -> IO DCtx
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Ptr DCtx -> DCtx
DCtx (String -> IO (Ptr DCtx) -> IO (Ptr DCtx)
forall a. String -> IO (Ptr a) -> IO (Ptr a)
C.checkAlloc String
"withDCtx" IO (Ptr DCtx)
C.createDCtx))
(Ptr DCtx -> IO ()
C.freeDCtx (Ptr DCtx -> IO ()) -> (DCtx -> Ptr DCtx) -> DCtx -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DCtx -> Ptr DCtx
getDCtx) DCtx -> IO a
act
withDict :: Dict -> (Ptr dict -> CSize -> IO a) -> IO a
withDict :: forall dict a. Dict -> (Ptr dict -> CSize -> IO a) -> IO a
withDict (Dict (PS ForeignPtr Word8
fp Int
off Int
len)) Ptr dict -> CSize -> IO a
act =
ForeignPtr Word8 -> (Ptr Word8 -> IO a) -> IO a
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
fp ((Ptr Word8 -> IO a) -> IO a) -> (Ptr Word8 -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr -> Ptr dict -> CSize -> IO a
act (Ptr Word8
ptr Ptr Word8 -> Int -> Ptr dict
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
off) (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len)
trainFromSamples :: Int
-> [ByteString]
-> Either String Dict
trainFromSamples :: Int -> [ByteString] -> Either String Dict
trainFromSamples Int
capacity [ByteString]
samples = IO (Either String Dict) -> Either String Dict
forall a. IO a -> a
unsafePerformIO (IO (Either String Dict) -> Either String Dict)
-> IO (Either String Dict) -> Either String Dict
forall a b. (a -> b) -> a -> b
$
[Int]
-> (Ptr Int -> IO (Either String Dict)) -> IO (Either String Dict)
forall a b. Storable a => [a] -> (Ptr a -> IO b) -> IO b
withArray ((ByteString -> Int) -> [ByteString] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ByteString -> Int
B.length [ByteString]
samples) ((Ptr Int -> IO (Either String Dict)) -> IO (Either String Dict))
-> (Ptr Int -> IO (Either String Dict)) -> IO (Either String Dict)
forall a b. (a -> b) -> a -> b
$ \Ptr Int
sizes -> do
dfp <- Int -> IO (ForeignPtr Word8)
forall a. Int -> IO (ForeignPtr a)
B.mallocByteString Int
capacity
let PS sfp _ _ = B.concat samples
withForeignPtr dfp $ \Ptr Word8
dict ->
ForeignPtr Word8
-> (Ptr Word8 -> IO (Either String Dict))
-> IO (Either String Dict)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
sfp ((Ptr Word8 -> IO (Either String Dict)) -> IO (Either String Dict))
-> (Ptr Word8 -> IO (Either String Dict))
-> IO (Either String Dict)
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
sampPtr -> do
dsz <- Ptr Word8 -> CSize -> Ptr Word8 -> Ptr CSize -> CUInt -> IO CSize
forall dict samples.
Ptr dict -> CSize -> Ptr samples -> Ptr CSize -> CUInt -> IO CSize
C.trainFromBuffer
Ptr Word8
dict (Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
capacity)
Ptr Word8
sampPtr (Ptr Int -> Ptr CSize
forall a b. Ptr a -> Ptr b
castPtr Ptr Int
sizes) (Int -> CUInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral ([ByteString] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ByteString]
samples))
if C.isError dsz
then return (Left (C.getErrorName dsz))
else fmap (Right . Dict) $ do
let size = CSize -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral CSize
dsz
if size < 128 || size >= capacity `div` 2
then return (PS dfp 0 size)
else B.create size $ \Ptr Word8
p -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
B.memcpy Ptr Word8
p Ptr Word8
dict Int
size
getDictID :: Dict -> Maybe Word
getDictID :: Dict -> Maybe Word
getDictID Dict
dict = IO (Maybe Word) -> Maybe Word
forall a. IO a -> a
unsafePerformIO (IO (Maybe Word) -> Maybe Word) -> IO (Maybe Word) -> Maybe Word
forall a b. (a -> b) -> a -> b
$ do
n <- Dict -> (Ptr (ZonkAny 0) -> CSize -> IO CUInt) -> IO CUInt
forall dict a. Dict -> (Ptr dict -> CSize -> IO a) -> IO a
withDict Dict
dict Ptr (ZonkAny 0) -> CSize -> IO CUInt
forall dict. Ptr dict -> CSize -> IO CUInt
C.getDictID
return $! if n == 0
then Nothing
else Just (fromIntegral n)
handleError :: CSize -> String -> IO a -> IO a
handleError :: forall a. CSize -> String -> IO a -> IO a
handleError CSize
sizeOrError String
func IO a
act
| CSize -> Bool
C.isError CSize
sizeOrError
= String -> String -> IO a
forall a. String -> String -> a
bail String
func (CSize -> String
C.getErrorName CSize
sizeOrError)
| Bool
otherwise = IO a
act
bail :: String -> String -> a
bail :: forall a. String -> String -> a
bail String
func String
str = String -> a
forall a. HasCallStack => String -> a
error (String -> a) -> String -> a
forall a b. (a -> b) -> a -> b
$ String
"Codec.Compression.Zstd." String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
func String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
str