{-# LANGUAGE RecordWildCards #-}
module Data.Conduit.BZlib (
  compress,
  decompress1,
  decompress,

  bzip2,
  bunzip2,

  CompressParams(..),
  DecompressParams(..),
  def,
  ) where

import Control.Monad as CM
import Control.Monad.Trans
import Control.Monad.Trans.Resource
import qualified Data.ByteString as S
import qualified Data.ByteString.Unsafe as S
import Data.Conduit
import Data.Default.Class
import Data.Maybe
import Data.IORef
import Foreign
import Foreign.C

import Data.Conduit.BZlib.Internal

-- | Compression parameters

data CompressParams
  = CompressParams
    { CompressParams -> Int
cpBlockSize  :: Int -- ^ Compress level [1..9]. default is 9.

    , CompressParams -> Int
cpVerbosity  :: Int -- ^ Verbosity mode [0..4]. default is 0.

    , CompressParams -> Int
cpWorkFactor :: Int -- ^ Work factor [0..250]. default is 30.

    }

instance Default CompressParams where
  def :: CompressParams
def = Int -> Int -> Int -> CompressParams
CompressParams Int
9 Int
0 Int
30

-- | Decompression parameters

data DecompressParams
  = DecompressParams
    { DecompressParams -> Int
dpVerbosity :: Int -- ^ Verbosity mode [0..4]. default is 0

    , DecompressParams -> Bool
dpSmall     :: Bool -- ^ If True, use an algorithm uses less memory but slow. default is False

    }

instance Default DecompressParams where
  def :: DecompressParams
def = Int -> Bool -> DecompressParams
DecompressParams Int
0 Bool
False

bufSize :: Int
bufSize :: Int
bufSize = Int
4096

yieldAvailOutput :: MonadIO m => Ptr C'bz_stream -> ConduitT S.ByteString S.ByteString m ()
yieldAvailOutput :: forall (m :: * -> *).
MonadIO m =>
Ptr C'bz_stream -> ConduitT ByteString ByteString m ()
yieldAvailOutput Ptr C'bz_stream
ptr = do
  availOut <- IO Int -> ConduitT ByteString ByteString m Int
forall a. IO a -> ConduitT ByteString ByteString m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Int -> ConduitT ByteString ByteString m Int)
-> IO Int -> ConduitT ByteString ByteString m Int
forall a b. (a -> b) -> a -> b
$ CUInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CUInt -> Int) -> IO CUInt -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Ptr CUInt -> IO CUInt
forall a. Storable a => Ptr a -> IO a
peek (Ptr CUInt -> IO CUInt) -> Ptr CUInt -> IO CUInt
forall a b. (a -> b) -> a -> b
$ Ptr C'bz_stream -> Ptr CUInt
p'bz_stream'avail_out Ptr C'bz_stream
ptr)
  when (availOut < bufSize) $
    yieldM $ liftIO $ do
          let len = Int
bufSize Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
availOut
          p <- (`plusPtr` (-len)) <$> (peek $ p'bz_stream'next_out ptr)
          out <- S.packCStringLen (p, fromIntegral len)
          poke (p'bz_stream'next_out ptr) p
          poke (p'bz_stream'avail_out ptr) (fromIntegral bufSize)
          return out

fillInput :: Ptr C'bz_stream -> IORef (Ptr CChar, Int) -> S.ByteString -> IO ()
fillInput :: Ptr C'bz_stream -> IORef CStringLen -> ByteString -> IO ()
fillInput Ptr C'bz_stream
ptr IORef CStringLen
mv ByteString
bs = ByteString -> (CStringLen -> IO ()) -> IO ()
forall a. ByteString -> (CStringLen -> IO a) -> IO a
S.unsafeUseAsCStringLen ByteString
bs ((CStringLen -> IO ()) -> IO ()) -> (CStringLen -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \(Ptr CChar
p, Int
len) -> do
  (buf, bsize) <- IORef CStringLen -> IO CStringLen
forall a. IORef a -> IO a
readIORef IORef CStringLen
mv
  let nsize = [Int] -> Int
forall a. HasCallStack => [a] -> a
head [ Int
s | Int
x <- [Int
0..], let s :: Int
s = Int
bsize Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
2 Int -> Int -> Int
forall a b. (Num a, Integral b) => a -> b -> a
^ (Int
x :: Int), Int
s Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
len ]
  nbuf <- if nsize >= bsize then reallocBytes buf nsize else return buf
  copyBytes nbuf p len
  poke (p'bz_stream'avail_in ptr) $ fromIntegral len
  poke (p'bz_stream'next_in ptr) nbuf
  writeIORef mv (nbuf, nsize)

throwIfMinus :: String -> IO CInt -> IO CInt
throwIfMinus :: String -> IO CInt -> IO CInt
throwIfMinus String
s IO CInt
m = do
  r <- IO CInt
m
  when (r < 0) $ throwM $ userError $ s ++ ": " ++ show r
  return r

throwIfMinus_ :: String -> IO CInt -> IO ()
throwIfMinus_ :: String -> IO CInt -> IO ()
throwIfMinus_ String
s IO CInt
m = IO CInt -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
CM.void (IO CInt -> IO ()) -> IO CInt -> IO ()
forall a b. (a -> b) -> a -> b
$ String -> IO CInt -> IO CInt
throwIfMinus String
s IO CInt
m

allocateStream :: MonadResource m => m (Ptr C'bz_stream, IORef (Ptr CChar, Int))
allocateStream :: forall (m :: * -> *).
MonadResource m =>
m (Ptr C'bz_stream, IORef CStringLen)
allocateStream = do
  (_, ptr)    <- IO (Ptr C'bz_stream)
-> (Ptr C'bz_stream -> IO ()) -> m (ReleaseKey, Ptr C'bz_stream)
forall (m :: * -> *) a.
MonadResource m =>
IO a -> (a -> IO ()) -> m (ReleaseKey, a)
allocate IO (Ptr C'bz_stream)
forall a. Storable a => IO (Ptr a)
malloc Ptr C'bz_stream -> IO ()
forall a. Ptr a -> IO ()
free
  (_, inbuf)  <- allocate (mallocBytes bufSize >>= \Ptr CChar
p -> CStringLen -> IO (IORef CStringLen)
forall a. a -> IO (IORef a)
newIORef (Ptr CChar
p, Int
bufSize))
                          (\IORef CStringLen
mv -> IORef CStringLen -> IO CStringLen
forall a. IORef a -> IO a
readIORef IORef CStringLen
mv IO CStringLen -> (CStringLen -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \(Ptr CChar
p, Int
_) -> Ptr CChar -> IO ()
forall a. Ptr a -> IO ()
free Ptr CChar
p)
  (_, outbuf) <- allocate (mallocBytes bufSize) free
  liftIO $ poke ptr $ C'bz_stream
    { c'bz_stream'next_in        = nullPtr
    , c'bz_stream'avail_in       = 0
    , c'bz_stream'total_in_lo32  = 0
    , c'bz_stream'total_in_hi32  = 0
    , c'bz_stream'next_out       = outbuf
    , c'bz_stream'avail_out      = fromIntegral bufSize
    , c'bz_stream'total_out_lo32 = 0
    , c'bz_stream'total_out_hi32 = 0
    , c'bz_stream'state          = nullPtr
    , c'bz_stream'bzalloc        = nullPtr
    , c'bz_stream'bzfree         = nullPtr
    , c'bz_stream'opaque         = nullPtr
    }
  return (ptr, inbuf)

-- | Compress a stream of ByteStrings.

compress
  :: MonadResource m
     => CompressParams -- ^ Compress parameter

     -> ConduitT S.ByteString S.ByteString m ()
compress :: forall (m :: * -> *).
MonadResource m =>
CompressParams -> ConduitT ByteString ByteString m ()
compress CompressParams {Int
cpBlockSize :: CompressParams -> Int
cpVerbosity :: CompressParams -> Int
cpWorkFactor :: CompressParams -> Int
cpBlockSize :: Int
cpVerbosity :: Int
cpWorkFactor :: Int
..} = do
  (ptr, inbuf) <- m (Ptr C'bz_stream, IORef CStringLen)
-> ConduitT
     ByteString ByteString m (Ptr C'bz_stream, IORef CStringLen)
forall (m :: * -> *) a.
Monad m =>
m a -> ConduitT ByteString ByteString m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (Ptr C'bz_stream, IORef CStringLen)
 -> ConduitT
      ByteString ByteString m (Ptr C'bz_stream, IORef CStringLen))
-> m (Ptr C'bz_stream, IORef CStringLen)
-> ConduitT
     ByteString ByteString m (Ptr C'bz_stream, IORef CStringLen)
forall a b. (a -> b) -> a -> b
$ m (Ptr C'bz_stream, IORef CStringLen)
forall (m :: * -> *).
MonadResource m =>
m (Ptr C'bz_stream, IORef CStringLen)
allocateStream
  _ <- lift $ allocate
    (throwIfMinus_ "bzCompressInit" $
     c'BZ2_bzCompressInit ptr
     (fromIntegral cpBlockSize)
     (fromIntegral cpVerbosity)
     (fromIntegral cpWorkFactor))
    (\()
_ -> String -> IO CInt -> IO ()
throwIfMinus_ String
"bzCompressEnd" (IO CInt -> IO ()) -> IO CInt -> IO ()
forall a b. (a -> b) -> a -> b
$ Ptr C'bz_stream -> IO CInt
c'BZ2_bzCompressEnd Ptr C'bz_stream
ptr)

  let loop = do
        mbinp <- ConduitT ByteString ByteString m (Maybe ByteString)
forall (m :: * -> *) i o. Monad m => ConduitT i o m (Maybe i)
await
        case mbinp of
          Just ByteString
inp -> do
            Bool
-> ConduitT ByteString ByteString m ()
-> ConduitT ByteString ByteString m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ ByteString -> Bool
S.null ByteString
inp) (ConduitT ByteString ByteString m ()
 -> ConduitT ByteString ByteString m ())
-> ConduitT ByteString ByteString m ()
-> ConduitT ByteString ByteString m ()
forall a b. (a -> b) -> a -> b
$ do
              IO () -> ConduitT ByteString ByteString m ()
forall a. IO a -> ConduitT ByteString ByteString m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ConduitT ByteString ByteString m ())
-> IO () -> ConduitT ByteString ByteString m ()
forall a b. (a -> b) -> a -> b
$ Ptr C'bz_stream -> IORef CStringLen -> ByteString -> IO ()
fillInput Ptr C'bz_stream
ptr IORef CStringLen
inbuf ByteString
inp
              Ptr C'bz_stream -> CInt -> ConduitT ByteString ByteString m ()
forall (m :: * -> *).
MonadIO m =>
Ptr C'bz_stream -> CInt -> ConduitT ByteString ByteString m ()
yields Ptr C'bz_stream
ptr CInt
forall a. Num a => a
c'BZ_RUN
            ConduitT ByteString ByteString m ()
loop
          Maybe ByteString
Nothing -> do
            Ptr C'bz_stream -> CInt -> ConduitT ByteString ByteString m ()
forall (m :: * -> *).
MonadIO m =>
Ptr C'bz_stream -> CInt -> ConduitT ByteString ByteString m ()
yields Ptr C'bz_stream
ptr CInt
forall a. Num a => a
c'BZ_FINISH
  loop
  where
    yields :: MonadIO m => Ptr C'bz_stream -> CInt -> ConduitT S.ByteString S.ByteString m ()
    yields :: forall (m :: * -> *).
MonadIO m =>
Ptr C'bz_stream -> CInt -> ConduitT ByteString ByteString m ()
yields Ptr C'bz_stream
ptr CInt
action = do
      cont <- IO CInt -> ConduitT ByteString ByteString m CInt
forall a. IO a -> ConduitT ByteString ByteString m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO CInt -> ConduitT ByteString ByteString m CInt)
-> IO CInt -> ConduitT ByteString ByteString m CInt
forall a b. (a -> b) -> a -> b
$ String -> IO CInt -> IO CInt
throwIfMinus String
"bzCompress" (IO CInt -> IO CInt) -> IO CInt -> IO CInt
forall a b. (a -> b) -> a -> b
$ Ptr C'bz_stream -> CInt -> IO CInt
c'BZ2_bzCompress Ptr C'bz_stream
ptr CInt
action
      yieldAvailOutput ptr
      availIn <- liftIO $ peek $ p'bz_stream'avail_in ptr
      when (availIn > 0 || action == c'BZ_FINISH && cont /= c'BZ_STREAM_END) $
        yields ptr action

-- | Decompress a stream of ByteStrings. Note that this will only decompress

-- the first compressed stream in the input and leave the rest for further

-- processing. See 'decompress'.

decompress1
  :: MonadResource m
     => DecompressParams -- ^ Decompress parameter

     -> ConduitT S.ByteString S.ByteString m ()
decompress1 :: forall (m :: * -> *).
MonadResource m =>
DecompressParams -> ConduitT ByteString ByteString m ()
decompress1 DecompressParams {Bool
Int
dpVerbosity :: DecompressParams -> Int
dpSmall :: DecompressParams -> Bool
dpVerbosity :: Int
dpSmall :: Bool
..} = do
  (ptr, inbuf) <- m (Ptr C'bz_stream, IORef CStringLen)
-> ConduitT
     ByteString ByteString m (Ptr C'bz_stream, IORef CStringLen)
forall (m :: * -> *) a.
Monad m =>
m a -> ConduitT ByteString ByteString m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (Ptr C'bz_stream, IORef CStringLen)
 -> ConduitT
      ByteString ByteString m (Ptr C'bz_stream, IORef CStringLen))
-> m (Ptr C'bz_stream, IORef CStringLen)
-> ConduitT
     ByteString ByteString m (Ptr C'bz_stream, IORef CStringLen)
forall a b. (a -> b) -> a -> b
$ m (Ptr C'bz_stream, IORef CStringLen)
forall (m :: * -> *).
MonadResource m =>
m (Ptr C'bz_stream, IORef CStringLen)
allocateStream
  _ <- lift $ allocate
    (throwIfMinus_ "bzDecompressInit" $
     c'BZ2_bzDecompressInit ptr (fromIntegral dpVerbosity) (fromBool dpSmall))
    (\()
_ -> String -> IO CInt -> IO ()
throwIfMinus_ String
"bzDecompressEnd" (IO CInt -> IO ()) -> IO CInt -> IO ()
forall a b. (a -> b) -> a -> b
$ Ptr C'bz_stream -> IO CInt
c'BZ2_bzDecompressEnd Ptr C'bz_stream
ptr)

  let loop = do
        mbinp <- ConduitT ByteString ByteString m (Maybe ByteString)
forall (m :: * -> *) i o. Monad m => ConduitT i o m (Maybe i)
await
        case mbinp of
          Just ByteString
inp | Bool -> Bool
not (ByteString -> Bool
S.null ByteString
inp) -> do
            IO () -> ConduitT ByteString ByteString m ()
forall a. IO a -> ConduitT ByteString ByteString m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ConduitT ByteString ByteString m ())
-> IO () -> ConduitT ByteString ByteString m ()
forall a b. (a -> b) -> a -> b
$ Ptr C'bz_stream -> IORef CStringLen -> ByteString -> IO ()
fillInput Ptr C'bz_stream
ptr IORef CStringLen
inbuf ByteString
inp
            cont <- Ptr C'bz_stream -> ConduitT ByteString ByteString m Bool
forall {m :: * -> *}.
MonadIO m =>
Ptr C'bz_stream -> ConduitT ByteString ByteString m Bool
yields Ptr C'bz_stream
ptr
            when cont $ loop
          Just ByteString
_ -> do
            loop
          Maybe ByteString
Nothing -> do
            IO () -> ConduitT ByteString ByteString m ()
forall a. IO a -> ConduitT ByteString ByteString m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> ConduitT ByteString ByteString m ())
-> IO () -> ConduitT ByteString ByteString m ()
forall a b. (a -> b) -> a -> b
$ IOError -> IO ()
forall e a. (HasCallStack, Exception e) => e -> IO a
forall (m :: * -> *) e a.
(MonadThrow m, HasCallStack, Exception e) =>
e -> m a
throwM (IOError -> IO ()) -> IOError -> IO ()
forall a b. (a -> b) -> a -> b
$ String -> IOError
userError String
"unexpected EOF on decompress"
  loop
  where
    yields :: Ptr C'bz_stream -> ConduitT ByteString ByteString m Bool
yields Ptr C'bz_stream
ptr = do
      ret <- IO CInt -> ConduitT ByteString ByteString m CInt
forall a. IO a -> ConduitT ByteString ByteString m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO CInt -> ConduitT ByteString ByteString m CInt)
-> IO CInt -> ConduitT ByteString ByteString m CInt
forall a b. (a -> b) -> a -> b
$ String -> IO CInt -> IO CInt
throwIfMinus String
"BZ2_bzDecompress" (IO CInt -> IO CInt) -> IO CInt -> IO CInt
forall a b. (a -> b) -> a -> b
$ Ptr C'bz_stream -> IO CInt
c'BZ2_bzDecompress Ptr C'bz_stream
ptr
      yieldAvailOutput ptr
      availIn <- liftIO $ peek $ p'bz_stream'avail_in ptr
      if availIn > 0
        then
            -- bzip2 files can contain multiple concatenated streams, but the

            -- API requires that we close the stream and start a new

            -- decompression session.

            if ret == c'BZ_STREAM_END
                then do
                    dataIn <- liftIO $ peek $ p'bz_stream'next_in ptr
                    unread <- liftIO $ S.packCStringLen (dataIn, fromIntegral availIn)
                    leftover unread
                    return False
                else yields ptr
        else return $ ret == c'BZ_OK

-- Decompress all the compressed bzip2 streams in the input, as the bzip2

-- command line tool.

decompress
  :: MonadResource m
     => DecompressParams -- ^ Decompress parameter

     -> ConduitT S.ByteString S.ByteString m ()
decompress :: forall (m :: * -> *).
MonadResource m =>
DecompressParams -> ConduitT ByteString ByteString m ()
decompress DecompressParams
params = do
    next <- ConduitT ByteString ByteString m (Maybe ByteString)
forall (m :: * -> *) i o. Monad m => ConduitT i o m (Maybe i)
await
    case next of
        Maybe ByteString
Nothing -> () -> ConduitT ByteString ByteString m ()
forall a. a -> ConduitT ByteString ByteString m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        Just ByteString
bs
            | ByteString -> Bool
S.null ByteString
bs -> DecompressParams -> ConduitT ByteString ByteString m ()
forall (m :: * -> *).
MonadResource m =>
DecompressParams -> ConduitT ByteString ByteString m ()
decompress DecompressParams
params
            | Bool
otherwise -> do
                ByteString -> ConduitT ByteString ByteString m ()
forall i o (m :: * -> *). i -> ConduitT i o m ()
leftover ByteString
bs
                DecompressParams -> ConduitT ByteString ByteString m ()
forall (m :: * -> *).
MonadResource m =>
DecompressParams -> ConduitT ByteString ByteString m ()
decompress1 DecompressParams
params
                DecompressParams -> ConduitT ByteString ByteString m ()
forall (m :: * -> *).
MonadResource m =>
DecompressParams -> ConduitT ByteString ByteString m ()
decompress DecompressParams
params
-- | bzip2 compression with default parameters.

bzip2 :: MonadResource m => ConduitT S.ByteString S.ByteString m ()
bzip2 :: forall (m :: * -> *).
MonadResource m =>
ConduitT ByteString ByteString m ()
bzip2 = CompressParams -> ConduitT ByteString ByteString m ()
forall (m :: * -> *).
MonadResource m =>
CompressParams -> ConduitT ByteString ByteString m ()
compress CompressParams
forall a. Default a => a
def

-- | bzip2 decompression with default parameters. This will decompress all the

-- streams in the input

bunzip2 :: MonadResource m => ConduitT S.ByteString S.ByteString m ()
bunzip2 :: forall (m :: * -> *).
MonadResource m =>
ConduitT ByteString ByteString m ()
bunzip2 = DecompressParams -> ConduitT ByteString ByteString m ()
forall (m :: * -> *).
MonadResource m =>
DecompressParams -> ConduitT ByteString ByteString m ()
decompress DecompressParams
forall a. Default a => a
def