{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
{-# OPTIONS_GHC -fno-warn-unused-top-binds #-}

-----------------------------------------------------------------------------
-- |
-- Module      :  Numeric.GSL.LinearAlgebra
-- Copyright   :  (c) Alberto Ruiz 2007-14
-- License     :  GPL
-- Maintainer  :  Alberto Ruiz
-- Stability   :  provisional
--
-----------------------------------------------------------------------------

module Numeric.GSL.LinearAlgebra (
    RandDist(..), randomVector,
    saveMatrix,
    fwriteVector, freadVector, fprintfVector, fscanfVector,
    fileDimensions, loadMatrix, fromFile
) where

import Numeric.LinearAlgebra.HMatrix hiding (RandDist,randomVector,saveMatrix,loadMatrix)
import Numeric.GSL.Internal hiding (TV,TM,TCV,TCM)

import Foreign.Marshal.Alloc(free)
import Foreign.Ptr(Ptr)
import Foreign.C.Types
import Foreign.C.String(newCString)
import System.IO.Unsafe(unsafePerformIO)
import System.Process(readProcess)

fromei :: a -> CInt
fromei a
x = Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral (a -> Int
forall a. Enum a => a -> Int
fromEnum a
x) :: CInt

-----------------------------------------------------------------------

data RandDist = Uniform  -- ^ uniform distribution in [0,1)
              | Gaussian -- ^ normal distribution with mean zero and standard deviation one
              deriving Int -> RandDist
RandDist -> Int
RandDist -> [RandDist]
RandDist -> RandDist
RandDist -> RandDist -> [RandDist]
RandDist -> RandDist -> RandDist -> [RandDist]
(RandDist -> RandDist)
-> (RandDist -> RandDist)
-> (Int -> RandDist)
-> (RandDist -> Int)
-> (RandDist -> [RandDist])
-> (RandDist -> RandDist -> [RandDist])
-> (RandDist -> RandDist -> [RandDist])
-> (RandDist -> RandDist -> RandDist -> [RandDist])
-> Enum RandDist
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
$csucc :: RandDist -> RandDist
succ :: RandDist -> RandDist
$cpred :: RandDist -> RandDist
pred :: RandDist -> RandDist
$ctoEnum :: Int -> RandDist
toEnum :: Int -> RandDist
$cfromEnum :: RandDist -> Int
fromEnum :: RandDist -> Int
$cenumFrom :: RandDist -> [RandDist]
enumFrom :: RandDist -> [RandDist]
$cenumFromThen :: RandDist -> RandDist -> [RandDist]
enumFromThen :: RandDist -> RandDist -> [RandDist]
$cenumFromTo :: RandDist -> RandDist -> [RandDist]
enumFromTo :: RandDist -> RandDist -> [RandDist]
$cenumFromThenTo :: RandDist -> RandDist -> RandDist -> [RandDist]
enumFromThenTo :: RandDist -> RandDist -> RandDist -> [RandDist]
Enum

-- | Obtains a vector of pseudorandom elements from the the mt19937 generator in GSL, with a given seed. Use randomIO to get a random seed.
randomVector :: Int      -- ^ seed
             -> RandDist -- ^ distribution
             -> Int      -- ^ vector size
             -> Vector Double
randomVector :: Int -> RandDist -> Int -> Vector Double
randomVector Int
seed RandDist
dist Int
n = IO (Vector Double) -> Vector Double
forall a. IO a -> a
unsafePerformIO (IO (Vector Double) -> Vector Double)
-> IO (Vector Double) -> Vector Double
forall a b. (a -> b) -> a -> b
$ do
    r <- Int -> IO (Vector Double)
forall a. Storable a => Int -> IO (Vector a)
createVector Int
n
    (r `applyRaw` id) (c_random_vector (fi seed) ((fi.fromEnum) dist))  #|"randomVector"
    return r

foreign import ccall unsafe "random_vector" c_random_vector :: CInt -> CInt -> TV

--------------------------------------------------------------------------------

-- | Saves a matrix as 2D ASCII table.
saveMatrix :: FilePath
           -> String     -- ^ format (%f, %g, %e)
           -> Matrix Double
           -> IO ()
saveMatrix :: String -> String -> Matrix Double -> IO ()
saveMatrix String
filename String
fmt Matrix Double
m = do
    charname <- String -> IO CString
newCString String
filename
    charfmt <- newCString fmt
    let o = if Matrix Double -> MatrixOrder
forall t. Matrix t -> MatrixOrder
orderOf Matrix Double
m MatrixOrder -> MatrixOrder -> Bool
forall a. Eq a => a -> a -> Bool
== MatrixOrder
RowMajor then CInt
1 else CInt
0
    (m `applyRaw` id) (matrix_fprintf charname charfmt o)  #|"matrix_fprintf"
    free charname
    free charfmt

foreign import ccall unsafe "matrix_fprintf" matrix_fprintf :: Ptr CChar -> Ptr CChar -> CInt -> TM

--------------------------------------------------------------------------------

-- | Loads a vector from an ASCII file (the number of elements must be known in advance).
fscanfVector :: FilePath -> Int -> IO (Vector Double)
fscanfVector :: String -> Int -> IO (Vector Double)
fscanfVector String
filename Int
n = do
    charname <- String -> IO CString
newCString String
filename
    res <- createVector n
    (res `applyRaw` id) (gsl_vector_fscanf charname) #|"gsl_vector_fscanf"
    free charname
    return res

foreign import ccall unsafe "vector_fscanf" gsl_vector_fscanf:: Ptr CChar -> TV

-- | Saves the elements of a vector, with a given format (%f, %e, %g), to an ASCII file.
fprintfVector :: FilePath -> String -> Vector Double -> IO ()
fprintfVector :: String -> String -> Vector Double -> IO ()
fprintfVector String
filename String
fmt Vector Double
v = do
    charname <- String -> IO CString
newCString String
filename
    charfmt <- newCString fmt
    (v `applyRaw` id) (gsl_vector_fprintf charname charfmt) #|"gsl_vector_fprintf"
    free charname
    free charfmt

foreign import ccall unsafe "vector_fprintf" gsl_vector_fprintf :: Ptr CChar -> Ptr CChar -> TV

-- | Loads a vector from a binary file (the number of elements must be known in advance).
freadVector :: FilePath -> Int -> IO (Vector Double)
freadVector :: String -> Int -> IO (Vector Double)
freadVector String
filename Int
n = do
    charname <- String -> IO CString
newCString String
filename
    res <- createVector n
    (res `applyRaw` id) (gsl_vector_fread charname) #| "gsl_vector_fread"
    free charname
    return res

foreign import ccall unsafe "vector_fread" gsl_vector_fread:: Ptr CChar -> TV

-- | Saves the elements of a vector to a binary file.
fwriteVector :: FilePath -> Vector Double -> IO ()
fwriteVector :: String -> Vector Double -> IO ()
fwriteVector String
filename Vector Double
v = do
    charname <- String -> IO CString
newCString String
filename
    (v `applyRaw` id) (gsl_vector_fwrite charname) #|"gsl_vector_fwrite"
    free charname

foreign import ccall unsafe "vector_fwrite" gsl_vector_fwrite :: Ptr CChar -> TV

type PD = Ptr Double                            --
type TV = CInt -> PD -> IO CInt                 --
type TM = CInt -> CInt -> PD -> IO CInt         --

--------------------------------------------------------------------------------

{- |  obtains the number of rows and columns in an ASCII data file
      (provisionally using unix's wc).
-}
fileDimensions :: FilePath -> IO (Int,Int)
fileDimensions :: String -> IO (Int, Int)
fileDimensions String
fname = do
    wcres <- String -> [String] -> String -> IO String
readProcess String
"wc" [String
"-w",String
fname] String
""
    contents <- readFile fname
    let tot = String -> Int
forall a. Read a => String -> a
read (String -> Int) -> (String -> String) -> String -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [String] -> String
forall a. HasCallStack => [a] -> a
head ([String] -> String) -> (String -> [String]) -> String -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> [String]
words (String -> Int) -> String -> Int
forall a b. (a -> b) -> a -> b
$ String
wcres
        c   = [String] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([String] -> Int) -> (String -> [String]) -> String -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[String]] -> [String]
forall a. HasCallStack => [a] -> a
head ([[String]] -> [String])
-> (String -> [[String]]) -> String -> [String]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([String] -> Bool) -> [[String]] -> [[String]]
forall a. (a -> Bool) -> [a] -> [a]
dropWhile [String] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([[String]] -> [[String]])
-> (String -> [[String]]) -> String -> [[String]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (String -> [String]) -> [String] -> [[String]]
forall a b. (a -> b) -> [a] -> [b]
map String -> [String]
words ([String] -> [[String]])
-> (String -> [String]) -> String -> [[String]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> [String]
lines (String -> Int) -> String -> Int
forall a b. (a -> b) -> a -> b
$ String
contents
    if tot > 0
        then return (tot `div` c, c)
        else return (0,0)

-- | Loads a matrix from an ASCII file formatted as a 2D table.
loadMatrix :: FilePath -> IO (Matrix Double)
loadMatrix :: String -> IO (Matrix Double)
loadMatrix String
file = String -> (Int, Int) -> IO (Matrix Double)
fromFile String
file ((Int, Int) -> IO (Matrix Double))
-> IO (Int, Int) -> IO (Matrix Double)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< String -> IO (Int, Int)
fileDimensions String
file

-- | Loads a matrix from an ASCII file (the number of rows and columns must be known in advance).
fromFile :: FilePath -> (Int,Int) -> IO (Matrix Double)
fromFile :: String -> (Int, Int) -> IO (Matrix Double)
fromFile String
filename (Int
r,Int
c) = Int -> Vector Double -> Matrix Double
forall t. Storable t => Int -> Vector t -> Matrix t
reshape Int
c (Vector Double -> Matrix Double)
-> IO (Vector Double) -> IO (Matrix Double)
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` String -> Int -> IO (Vector Double)
fscanfVector String
filename (Int
rInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
c)