source: project/release/3/atlas-lapack/trunk/atlas-lapack.scm @ 13043

Last change on this file since 13043 was 13043, checked in by Ivan Raikov, 12 years ago

Version set to 1.10.

File size: 17.5 KB
Line 
1;;
2;; Chicken Scheme bindings for the LAPACK routines in the ATLAS
3;; library.
4;;
5;; Copyright 2007-2009 Ivan Raikov and the Okinawa Institute of Science and Technology
6;;
7;; This program is free software: you can redistribute it and/or
8;; modify it under the terms of the GNU General Public License as
9;; published by the Free Software Foundation, either version 3 of the
10;; License, or (at your option) any later version.
11;;
12;; This program is distributed in the hope that it will be useful, but
13;; WITHOUT ANY WARRANTY; without even the implied warranty of
14;; MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
15;; General Public License for more details.
16;;
17;; A full copy of the GPL license can be found at
18;; <http://www.gnu.org/licenses/>.
19;;
20
21(require-extension blas)
22(require-extension srfi-4)
23
24(define-extension atlas-lapack)
25
26(declare (export atlas-lapack:sgesv
27                 atlas-lapack:dgesv
28                 atlas-lapack:cgesv
29                 atlas-lapack:zgesv
30                 atlas-lapack:sposv
31                 atlas-lapack:dposv
32                 atlas-lapack:cposv
33                 atlas-lapack:zposv
34                 atlas-lapack:sgetrf
35                 atlas-lapack:dgetrf
36                 atlas-lapack:cgetrf
37                 atlas-lapack:zgetrf
38                 atlas-lapack:sgetrs
39                 atlas-lapack:dgetrs
40                 atlas-lapack:cgetrs
41                 atlas-lapack:zgetrs
42                 atlas-lapack:sgetri
43                 atlas-lapack:dgetri
44                 atlas-lapack:cgetri
45                 atlas-lapack:zgetri
46                 atlas-lapack:spotrf
47                 atlas-lapack:dpotrf
48                 atlas-lapack:cpotrf
49                 atlas-lapack:zpotrf
50                 atlas-lapack:spotrs
51                 atlas-lapack:dpotrs
52                 atlas-lapack:cpotrs
53                 atlas-lapack:zpotrs
54                 atlas-lapack:spotri
55                 atlas-lapack:dpotri
56                 atlas-lapack:cpotri
57                 atlas-lapack:zpotri
58                 atlas-lapack:strtri
59                 atlas-lapack:dtrtri
60                 atlas-lapack:ctrtri
61                 atlas-lapack:ztrtri
62                 atlas-lapack:slauum
63                 atlas-lapack:dlauum
64                 atlas-lapack:clauum
65                 atlas-lapack:zlauum
66                 atlas-lapack:sgesv!
67                 atlas-lapack:dgesv!
68                 atlas-lapack:cgesv!
69                 atlas-lapack:zgesv!
70                 atlas-lapack:sposv!
71                 atlas-lapack:dposv!
72                 atlas-lapack:cposv!
73                 atlas-lapack:zposv!
74                 atlas-lapack:sgetrf!
75                 atlas-lapack:dgetrf!
76                 atlas-lapack:cgetrf!
77                 atlas-lapack:zgetrf!
78                 atlas-lapack:sgetrs!
79                 atlas-lapack:dgetrs!
80                 atlas-lapack:cgetrs!
81                 atlas-lapack:zgetrs!
82                 atlas-lapack:sgetri!
83                 atlas-lapack:dgetri!
84                 atlas-lapack:cgetri!
85                 atlas-lapack:zgetri!
86                 atlas-lapack:spotrf!
87                 atlas-lapack:dpotrf!
88                 atlas-lapack:cpotrf!
89                 atlas-lapack:zpotrf!
90                 atlas-lapack:spotrs!
91                 atlas-lapack:dpotrs!
92                 atlas-lapack:cpotrs!
93                 atlas-lapack:zpotrs!
94                 atlas-lapack:spotri!
95                 atlas-lapack:dpotri!
96                 atlas-lapack:cpotri!
97                 atlas-lapack:zpotri!
98                 atlas-lapack:strtri!
99                 atlas-lapack:dtrtri!
100                 atlas-lapack:ctrtri!
101                 atlas-lapack:ztrtri!
102                 atlas-lapack:slauum!
103                 atlas-lapack:dlauum!
104                 atlas-lapack:clauum!
105                 atlas-lapack:zlauum!
106                 unsafe-atlas-lapack:sgesv!
107                 unsafe-atlas-lapack:dgesv!
108                 unsafe-atlas-lapack:cgesv!
109                 unsafe-atlas-lapack:zgesv!
110                 unsafe-atlas-lapack:sposv!
111                 unsafe-atlas-lapack:dposv!
112                 unsafe-atlas-lapack:cposv!
113                 unsafe-atlas-lapack:zposv!
114                 unsafe-atlas-lapack:sgetrf!
115                 unsafe-atlas-lapack:dgetrf!
116                 unsafe-atlas-lapack:cgetrf!
117                 unsafe-atlas-lapack:zgetrf!
118                 unsafe-atlas-lapack:sgetrs!
119                 unsafe-atlas-lapack:dgetrs!
120                 unsafe-atlas-lapack:cgetrs!
121                 unsafe-atlas-lapack:zgetrs!
122                 unsafe-atlas-lapack:sgetri!
123                 unsafe-atlas-lapack:dgetri!
124                 unsafe-atlas-lapack:cgetri!
125                 unsafe-atlas-lapack:zgetri!
126                 unsafe-atlas-lapack:spotrf!
127                 unsafe-atlas-lapack:dpotrf!
128                 unsafe-atlas-lapack:cpotrf!
129                 unsafe-atlas-lapack:zpotrf!
130                 unsafe-atlas-lapack:spotrs!
131                 unsafe-atlas-lapack:dpotrs!
132                 unsafe-atlas-lapack:cpotrs!
133                 unsafe-atlas-lapack:zpotrs!
134                 unsafe-atlas-lapack:spotri!
135                 unsafe-atlas-lapack:dpotri!
136                 unsafe-atlas-lapack:cpotri!
137                 unsafe-atlas-lapack:zpotri!
138                 unsafe-atlas-lapack:strtri!
139                 unsafe-atlas-lapack:dtrtri!
140                 unsafe-atlas-lapack:ctrtri!
141                 unsafe-atlas-lapack:ztrtri!
142                 unsafe-atlas-lapack:slauum!
143                 unsafe-atlas-lapack:dlauum!
144                 unsafe-atlas-lapack:clauum!
145                 unsafe-atlas-lapack:zlauum!))
146
147#>!
148typedef float  CCOMPLEX;
149typedef double ZCOMPLEX;
150
151___declare(substitute,"clapack_;lapack:")
152
153
154/*
155 * Enumerated and derived types from cblas.h
156 */
157
158enum CBLAS_ORDER {CblasRowMajor=101, CblasColMajor=102};
159enum CBLAS_TRANSPOSE {CblasNoTrans=111, CblasTrans=112, CblasConjTrans=113};
160enum CBLAS_UPLO {CblasUpper=121, CblasLower=122};
161enum CBLAS_DIAG {CblasNonUnit=131, CblasUnit=132};
162enum CBLAS_SIDE {CblasLeft=141, CblasRight=142};
163
164/*
165 * ===========================================================================
166 * Prototypes for LAPACK driver routines
167 * ===========================================================================
168 */
169
170/* Driver routines for linear equations */
171
172int clapack_sgesv(const enum CBLAS_ORDER Order, const int N, const int NRHS,
173                  float *A, const int lda, int *ipiv,
174                  float *B, const int ldb);
175
176int clapack_dgesv(const enum CBLAS_ORDER Order, const int N, const int NRHS,
177                  double *A, const int lda, int *ipiv,
178                  double *B, const int ldb);
179
180int clapack_cgesv(const enum CBLAS_ORDER Order, const int N, const int NRHS,
181                  const CCOMPLEX *A, const int lda, int *ipiv,
182                  const CCOMPLEX *B, const int ldb);
183
184int clapack_zgesv(const enum CBLAS_ORDER Order, const int N, const int NRHS,
185                  const ZCOMPLEX *A, const int lda, int *ipiv,
186                  const ZCOMPLEX *B, const int ldb);
187
188int clapack_sposv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
189                  const int N, const int NRHS, float *A, const int lda,
190                  float *B, const int ldb);
191
192int clapack_dposv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
193                  const int N, const int NRHS, double *A, const int lda,
194                  double *B, const int ldb);
195
196int clapack_cposv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
197                  const int N, const int NRHS, const CCOMPLEX *A, const int lda,
198                  const CCOMPLEX *B, const int ldb);
199
200int clapack_zposv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
201                  const int N, const int NRHS, const ZCOMPLEX *A, const int lda,
202                  const ZCOMPLEX *B, const int ldb);
203
204/*
205 * ===========================================================================
206 * Prototypes for LAPACK computational routines
207 * ===========================================================================
208 */
209
210/* Computational routines for  linear equations */
211
212/* General matrix factorize */
213
214int clapack_sgetrf(const enum CBLAS_ORDER Order, const int M, const int N,
215                   float *A, const int lda, int *ipiv);
216
217int clapack_dgetrf(const enum CBLAS_ORDER Order, const int M, const int N,
218                   double *A, const int lda, int *ipiv);
219
220int clapack_cgetrf(const enum CBLAS_ORDER Order, const int M, const int N,
221                   const CCOMPLEX *A, const int lda, int *ipiv);
222
223int clapack_zgetrf(const enum CBLAS_ORDER Order, const int M, const int N,
224                   const ZCOMPLEX *A, const int lda, int *ipiv);
225
226/* General matrix solve using factorization */
227
228int clapack_sgetrs(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE Trans,
229                         const int N, const int NRHS, const float *A, const int lda,
230                         const int *ipiv, float *B, const int ldb);
231
232int clapack_dgetrs
233   (const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE Trans,
234    const int N, const int NRHS, const double *A, const int lda,
235    const int *ipiv, double *B, const int ldb);
236
237int clapack_cgetrs
238   (const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE Trans,
239    const int N, const int NRHS, const const CCOMPLEX *A, const int lda,
240    const int *ipiv, const CCOMPLEX *B, const int ldb);
241
242int clapack_zgetrs
243   (const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE Trans,
244    const int N, const int NRHS, const const ZCOMPLEX *A, const int lda,
245    const int *ipiv, const ZCOMPLEX *B, const int ldb);
246
247
248/* General matrix invert using factorization */
249
250int clapack_sgetri(const enum CBLAS_ORDER Order, const int N, float *A,
251                   const int lda, const int *ipiv);
252
253int clapack_dgetri(const enum CBLAS_ORDER Order, const int N, double *A,
254                   const int lda, const int *ipiv);
255
256int clapack_cgetri(const enum CBLAS_ORDER Order, const int N, const CCOMPLEX *A,
257                   const int lda, const int *ipiv);
258
259int clapack_zgetri(const enum CBLAS_ORDER Order, const int N, const ZCOMPLEX *A,
260                   const int lda, const int *ipiv);
261
262/* Symmetric/hermitian positive definite matrix factorize */
263
264int clapack_spotrf(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
265                   const int N, float *A, const int lda);
266
267int clapack_dpotrf(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
268                   const int N, double *A, const int lda);
269
270int clapack_cpotrf(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
271                   const int N, const CCOMPLEX *A, const int lda);
272
273int clapack_zpotrf(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
274                   const int N, const ZCOMPLEX *A, const int lda);
275
276/* Symmetric/hermitian positive definite matrix solve using factorization */
277
278int clapack_spotrs(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
279                   const int N, const int NRHS, const float *A, const int lda,
280                   float *B, const int ldb);
281
282int clapack_dpotrs(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
283                   const int N, const int NRHS, const double *A, const int lda,
284                   double *B, const int ldb);
285
286int clapack_cpotrs(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
287                   const int N, const int NRHS, const const CCOMPLEX *A, const int lda,
288                   const CCOMPLEX *B, const int ldb);
289
290int clapack_zpotrs(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
291                   const int N, const int NRHS, const const ZCOMPLEX *A, const int lda,
292                   const ZCOMPLEX *B, const int ldb);
293
294
295/* Symmetric/hermitian positive definite matrix invert using factorization */
296
297int clapack_spotri(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
298                   const int N, float *A, const int lda);
299
300int clapack_dpotri(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
301                   const int N, double *A, const int lda);
302
303int clapack_cpotri(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
304                   const int N, const CCOMPLEX *A, const int lda);
305
306int clapack_zpotri(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
307                   const int N, const ZCOMPLEX *A, const int lda);
308
309
310/* Triangular matrix invert */
311
312int clapack_strtri(const enum CBLAS_ORDER Order,const enum CBLAS_UPLO Uplo,
313                  const enum CBLAS_DIAG Diag,const int N, float *A, const int lda);
314
315int clapack_dtrtri(const enum CBLAS_ORDER Order,const enum CBLAS_UPLO Uplo,
316                  const enum CBLAS_DIAG Diag,const int N, double *A, const int lda);
317
318int clapack_ctrtri(const enum CBLAS_ORDER Order,const enum CBLAS_UPLO Uplo,
319                  const enum CBLAS_DIAG Diag,const int N, const CCOMPLEX *A, const int lda);
320
321int clapack_ztrtri(const enum CBLAS_ORDER Order,const enum CBLAS_UPLO Uplo,
322                   const enum CBLAS_DIAG Diag,const int N, const ZCOMPLEX *A, const int lda);
323
324/* Auxilliary routines  */
325
326int clapack_slauum(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
327                   const int N, float *A, const int lda);
328
329int clapack_clauum(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
330                   const int N, const CCOMPLEX *A, const int lda);
331
332int clapack_dlauum(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
333                   const int N, double *A, const int lda);
334
335int clapack_zlauum(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo,
336                   const int N, const ZCOMPLEX *A, const int lda);
337
338<#
339
340(define (atlas-lapack:error x . rest)
341  (let ((port (open-output-string)))
342    (let loop ((objs (if (symbol? x) rest (cons x rest))))
343      (if (null? objs)
344          (begin
345            (newline port)
346            (error (if (symbol? x) x 'atlas-lapack) 
347                   (get-output-string port)))
348          (begin (display (car objs) port)
349                 (display " " port)
350                 (loop (cdr objs)))))))
351
352
353(define-macro (lapack-wrap fn ret errs vsize copy)
354  (let ((fname (string->symbol (conc (if vsize "atlas-" "unsafe-atlas-") 
355                                     (symbol->string (car fn))
356                                     (if copy "" "!"))))
357        (args  (reverse (cdr fn))))
358    `(define ,(let loop ((args args) (sig 'rest))
359                (if (null? args) (cons fname sig)
360                    (let ((x (car args)))
361                      (let ((sig (case x 
362                                   ((opiv) sig)
363                                   ((lda)  sig)
364                                   ((ldb)  sig)
365                                   (else   (cons x sig)))))
366                        (loop (cdr args) sig)))))
367     (let-optionals rest ,(if (memq 'ldb fn)
368                              `((lda ,(if (memq 'm fn) 'm 'n)) (ldb ,(if (memq 'm fn) 'm 'n)))
369                              `((lda ,(if (memq 'm fn) 'm 'n))))
370     ,(if vsize
371          `(begin
372             (let ((asize (,vsize a)))
373               ,(if (memq 'm fn) 
374                    `(if (< asize (fx* m n)) 
375                         (atlas-lapack:error ',fname (conc "matrix A is allocated " asize " elements "
376                                                           "but given dimensions are " m " by " n)))
377                    `(if (< asize (fx* n n)) 
378                         (atlas-lapack:error ',fname (conc "matrix A is allocated " asize " elements "
379                                                           "but given dimensions are " n " by " n)))))
380            ,(if (memq 'b fn)
381                 `(let ((bsize (,vsize b)))
382                    ,(if (memq 'nrhs fn) 
383                         `(if (< bsize (fx* nrhs n)) 
384                              (atlas-lapack:error ',fname (conc "matrix B is allocated " bsize " elements "
385                                                                "but given dimensions are " n " by " nrhs)))
386                         `(if (< bsize (fx* n 1)) 
387                              (atlas-lapack:error ,fname (conc "matrix B is allocated " bsize " elements "
388                                                               "but given dimensions are " n " by " 1)))))
389                 `(noop)))
390          `(noop))
391     (let ,(let loop ((fn fn) (bnds '()))
392             (if (null? fn) bnds
393                 (let ((x (car fn)))
394                   (let ((bnds (case x 
395                                 ((opiv)  (cons `(opiv (make-s32vector n)) bnds))
396                                 (else    (if (and copy (memq x ret))
397                                              (cons `(,x (,copy ,x)) bnds)
398                                              bnds)))))
399                     (loop (cdr fn) bnds)))))
400       (let ((info ,fn))
401           (cond ((= info 0) (values . ,ret))
402                 ((< info 0) (atlas-lapack:error ',fname (,(car errs) info)))
403                 ((> info 0) (atlas-lapack:error ',fname (,(cadr errs) info))))))))))
404
405
406(define-macro (lapack-wrapx fn ret errs)
407  `(begin
408     (lapack-wrap ,(cons (string->symbol (conc "lapack:s" (symbol->string (car fn)))) (cdr fn))
409                   ,ret ,errs #f #f)
410     (lapack-wrap ,(cons (string->symbol (conc "lapack:d" (symbol->string (car fn)))) (cdr fn))
411                   ,ret ,errs #f #f)
412     (lapack-wrap ,(cons (string->symbol (conc "lapack:c" (symbol->string (car fn)))) (cdr fn))
413                   ,ret ,errs #f #f)
414     (lapack-wrap ,(cons (string->symbol (conc "lapack:z" (symbol->string (car fn)))) (cdr fn))
415                   ,ret ,errs #f #f)
416
417     (lapack-wrap ,(cons (string->symbol (conc "lapack:s" (symbol->string (car fn)))) (cdr fn))
418                   ,ret ,errs f32vector-length #f)
419     (lapack-wrap ,(cons (string->symbol (conc "lapack:d" (symbol->string (car fn)))) (cdr fn))
420                   ,ret ,errs f64vector-length #f)
421     (lapack-wrap ,(cons (string->symbol (conc "lapack:c" (symbol->string (car fn)))) (cdr fn))
422                   ,ret ,errs (lambda (v) (fx/ 2 (f32vector-length v))) #f)
423     (lapack-wrap ,(cons (string->symbol (conc "lapack:z" (symbol->string (car fn)))) (cdr fn))
424                    ,ret ,errs (lambda (v) (fx/ 2 (f64vector-length v))) #f)
425
426     (lapack-wrap ,(cons (string->symbol (conc "lapack:s" (symbol->string (car fn)))) (cdr fn))
427                   ,ret ,errs f32vector-length  blas:scopy)
428     (lapack-wrap ,(cons (string->symbol (conc "lapack:d" (symbol->string (car fn)))) (cdr fn))
429                   ,ret ,errs f64vector-length  blas:dcopy)
430     (lapack-wrap ,(cons (string->symbol (conc "lapack:c" (symbol->string (car fn)))) (cdr fn))
431                  ,ret ,errs (lambda (v) (fx/ 2 (f32vector-length v))) blas:ccopy)
432     (lapack-wrap ,(cons (string->symbol (conc "lapack:z" (symbol->string (car fn)))) (cdr fn))
433                  ,ret ,errs (lambda (v) (fx/ 2 (f64vector-length v))) blas:zcopy)))
434
435
436
437(lapack-wrapx (gesv order n nrhs a lda opiv b ldb)
438               (a b opiv)
439               ((lambda (i) (conc i "-th argument had an illegal value"))
440                (lambda (i) "upper triangular matrix is singular")))
441     
442
443(lapack-wrapx (posv order uplo n nrhs a lda b ldb)
444               (a b)
445               ((lambda (i) (conc i "-th argument had an illegal value"))
446                (lambda (i) (conc "leading minor of order " i 
447                                  " of A is not positive definite"))))
448
449(lapack-wrapx (getrf order m n a lda opiv)
450               (a opiv)
451               ((lambda (i) (conc i "-th argument had an illegal value"))
452                (lambda (i) "factor U is singular")))
453     
454(lapack-wrapx (getrs order trans n nrhs a lda ipiv b ldb)
455               (b)
456               ((lambda (i) (conc i "-th argument had an illegal value"))
457                (lambda (i) "unknown error")))
458
459(lapack-wrapx (getri order n a lda ipiv)
460               (a)
461               ((lambda (i) (conc i "-th argument had an illegal value"))
462                (lambda (i) "factor U is singular")))
463
464(lapack-wrapx (potrf order uplo n a lda)
465               (a)
466               ((lambda (i) (conc i "-th argument had an illegal value"))
467                (lambda (i) (conc "leading minor of order " i " is not positive definite"))))
468
469(lapack-wrapx (potrs order uplo n nrhs a lda b ldb)
470               (b)
471               ((lambda (i) (conc i "-th argument had an illegal value"))
472                (lambda (i) "unknown error")))
473               
474(lapack-wrapx (potri order uplo n  a lda)
475               (a)
476               ((lambda (i) (conc i "-th argument had an illegal value"))
477                (lambda (i) (conc "element " "(" i "," i")" " of factor U or L is zero"))))
478               
479(lapack-wrapx (trtri order uplo diag n a lda)
480               (a)
481               ((lambda (i) (conc i "-th argument had an illegal value"))
482                (lambda (i) "the triangular matrix is singular")))
483
484(lapack-wrapx (lauum order uplo n a lda)
485               (a)
486               ((lambda (i) (conc i "-th argument had an illegal value"))
487                (lambda (i) "unknown error")))
488
Note: See TracBrowser for help on using the repository browser.