{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}

module Internal.Sparse(
    GMatrix(..), CSR(..), mkCSR, fromCSR, impureCSR,
    mkSparse, mkDiagR, mkDense,
    AssocMatrix,
    toDense,
    gmXv, (!#>)
)where

import Internal.Vector
import Internal.Matrix
import Internal.Numeric
import qualified Data.Vector.Storable as V
import qualified Data.Vector.Storable.Mutable as M
import Control.Arrow((***))
import Control.Monad(when, foldM)
import Control.Monad.ST (runST)
import Control.Monad.Primitive (PrimMonad)
import Data.List(sort)
import Foreign.C.Types(CInt(..))

import Internal.Devel
import System.IO.Unsafe(unsafePerformIO)
import Foreign(Ptr)
import Text.Printf(printf)

type AssocMatrix = [(IndexOf Matrix, Double)]

data CSR = CSR
        { CSR -> Vector Double
csrVals  :: Vector Double
        , CSR -> Vector CInt
csrCols  :: Vector CInt
        , CSR -> Vector CInt
csrRows  :: Vector CInt
        , CSR -> Int
csrNRows :: Int
        , CSR -> Int
csrNCols :: Int
        } deriving Int -> CSR -> ShowS
[CSR] -> ShowS
CSR -> [Char]
(Int -> CSR -> ShowS)
-> (CSR -> [Char]) -> ([CSR] -> ShowS) -> Show CSR
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> CSR -> ShowS
showsPrec :: Int -> CSR -> ShowS
$cshow :: CSR -> [Char]
show :: CSR -> [Char]
$cshowList :: [CSR] -> ShowS
showList :: [CSR] -> ShowS
Show

data CSC = CSC
        { CSC -> Vector Double
cscVals  :: Vector Double
        , CSC -> Vector CInt
cscRows  :: Vector CInt
        , CSC -> Vector CInt
cscCols  :: Vector CInt
        , CSC -> Int
cscNRows :: Int
        , CSC -> Int
cscNCols :: Int
        } deriving Int -> CSC -> ShowS
[CSC] -> ShowS
CSC -> [Char]
(Int -> CSC -> ShowS)
-> (CSC -> [Char]) -> ([CSC] -> ShowS) -> Show CSC
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> CSC -> ShowS
showsPrec :: Int -> CSC -> ShowS
$cshow :: CSC -> [Char]
show :: CSC -> [Char]
$cshowList :: [CSC] -> ShowS
showList :: [CSC] -> ShowS
Show


-- | Produce a CSR sparse matrix from a association matrix.
mkCSR :: AssocMatrix -> CSR
mkCSR :: AssocMatrix -> CSR
mkCSR AssocMatrix
ms =
  (forall s. ST s CSR) -> CSR
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s CSR) -> CSR) -> (forall s. ST s CSR) -> CSR
forall a b. (a -> b) -> a -> b
$ (forall x.
 (x -> (IndexOf Matrix, Double) -> ST s x)
 -> ST s x -> (x -> ST s CSR) -> [((Int, Int), Double)] -> ST s CSR)
-> [((Int, Int), Double)] -> ST s CSR
forall (m :: * -> *) r.
PrimMonad m =>
(forall x.
 (x -> (IndexOf Matrix, Double) -> m x) -> m x -> (x -> m CSR) -> r)
-> r
impureCSR (x -> ((Int, Int), Double) -> ST s x)
-> ST s x -> (x -> ST s CSR) -> [((Int, Int), Double)] -> ST s CSR
(x -> (IndexOf Matrix, Double) -> ST s x)
-> ST s x -> (x -> ST s CSR) -> [((Int, Int), Double)] -> ST s CSR
forall x.
(x -> (IndexOf Matrix, Double) -> ST s x)
-> ST s x -> (x -> ST s CSR) -> [((Int, Int), Double)] -> ST s CSR
forall {m :: * -> *} {t :: * -> *} {t} {a} {b}.
(Monad m, Foldable t) =>
(t -> a -> m t) -> m t -> (t -> m b) -> t a -> m b
runFold ([((Int, Int), Double)] -> ST s CSR)
-> [((Int, Int), Double)] -> ST s CSR
forall a b. (a -> b) -> a -> b
$ [((Int, Int), Double)] -> [((Int, Int), Double)]
forall a. Ord a => [a] -> [a]
sort [((Int, Int), Double)]
AssocMatrix
ms
    where
  runFold :: (t -> a -> m t) -> m t -> (t -> m b) -> t a -> m b
runFold t -> a -> m t
next m t
initialise t -> m b
xtract t a
as0 = do
    i0  <- m t
initialise
    acc <- foldM next i0 as0
    xtract acc

-- | Produce a CSR sparse matrix by applying a generic folding function.
--
--   This allows one to build a CSR from an effectful streaming source
--   when combined with libraries like pipes, io-streams, or streaming.
--
--   For example
--
--   > impureCSR Pipes.Prelude.foldM :: PrimMonad m => Producer AssocEntry m () -> m CSR
--   > impureCSR Streaming.Prelude.foldM :: PrimMonad m => Stream (Of AssocEntry) m r -> m (Of CSR r)
--
impureCSR
    :: PrimMonad m
    => (forall x . (x -> (IndexOf Matrix, Double) -> m x) -> m x -> (x -> m CSR) -> r)
    -> r
impureCSR :: forall (m :: * -> *) r.
PrimMonad m =>
(forall x.
 (x -> (IndexOf Matrix, Double) -> m x) -> m x -> (x -> m CSR) -> r)
-> r
impureCSR forall x.
(x -> (IndexOf Matrix, Double) -> m x) -> m x -> (x -> m CSR) -> r
f = ((MVector (PrimState m) Double, MVector (PrimState m) CInt,
  MVector (PrimState m) CInt, Int, Int, Int, Int)
 -> (IndexOf Matrix, Double)
 -> m (MVector (PrimState m) Double, MVector (PrimState m) CInt,
       MVector (PrimState m) CInt, Int, Int, Int, Int))
-> m (MVector (PrimState m) Double, MVector (PrimState m) CInt,
      MVector (PrimState m) CInt, Int, Int, Int, Int)
-> ((MVector (PrimState m) Double, MVector (PrimState m) CInt,
     MVector (PrimState m) CInt, Int, Int, Int, Int)
    -> m CSR)
-> r
forall x.
(x -> (IndexOf Matrix, Double) -> m x) -> m x -> (x -> m CSR) -> r
f (MVector (PrimState m) Double, MVector (PrimState m) CInt,
 MVector (PrimState m) CInt, Int, Int, Int, Int)
-> ((Int, Int), Double)
-> m (MVector (PrimState m) Double, MVector (PrimState m) CInt,
      MVector (PrimState m) CInt, Int, Int, Int, Int)
(MVector (PrimState m) Double, MVector (PrimState m) CInt,
 MVector (PrimState m) CInt, Int, Int, Int, Int)
-> (IndexOf Matrix, Double)
-> m (MVector (PrimState m) Double, MVector (PrimState m) CInt,
      MVector (PrimState m) CInt, Int, Int, Int, Int)
forall {m :: * -> *} {g} {a}.
(Ord g, PrintfArg g, PrimMonad m, Num g, Enum g, Storable a) =>
(MVector (PrimState m) a, MVector (PrimState m) CInt,
 MVector (PrimState m) CInt, Int, Int, Int, g)
-> ((g, Int), a)
-> m (MVector (PrimState m) a, MVector (PrimState m) CInt,
      MVector (PrimState m) CInt, Int, Int, Int, g)
next m (MVector (PrimState m) Double, MVector (PrimState m) CInt,
   MVector (PrimState m) CInt, Int, Int, Int, Int)
begin (MVector (PrimState m) Double, MVector (PrimState m) CInt,
 MVector (PrimState m) CInt, Int, Int, Int, Int)
-> m CSR
forall {m :: * -> *}.
PrimMonad m =>
(MVector (PrimState m) Double, MVector (PrimState m) CInt,
 MVector (PrimState m) CInt, Int, Int, Int, Int)
-> m CSR
done
  where
    sfi :: Int -> CInt
sfi = CInt -> CInt
forall a. Enum a => a -> a
succ (CInt -> CInt) -> (Int -> CInt) -> Int -> CInt
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> CInt
fi
    begin :: m (MVector (PrimState m) Double, MVector (PrimState m) CInt,
   MVector (PrimState m) CInt, Int, Int, Int, Int)
begin = do
      mv <- Int -> m (MVector (PrimState m) Double)
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
Int -> m (MVector (PrimState m) a)
M.unsafeNew Int
64
      mr <- M.unsafeNew 64
      mc <- M.unsafeNew 64
      return (mv, mr, mc, 0, 0, 0, -1)

    next :: (MVector (PrimState m) a, MVector (PrimState m) CInt,
 MVector (PrimState m) CInt, Int, Int, Int, g)
-> ((g, Int), a)
-> m (MVector (PrimState m) a, MVector (PrimState m) CInt,
      MVector (PrimState m) CInt, Int, Int, Int, g)
next (!MVector (PrimState m) a
mv, !MVector (PrimState m) CInt
mr, !MVector (PrimState m) CInt
mc, !Int
idxVC, !Int
idxR, !Int
maxC, !g
curRow) ((g
r,Int
c),a
d) = do
      Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (g
r g -> g -> Bool
forall a. Ord a => a -> a -> Bool
< g
curRow) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
        [Char] -> m ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> g -> g -> [Char]
forall r. PrintfType r => [Char] -> r
printf [Char]
"impureCSR: row %i specified after %i" g
r g
curRow)

      let lenVC :: Int
lenVC = MVector (PrimState m) a -> Int
forall a s. Storable a => MVector s a -> Int
M.length MVector (PrimState m) a
mv
          lenR :: Int
lenR  = MVector (PrimState m) CInt -> Int
forall a s. Storable a => MVector s a -> Int
M.length MVector (PrimState m) CInt
mr
          maxC' :: Int
maxC' = Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
maxC Int
c

      (mv', mc') <-
        if Int
idxVC Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
lenVC then do
          mv' <- MVector (PrimState m) a -> Int -> m (MVector (PrimState m) a)
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> Int -> m (MVector (PrimState m) a)
M.unsafeGrow MVector (PrimState m) a
mv Int
lenVC
          mc' <- M.unsafeGrow mc lenVC
          return (mv', mc')
        else
          (MVector (PrimState m) a, MVector (PrimState m) CInt)
-> m (MVector (PrimState m) a, MVector (PrimState m) CInt)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (MVector (PrimState m) a
mv, MVector (PrimState m) CInt
mc)

      mr' <-
        if idxR >= lenR - 1 then
          M.unsafeGrow mr lenR
        else
          return mr

      M.unsafeWrite mc' idxVC (sfi c)
      M.unsafeWrite mv' idxVC d

      idxR' <-
        foldM
          (\Int
idxR' g
_ -> Int
idxR' Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 Int -> m () -> m Int
forall a b. a -> m b -> m a
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ MVector (PrimState m) CInt -> Int -> CInt -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> Int -> a -> m ()
M.unsafeWrite MVector (PrimState m) CInt
mr' Int
idxR' (Int -> CInt
sfi Int
idxVC))
          idxR [1 .. (r-curRow)]

      return (mv', mr', mc', idxVC + 1, idxR', maxC', r)

    done :: (MVector (PrimState m) Double, MVector (PrimState m) CInt,
 MVector (PrimState m) CInt, Int, Int, Int, Int)
-> m CSR
done (!MVector (PrimState m) Double
mv, !MVector (PrimState m) CInt
mr, !MVector (PrimState m) CInt
mc, !Int
idxVC, !Int
idxR, !Int
maxC, !Int
curR) = do
      MVector (PrimState m) CInt -> Int -> CInt -> m ()
forall (m :: * -> *) a.
(PrimMonad m, Storable a) =>
MVector (PrimState m) a -> Int -> a -> m ()
M.unsafeWrite MVector (PrimState m) CInt
mr Int
idxR (Int -> CInt
sfi Int
idxVC)
      vv <- MVector (PrimState m) Double -> m (Vector Double)
forall a (m :: * -> *).
(Storable a, PrimMonad m) =>
MVector (PrimState m) a -> m (Vector a)
V.unsafeFreeze (Int -> MVector (PrimState m) Double -> MVector (PrimState m) Double
forall a s. Storable a => Int -> MVector s a -> MVector s a
M.unsafeTake Int
idxVC MVector (PrimState m) Double
mv)
      vc <- V.unsafeFreeze (M.unsafeTake idxVC mc)
      vr <- V.unsafeFreeze (M.unsafeTake (idxR + 1)  mr)
      return $ CSR vv vc vr (succ curR) (succ maxC)


{- | General matrix with specialized internal representations for
     dense, sparse, diagonal, banded, and constant elements.

>>> let m = mkSparse [((0,999),1.0),((1,1999),2.0)]
>>> m
SparseR {gmCSR = CSR {csrVals = fromList [1.0,2.0],
                      csrCols = fromList [1000,2000],
                      csrRows = fromList [1,2,3],
                      csrNRows = 2,
                      csrNCols = 2000},
                      nRows = 2,
                      nCols = 2000}

>>> let m = mkDense (mat 2 [1..4])
>>> m
Dense {gmDense = (2><2)
 [ 1.0, 2.0
 , 3.0, 4.0 ], nRows = 2, nCols = 2}

-}
data GMatrix
    = SparseR
        { GMatrix -> CSR
gmCSR   :: CSR
        , GMatrix -> Int
nRows   :: Int
        , GMatrix -> Int
nCols   :: Int
        }
    | SparseC
        { GMatrix -> CSC
gmCSC   :: CSC
        , nRows   :: Int
        , nCols   :: Int
        }
    | Diag
        { GMatrix -> Vector Double
diagVals :: Vector Double
        , nRows    :: Int
        , nCols    :: Int
        }
    | Dense
        { GMatrix -> Matrix Double
gmDense :: Matrix Double
        , nRows   :: Int
        , nCols   :: Int
        }
--    | Banded
    deriving Int -> GMatrix -> ShowS
[GMatrix] -> ShowS
GMatrix -> [Char]
(Int -> GMatrix -> ShowS)
-> (GMatrix -> [Char]) -> ([GMatrix] -> ShowS) -> Show GMatrix
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> GMatrix -> ShowS
showsPrec :: Int -> GMatrix -> ShowS
$cshow :: GMatrix -> [Char]
show :: GMatrix -> [Char]
$cshowList :: [GMatrix] -> ShowS
showList :: [GMatrix] -> ShowS
Show


mkDense :: Matrix Double -> GMatrix
mkDense :: Matrix Double -> GMatrix
mkDense Matrix Double
m = Dense{Int
Matrix Double
nRows :: Int
nCols :: Int
gmDense :: Matrix Double
gmDense :: Matrix Double
nRows :: Int
nCols :: Int
..}
  where
    gmDense :: Matrix Double
gmDense = Matrix Double
m
    nRows :: Int
nRows = Matrix Double -> Int
forall t. Matrix t -> Int
rows Matrix Double
m
    nCols :: Int
nCols = Matrix Double -> Int
forall t. Matrix t -> Int
cols Matrix Double
m

mkSparse :: AssocMatrix -> GMatrix
mkSparse :: AssocMatrix -> GMatrix
mkSparse = CSR -> GMatrix
fromCSR (CSR -> GMatrix)
-> ([((Int, Int), Double)] -> CSR)
-> [((Int, Int), Double)]
-> GMatrix
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [((Int, Int), Double)] -> CSR
AssocMatrix -> CSR
mkCSR

fromCSR :: CSR -> GMatrix
fromCSR :: CSR -> GMatrix
fromCSR CSR
csr = SparseR {Int
CSR
gmCSR :: CSR
nRows :: Int
nCols :: Int
gmCSR :: CSR
nRows :: Int
nCols :: Int
..}
  where
    gmCSR :: CSR
gmCSR@CSR {Int
Vector Double
Vector CInt
csrVals :: CSR -> Vector Double
csrCols :: CSR -> Vector CInt
csrRows :: CSR -> Vector CInt
csrNRows :: CSR -> Int
csrNCols :: CSR -> Int
csrVals :: Vector Double
csrCols :: Vector CInt
csrRows :: Vector CInt
csrNRows :: Int
csrNCols :: Int
..} = CSR
csr
    nRows :: Int
nRows = Int
csrNRows
    nCols :: Int
nCols = Int
csrNCols


mkDiagR :: Int -> Int -> Vector Double -> GMatrix
mkDiagR :: Int -> Int -> Vector Double -> GMatrix
mkDiagR Int
r Int
c Vector Double
v
    | Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
v Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
r Int
c = Diag{Int
Vector Double
nRows :: Int
nCols :: Int
diagVals :: Vector Double
nRows :: Int
nCols :: Int
diagVals :: Vector Double
..}
    | Bool
otherwise = [Char] -> GMatrix
forall a. HasCallStack => [Char] -> a
error ([Char] -> GMatrix) -> [Char] -> GMatrix
forall a b. (a -> b) -> a -> b
$ [Char] -> Int -> Int -> Int -> [Char]
forall r. PrintfType r => [Char] -> r
printf [Char]
"mkDiagR: incorrect sizes (%d,%d) [%d]" Int
r Int
c (Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
v)
  where
    nRows :: Int
nRows = Int
r
    nCols :: Int
nCols = Int
c
    diagVals :: Vector Double
diagVals = Vector Double
v


type IV t = CInt -> Ptr CInt   -> t
type  V t = CInt -> Ptr Double -> t
type SMxV = V (IV (IV (V (V (IO CInt)))))

gmXv :: GMatrix -> Vector Double -> Vector Double
gmXv :: GMatrix -> Vector Double -> Vector Double
gmXv SparseR { gmCSR :: GMatrix -> CSR
gmCSR = CSR{Int
Vector Double
Vector CInt
csrVals :: CSR -> Vector Double
csrCols :: CSR -> Vector CInt
csrRows :: CSR -> Vector CInt
csrNRows :: CSR -> Int
csrNCols :: CSR -> Int
csrVals :: Vector Double
csrCols :: Vector CInt
csrRows :: Vector CInt
csrNRows :: Int
csrNCols :: Int
..}, Int
nRows :: GMatrix -> Int
nCols :: GMatrix -> Int
nRows :: Int
nCols :: Int
.. } Vector Double
v = IO (Vector Double) -> Vector Double
forall a. IO a -> a
unsafePerformIO (IO (Vector Double) -> Vector Double)
-> IO (Vector Double) -> Vector Double
forall a b. (a -> b) -> a -> b
$ do
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
v Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
nCols) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
      [Char] -> IO ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> Int -> Int -> Int -> [Char]
forall r. PrintfType r => [Char] -> r
printf [Char]
"gmXv (CSR): incorrect sizes: (%d,%d) x %d" Int
nRows Int
nCols (Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
v))

    r <- Int -> IO (Vector Double)
forall a. Storable a => Int -> IO (Vector a)
createVector Int
nRows
    (csrVals # csrCols # csrRows # v #! r) c_smXv #|"CSRXv"
    return r

gmXv SparseC { gmCSC :: GMatrix -> CSC
gmCSC = CSC{Int
Vector Double
Vector CInt
cscVals :: CSC -> Vector Double
cscRows :: CSC -> Vector CInt
cscCols :: CSC -> Vector CInt
cscNRows :: CSC -> Int
cscNCols :: CSC -> Int
cscVals :: Vector Double
cscRows :: Vector CInt
cscCols :: Vector CInt
cscNRows :: Int
cscNCols :: Int
..}, Int
nRows :: GMatrix -> Int
nCols :: GMatrix -> Int
nRows :: Int
nCols :: Int
.. } Vector Double
v = IO (Vector Double) -> Vector Double
forall a. IO a -> a
unsafePerformIO (IO (Vector Double) -> Vector Double)
-> IO (Vector Double) -> Vector Double
forall a b. (a -> b) -> a -> b
$ do
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
v Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
nCols) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
      [Char] -> IO ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> Int -> Int -> Int -> [Char]
forall r. PrintfType r => [Char] -> r
printf [Char]
"gmXv (CSC): incorrect sizes: (%d,%d) x %d" Int
nRows Int
nCols (Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
v))

    r <- Int -> IO (Vector Double)
forall a. Storable a => Int -> IO (Vector a)
createVector Int
nRows
    (cscVals # cscRows # cscCols # v #! r) c_smTXv #|"CSCXv"
    return r

gmXv Diag{Int
Vector Double
nRows :: GMatrix -> Int
nCols :: GMatrix -> Int
diagVals :: GMatrix -> Vector Double
diagVals :: Vector Double
nRows :: Int
nCols :: Int
..} Vector Double
v
    | Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
v Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
nCols
        = [Vector Double] -> Vector Double
forall t. Storable t => [Vector t] -> Vector t
vjoin [ Int -> Int -> Vector Double -> Vector Double
forall t. Storable t => Int -> Int -> Vector t -> Vector t
subVector Int
0 (Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
diagVals) Vector Double
v Vector Double -> Vector Double -> Vector Double
forall (c :: * -> *) e. Container c e => c e -> c e -> c e
`mul` Vector Double
diagVals
                , Double -> Int -> Vector Double
forall e d (c :: * -> *). Konst e d c => e -> d -> c e
konst Double
0 (Int
nRows Int -> Int -> Int
forall a. Num a => a -> a -> a
- Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
diagVals) ]
    | Bool
otherwise = [Char] -> Vector Double
forall a. HasCallStack => [Char] -> a
error ([Char] -> Vector Double) -> [Char] -> Vector Double
forall a b. (a -> b) -> a -> b
$ [Char] -> Int -> Int -> Int -> Int -> [Char]
forall r. PrintfType r => [Char] -> r
printf [Char]
"gmXv (Diag): incorrect sizes: (%d,%d) [%d] x %d"
                                 Int
nRows Int
nCols (Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
diagVals) (Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
v)

gmXv Dense{Int
Matrix Double
nRows :: GMatrix -> Int
nCols :: GMatrix -> Int
gmDense :: GMatrix -> Matrix Double
gmDense :: Matrix Double
nRows :: Int
nCols :: Int
..} Vector Double
v
    | Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
v Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
nCols
        = Matrix Double -> Vector Double -> Vector Double
forall t. Product t => Matrix t -> Vector t -> Vector t
mXv Matrix Double
gmDense Vector Double
v
    | Bool
otherwise = [Char] -> Vector Double
forall a. HasCallStack => [Char] -> a
error ([Char] -> Vector Double) -> [Char] -> Vector Double
forall a b. (a -> b) -> a -> b
$ [Char] -> Int -> Int -> Int -> [Char]
forall r. PrintfType r => [Char] -> r
printf [Char]
"gmXv (Dense): incorrect sizes: (%d,%d) x %d"
                                 Int
nRows Int
nCols (Vector Double -> Int
forall t. Storable t => Vector t -> Int
dim Vector Double
v)


{- | general matrix - vector product

>>> let m = mkSparse [((0,999),1.0),((1,1999),2.0)]
m :: GMatrix
>>> m !#> vector [1..2000]
[1000.0,4000.0]
it :: Vector Double

-}
infixr 8 !#>
(!#>) :: GMatrix -> Vector Double -> Vector Double
!#> :: GMatrix -> Vector Double -> Vector Double
(!#>) = GMatrix -> Vector Double -> Vector Double
gmXv

--------------------------------------------------------------------------------

foreign import ccall unsafe "smXv"
  c_smXv :: SMxV

foreign import ccall unsafe "smTXv"
  c_smTXv :: SMxV

--------------------------------------------------------------------------------

toDense :: AssocMatrix -> Matrix Double
toDense :: AssocMatrix -> Matrix Double
toDense AssocMatrix
asm = IndexOf Matrix -> Double -> AssocMatrix -> Matrix Double
forall (c :: * -> *) e.
Container c e =>
IndexOf c -> e -> [(IndexOf c, e)] -> c e
assoc (Int
rInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1,Int
cInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Double
0 AssocMatrix
asm
  where
    (Int
r,Int
c) = ([Int] -> Int
forall a. Ord a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum ([Int] -> Int) -> ([Int] -> Int) -> ([Int], [Int]) -> (Int, Int)
forall b c b' c'. (b -> c) -> (b' -> c') -> (b, b') -> (c, c')
forall (a :: * -> * -> *) b c b' c'.
Arrow a =>
a b c -> a b' c' -> a (b, b') (c, c')
*** [Int] -> Int
forall a. Ord a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum) (([Int], [Int]) -> (Int, Int))
-> (AssocMatrix -> ([Int], [Int])) -> AssocMatrix -> (Int, Int)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(Int, Int)] -> ([Int], [Int])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Int, Int)] -> ([Int], [Int]))
-> ([((Int, Int), Double)] -> [(Int, Int)])
-> [((Int, Int), Double)]
-> ([Int], [Int])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (((Int, Int), Double) -> (Int, Int))
-> [((Int, Int), Double)] -> [(Int, Int)]
forall a b. (a -> b) -> [a] -> [b]
map ((Int, Int), Double) -> (Int, Int)
forall a b. (a, b) -> a
fst (AssocMatrix -> (Int, Int)) -> AssocMatrix -> (Int, Int)
forall a b. (a -> b) -> a -> b
$ AssocMatrix
asm


instance Transposable CSR CSC
  where
    tr :: CSR -> CSC
tr (CSR Vector Double
vs Vector CInt
cs Vector CInt
rs Int
n Int
m) = Vector Double -> Vector CInt -> Vector CInt -> Int -> Int -> CSC
CSC Vector Double
vs Vector CInt
cs Vector CInt
rs Int
m Int
n
    tr' :: CSR -> CSC
tr' = CSR -> CSC
forall m mt. Transposable m mt => m -> mt
tr

instance Transposable CSC CSR
  where
    tr :: CSC -> CSR
tr (CSC Vector Double
vs Vector CInt
rs Vector CInt
cs Int
n Int
m) = Vector Double -> Vector CInt -> Vector CInt -> Int -> Int -> CSR
CSR Vector Double
vs Vector CInt
rs Vector CInt
cs Int
m Int
n
    tr' :: CSC -> CSR
tr' = CSC -> CSR
forall m mt. Transposable m mt => m -> mt
tr

instance Transposable GMatrix GMatrix
  where
    tr :: GMatrix -> GMatrix
tr (SparseR CSR
s Int
n Int
m) = CSC -> Int -> Int -> GMatrix
SparseC (CSR -> CSC
forall m mt. Transposable m mt => m -> mt
tr CSR
s) Int
m Int
n
    tr (SparseC CSC
s Int
n Int
m) = CSR -> Int -> Int -> GMatrix
SparseR (CSC -> CSR
forall m mt. Transposable m mt => m -> mt
tr CSC
s) Int
m Int
n
    tr (Diag Vector Double
v Int
n Int
m) = Vector Double -> Int -> Int -> GMatrix
Diag Vector Double
v Int
m Int
n
    tr (Dense Matrix Double
a Int
n Int
m) = Matrix Double -> Int -> Int -> GMatrix
Dense (Matrix Double -> Matrix Double
forall m mt. Transposable m mt => m -> mt
tr Matrix Double
a) Int
m Int
n
    tr' :: GMatrix -> GMatrix
tr' = GMatrix -> GMatrix
forall m mt. Transposable m mt => m -> mt
tr