1 | ;; |
---|
2 | ;; |
---|
3 | ;; Chicken MPI regression test |
---|
4 | ;; |
---|
5 | ;; Based on the Caml/MPI interface by Xavier Leroy. |
---|
6 | ;; |
---|
7 | ;; Copyright 2007-2009 Ivan Raikov and the Okinawa Institute of Science and Technology |
---|
8 | ;; |
---|
9 | ;; This program is free software: you can redistribute it and/or |
---|
10 | ;; modify it under the terms of the GNU General Public License as |
---|
11 | ;; published by the Free Software Foundation, either version 3 of the |
---|
12 | ;; License, or (at your option) any later version. |
---|
13 | ;; |
---|
14 | ;; This program is distributed in the hope that it will be useful, but |
---|
15 | ;; WITHOUT ANY WARRANTY; without even the implied warranty of |
---|
16 | ;; MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU |
---|
17 | ;; General Public License for more details. |
---|
18 | ;; |
---|
19 | ;; A full copy of the GPL license can be found at |
---|
20 | ;; <http://www.gnu.org/licenses/>. |
---|
21 | ;; |
---|
22 | |
---|
23 | (require-extension posix srfi-4 srfi-13 srfi-14 mpi test) |
---|
24 | |
---|
25 | (define (land . args) |
---|
26 | (if (null? args) #t |
---|
27 | (and (car args) (apply land (cdr args))))) |
---|
28 | |
---|
29 | (define (lor . args) |
---|
30 | (if (null? args) #f |
---|
31 | (or (car args) (apply lor (cdr args))))) |
---|
32 | |
---|
33 | (define (eval-op op args) |
---|
34 | (apply |
---|
35 | (cond ((= op MPI:i_max) max) |
---|
36 | ((= op MPI:i_min) min) |
---|
37 | ((= op MPI:i_sum) +) |
---|
38 | ((= op MPI:i_prod) *) |
---|
39 | ((= op MPI:i_land) land) |
---|
40 | ((= op MPI:i_lor) lor) |
---|
41 | ((= op MPI:i_xor) fxxor) |
---|
42 | ((= op MPI:f_max) max) |
---|
43 | ((= op MPI:f_min) min) |
---|
44 | ((= op MPI:f_sum) +) |
---|
45 | ((= op MPI:f_prod) *) |
---|
46 | (else (error 'eval-op "unknown op " op))) |
---|
47 | args)) |
---|
48 | |
---|
49 | (define (blob-range x i j) |
---|
50 | (string->blob (string-copy (blob->string x) i j))) |
---|
51 | |
---|
52 | (define (make-srfi4-vector-map makev vlen vset! vref) |
---|
53 | (lambda (v f) |
---|
54 | (let loop ((v v) (newv (makev (vlen v))) (n (- (vlen v) 1))) |
---|
55 | (if (>= n 0) |
---|
56 | (let ((x (f (vref v n)))) |
---|
57 | (vset! newv n x) |
---|
58 | (loop v newv (- n 1))) |
---|
59 | (begin |
---|
60 | newv))))) |
---|
61 | |
---|
62 | (define (make-srfi4-vector-range makev vlen vset! vref) |
---|
63 | (lambda (v i j) |
---|
64 | (and (and (positive? j) (or (zero? i) (positive? i)) (< i j) (< (- j i) (vlen v))) |
---|
65 | (let loop ((v v) (newv (makev (- j i))) (n 0) (i i)) |
---|
66 | (if (< i j) |
---|
67 | (let ((x (vref v i))) |
---|
68 | (vset! newv n x) |
---|
69 | (loop v newv (+ n 1) (+ i 1))) |
---|
70 | newv))))) |
---|
71 | |
---|
72 | |
---|
73 | (define u8vector-map (make-srfi4-vector-map make-u8vector |
---|
74 | u8vector-length |
---|
75 | u8vector-set! |
---|
76 | u8vector-ref)) |
---|
77 | |
---|
78 | (define s8vector-map (make-srfi4-vector-map make-s8vector |
---|
79 | s8vector-length |
---|
80 | s8vector-set! |
---|
81 | s8vector-ref)) |
---|
82 | |
---|
83 | (define u16vector-map (make-srfi4-vector-map make-u16vector |
---|
84 | u16vector-length |
---|
85 | u16vector-set! |
---|
86 | u16vector-ref)) |
---|
87 | |
---|
88 | (define s16vector-map (make-srfi4-vector-map make-s16vector |
---|
89 | s16vector-length |
---|
90 | s16vector-set! |
---|
91 | s16vector-ref)) |
---|
92 | |
---|
93 | (define u32vector-map (make-srfi4-vector-map make-u32vector |
---|
94 | u32vector-length |
---|
95 | u32vector-set! |
---|
96 | u32vector-ref)) |
---|
97 | |
---|
98 | (define s32vector-map (make-srfi4-vector-map make-s32vector |
---|
99 | s32vector-length |
---|
100 | s32vector-set! |
---|
101 | s32vector-ref)) |
---|
102 | |
---|
103 | (define f32vector-map (make-srfi4-vector-map make-f32vector |
---|
104 | f32vector-length |
---|
105 | f32vector-set! |
---|
106 | f32vector-ref)) |
---|
107 | |
---|
108 | (define f64vector-map (make-srfi4-vector-map make-f64vector |
---|
109 | f64vector-length |
---|
110 | f64vector-set! |
---|
111 | f64vector-ref)) |
---|
112 | |
---|
113 | |
---|
114 | |
---|
115 | (define u8vector-range (make-srfi4-vector-range make-u8vector |
---|
116 | u8vector-length |
---|
117 | u8vector-set! |
---|
118 | u8vector-ref)) |
---|
119 | |
---|
120 | (define s8vector-range (make-srfi4-vector-range make-s8vector |
---|
121 | s8vector-length |
---|
122 | s8vector-set! |
---|
123 | s8vector-ref)) |
---|
124 | |
---|
125 | (define u16vector-range (make-srfi4-vector-range make-u16vector |
---|
126 | u16vector-length |
---|
127 | u16vector-set! |
---|
128 | u16vector-ref)) |
---|
129 | |
---|
130 | (define s16vector-range (make-srfi4-vector-range make-s16vector |
---|
131 | s16vector-length |
---|
132 | s16vector-set! |
---|
133 | s16vector-ref)) |
---|
134 | |
---|
135 | (define u32vector-range (make-srfi4-vector-range make-u32vector |
---|
136 | u32vector-length |
---|
137 | u32vector-set! |
---|
138 | u32vector-ref)) |
---|
139 | |
---|
140 | |
---|
141 | (define s32vector-range (make-srfi4-vector-range make-s32vector |
---|
142 | s32vector-length |
---|
143 | s32vector-set! |
---|
144 | s32vector-ref)) |
---|
145 | |
---|
146 | (define f32vector-range (make-srfi4-vector-range make-f32vector |
---|
147 | f32vector-length |
---|
148 | f32vector-set! |
---|
149 | f32vector-ref)) |
---|
150 | |
---|
151 | (define f64vector-range (make-srfi4-vector-range make-f64vector |
---|
152 | f64vector-length |
---|
153 | f64vector-set! |
---|
154 | f64vector-ref)) |
---|
155 | |
---|
156 | (define (check-string rank n c size) |
---|
157 | (print "rank = " rank " n = " n " size = " size) |
---|
158 | (and (= (length (string->list n)) (+ 1 size)) |
---|
159 | (every (lambda (x) (char=? x c)) (string->list n)))) |
---|
160 | |
---|
161 | (MPI:init) |
---|
162 | |
---|
163 | (print "Host " (get-host-name)) |
---|
164 | |
---|
165 | (define comm-world (MPI:get-comm-world)) |
---|
166 | (define size (MPI:comm-size comm-world)) |
---|
167 | (define myrank (MPI:comm-rank comm-world)) |
---|
168 | (define vsize 3) |
---|
169 | (define intdata (list-tabulate size (lambda (i) (* 10 i)))) |
---|
170 | (define flodata (list-tabulate size (lambda (i) (* 0.1 i)))) |
---|
171 | (define vsdata (list-tabulate size (lambda (i) |
---|
172 | (list->string (list-tabulate vsize |
---|
173 | (lambda (j) (integer->char (+ i 97)))))))) |
---|
174 | (define vvsdata (list-tabulate size (lambda (i) |
---|
175 | (list->string (list-tabulate (+ i 1) |
---|
176 | (lambda (j) (integer->char (+ i 97)))))))) |
---|
177 | (define vintdata (list-tabulate size (lambda (i) (list-tabulate vsize (lambda (j) (+ (* 10 i) j)))))) |
---|
178 | (define vflodata (list-tabulate size (lambda (i) (list-tabulate vsize (lambda (j) (+ i (* 0.1 j))))))) |
---|
179 | (define vvintdata (list-tabulate size (lambda (i) (list-tabulate (+ i 1) (lambda (j) (+ (* 10 i) j)))))) |
---|
180 | (define vvflodata (list-tabulate size (lambda (i) (list-tabulate (+ i 1) (lambda (j) (+ i (* 0.1 j))))))) |
---|
181 | |
---|
182 | |
---|
183 | (test-group "MPI test" |
---|
184 | |
---|
185 | (if (zero? myrank) |
---|
186 | (let ((data "aa")) |
---|
187 | (print myrank ": sending " data) |
---|
188 | (MPI:send (string->blob data) 1 0 comm-world) |
---|
189 | (let ((n (blob->string (MPI:receive MPI:any-source MPI:any-tag comm-world)))) |
---|
190 | (print myrank ": received " n) |
---|
191 | (test-assert (check-string myrank n #\a size)))) |
---|
192 | (let* ((n (blob->string (MPI:receive MPI:any-source MPI:any-tag comm-world))) |
---|
193 | (n1 (string-append n "a"))) |
---|
194 | (print myrank ": received " n ", resending " n1) |
---|
195 | (MPI:send (string->blob n1) (modulo (+ myrank 1) size) 0 comm-world) |
---|
196 | (test-assert (check-string myrank n #\a myrank)) |
---|
197 | )) |
---|
198 | |
---|
199 | ;; Barrier |
---|
200 | (MPI:barrier comm-world) |
---|
201 | |
---|
202 | (if (zero? myrank) |
---|
203 | (let ((data1 "aa") |
---|
204 | (data2 "bb")) |
---|
205 | (print myrank ": sending (tag 0) " data1) |
---|
206 | (MPI:send (string->blob data1) 1 0 comm-world) |
---|
207 | (print myrank ": sending (tag 1) " data2) |
---|
208 | (MPI:send (string->blob data2) 1 1 comm-world) |
---|
209 | (let-values (((n src tag) (MPI:receive-with-status MPI:any-source MPI:any-tag comm-world))) |
---|
210 | (print myrank ": received " (blob->string n) " (tag " tag ")" " from " src) |
---|
211 | (if (zero? tag) |
---|
212 | (test-assert (check-string myrank (blob->string n) #\a size)) |
---|
213 | (test-assert (check-string myrank (blob->string n) #\b size))) |
---|
214 | (let-values (((n src tag) (MPI:receive-with-status MPI:any-source MPI:any-tag comm-world))) |
---|
215 | (print myrank ": received " (blob->string n) " (tag " tag ")" " from " src) |
---|
216 | (if (zero? tag) |
---|
217 | (test-assert (check-string myrank (blob->string n) #\a size)) |
---|
218 | (test-assert (check-string myrank (blob->string n) #\b size)))))) |
---|
219 | (let-values (((n1 src tag1) (MPI:receive-with-status MPI:any-source 0 comm-world))) |
---|
220 | (let* ((n1 (blob->string n1)) |
---|
221 | (nn1 (if (zero? tag1) (string-append n1 "a") (string-append n1 "b")))) |
---|
222 | (print myrank ": received " n1 " (tag " tag1 ")" " from " src |
---|
223 | ", resending " nn1) |
---|
224 | (if (zero? tag1) |
---|
225 | (test-assert (check-string myrank n1 #\a myrank)) |
---|
226 | (test-assert (check-string myrank n1 #\b myrank))) |
---|
227 | (let-values (((n2 src tag2) (MPI:receive-with-status MPI:any-source MPI:any-tag comm-world))) |
---|
228 | (let* ((n2 (blob->string n2)) |
---|
229 | (nn2 (if (zero? tag2) (string-append n2 "a") (string-append n2 "b")))) |
---|
230 | (if (zero? tag2) |
---|
231 | (test-assert (check-string myrank n2 #\a myrank)) |
---|
232 | (test-assert (check-string myrank n2 #\b myrank))) |
---|
233 | (print myrank ": received " n2 " (tag " tag2 ")" " from " src |
---|
234 | ", resending " nn2) |
---|
235 | (MPI:send (string->blob nn1) (modulo (+ 1 myrank) size) tag1 comm-world) |
---|
236 | (MPI:send (string->blob nn2) (modulo (+ 1 myrank) size) tag2 comm-world)))))) |
---|
237 | |
---|
238 | ;; Barrier |
---|
239 | (MPI:barrier comm-world) |
---|
240 | |
---|
241 | (let ((test-send-recv |
---|
242 | (lambda (sendfun recvfun transf data) |
---|
243 | (if (zero? myrank) |
---|
244 | (begin |
---|
245 | (print myrank ": test-send-recv: data = " data) |
---|
246 | (print myrank ": test-send-recv: size = " size) |
---|
247 | (let loop ((lst data) (i 1)) |
---|
248 | (if (and (not (null? lst)) (< i size)) |
---|
249 | (begin |
---|
250 | (print myrank ": sending " (car lst) " to " i) |
---|
251 | (sendfun (car lst) i 0 comm-world) |
---|
252 | (loop (cdr lst) (+ 1 i))))) |
---|
253 | (let loop ((i size)) |
---|
254 | (if (positive? (- i 1)) |
---|
255 | (let ((x (recvfun (- i 1) 0 comm-world))) |
---|
256 | (print myrank ": received " x) |
---|
257 | (test-assert (any (lambda (y) (equal? x y)) (map transf data))) |
---|
258 | (loop (- i 1)))))) |
---|
259 | (let ((x (recvfun 0 0 comm-world))) |
---|
260 | (print myrank ": received " x) |
---|
261 | (test-assert (member x data)) |
---|
262 | (let ((y (transf x))) |
---|
263 | (sendfun y 0 0 comm-world)))) |
---|
264 | (MPI:barrier comm-world)))) |
---|
265 | (test-send-recv MPI:send-fixnum MPI:receive-fixnum (lambda (x) (+ 1 x)) intdata) |
---|
266 | (test-send-recv MPI:send-int MPI:receive-int (lambda (x) (+ 1 x)) intdata) |
---|
267 | (test-send-recv MPI:send-flonum MPI:receive-flonum (lambda (x) (* 2 x)) flodata) |
---|
268 | (let ((srfi4-test-send-recv |
---|
269 | (lambda (len vsend vreceive vmap list->vector) |
---|
270 | (lambda (data) |
---|
271 | (test-send-recv vsend |
---|
272 | (lambda (src tag comm) (vreceive len src tag comm)) |
---|
273 | (lambda (v) (vmap v (lambda (x) (+ 1 x)))) |
---|
274 | (map list->vector data)))))) |
---|
275 | ((srfi4-test-send-recv vsize MPI:send-u8vector MPI:receive-u8vector u8vector-map list->u8vector) |
---|
276 | vintdata) |
---|
277 | ((srfi4-test-send-recv vsize MPI:send-s8vector MPI:receive-s8vector s8vector-map list->s8vector) |
---|
278 | vintdata) |
---|
279 | ((srfi4-test-send-recv vsize MPI:send-u16vector MPI:receive-u16vector u16vector-map list->u16vector) |
---|
280 | vintdata) |
---|
281 | ((srfi4-test-send-recv vsize MPI:send-s16vector MPI:receive-s16vector s16vector-map list->s16vector) |
---|
282 | vintdata) |
---|
283 | ((srfi4-test-send-recv vsize MPI:send-u32vector MPI:receive-u32vector u32vector-map list->u32vector) |
---|
284 | vintdata) |
---|
285 | ((srfi4-test-send-recv vsize MPI:send-s32vector MPI:receive-s32vector s32vector-map list->s32vector) |
---|
286 | vintdata) |
---|
287 | ((srfi4-test-send-recv vsize MPI:send-f32vector MPI:receive-f32vector f32vector-map list->f32vector) |
---|
288 | vflodata) |
---|
289 | ((srfi4-test-send-recv vsize MPI:send-f64vector MPI:receive-f64vector f64vector-map list->f64vector) |
---|
290 | vflodata) |
---|
291 | )) |
---|
292 | |
---|
293 | (begin |
---|
294 | (if (positive? myrank) |
---|
295 | (sleep myrank)) |
---|
296 | (print myrank ": hitting barrier") |
---|
297 | (MPI:barrier comm-world) |
---|
298 | (if (zero? myrank) |
---|
299 | (print "jumped barrier"))) |
---|
300 | |
---|
301 | ;; Broadcast |
---|
302 | (let* ((test-broadcast |
---|
303 | (lambda (bcast data) |
---|
304 | (if (zero? myrank) |
---|
305 | (print myrank ": broadcasting " data)) |
---|
306 | (let ((res (bcast data 0 comm-world))) |
---|
307 | (print myrank ": received " (if (blob? res) (blob->string res) res)) |
---|
308 | (test-assert (equal? res data)) |
---|
309 | (MPI:barrier comm-world))))) |
---|
310 | (test-broadcast MPI:broadcast-bytevector (string->blob "Hello!")) |
---|
311 | (test-broadcast MPI:broadcast-int 123456) |
---|
312 | (test-broadcast MPI:broadcast-flonum 3.141592654) |
---|
313 | (let ((intdata (list 12 45 78)) |
---|
314 | (flodata (list 3.14 2.718 0.578)) |
---|
315 | (srfi4-test-broadcast |
---|
316 | (lambda (bcast list->vector data) |
---|
317 | (test-broadcast bcast (list->vector data))))) |
---|
318 | (srfi4-test-broadcast MPI:broadcast-s8vector list->s8vector intdata) |
---|
319 | (srfi4-test-broadcast MPI:broadcast-u8vector list->u8vector intdata) |
---|
320 | (srfi4-test-broadcast MPI:broadcast-s16vector list->s16vector intdata) |
---|
321 | (srfi4-test-broadcast MPI:broadcast-u16vector list->u16vector intdata) |
---|
322 | (srfi4-test-broadcast MPI:broadcast-s32vector list->s32vector intdata) |
---|
323 | (srfi4-test-broadcast MPI:broadcast-u32vector list->u32vector intdata) |
---|
324 | (srfi4-test-broadcast MPI:broadcast-f32vector list->f32vector flodata) |
---|
325 | (srfi4-test-broadcast MPI:broadcast-f64vector list->f64vector flodata))) |
---|
326 | |
---|
327 | ;; Scatter |
---|
328 | (let* ((test-scatter |
---|
329 | (lambda (scatter vrange data) |
---|
330 | (if (zero? myrank) |
---|
331 | (print myrank ": scatter " (if (blob? data) (blob->string data) data))) |
---|
332 | (let ((res (scatter data 3 0 comm-world))) |
---|
333 | (print myrank ": received (scatter) " (if (blob? res) (blob->string res) res)) |
---|
334 | (test-assert |
---|
335 | (equal? res (vrange data (* myrank vsize) (+ vsize (* myrank vsize)))))) |
---|
336 | (MPI:barrier comm-world)))) |
---|
337 | (test-scatter MPI:scatter-bytevector blob-range (string->blob (string-concatenate vsdata))) |
---|
338 | (let ((srfi4-test-scatter |
---|
339 | (lambda (scatter vrange list->vector data) |
---|
340 | (test-scatter scatter vrange (list->vector (concatenate data)))))) |
---|
341 | (srfi4-test-scatter MPI:scatter-s8vector s8vector-range list->s8vector vintdata) |
---|
342 | (srfi4-test-scatter MPI:scatter-u8vector u8vector-range list->u8vector vintdata) |
---|
343 | (srfi4-test-scatter MPI:scatter-s16vector s16vector-range list->s16vector vintdata) |
---|
344 | (srfi4-test-scatter MPI:scatter-u16vector u16vector-range list->u16vector vintdata) |
---|
345 | (srfi4-test-scatter MPI:scatter-s32vector s32vector-range list->s32vector vintdata) |
---|
346 | (srfi4-test-scatter MPI:scatter-u32vector u32vector-range list->u32vector vintdata) |
---|
347 | (srfi4-test-scatter MPI:scatter-f32vector f32vector-range list->f32vector vflodata) |
---|
348 | (srfi4-test-scatter MPI:scatter-f64vector f64vector-range list->f64vector vflodata))) |
---|
349 | |
---|
350 | ;; Scatterv |
---|
351 | (let* ((test-scatterv |
---|
352 | (lambda (scatterv data) |
---|
353 | (if (zero? myrank) |
---|
354 | (print myrank ": scatterv " data)) |
---|
355 | (let ((res (scatterv data 0 comm-world))) |
---|
356 | (print myrank ": received (scatterv) " res) |
---|
357 | (test res (list-ref data myrank))) |
---|
358 | (MPI:barrier comm-world)))) |
---|
359 | (test-scatterv MPI:scatterv-bytevector (map string->blob vvsdata)) |
---|
360 | (let ((srfi4-test-scatterv |
---|
361 | (lambda (scatterv list->vector data) |
---|
362 | (test-scatterv scatterv (map list->vector data))))) |
---|
363 | (srfi4-test-scatterv MPI:scatterv-s8vector list->s8vector vvintdata) |
---|
364 | (srfi4-test-scatterv MPI:scatterv-u8vector list->u8vector vvintdata) |
---|
365 | (srfi4-test-scatterv MPI:scatterv-s16vector list->s16vector vvintdata) |
---|
366 | (srfi4-test-scatterv MPI:scatterv-u16vector list->u16vector vvintdata) |
---|
367 | (srfi4-test-scatterv MPI:scatterv-s32vector list->s32vector vvintdata) |
---|
368 | (srfi4-test-scatterv MPI:scatterv-u32vector list->u32vector vvintdata) |
---|
369 | (srfi4-test-scatterv MPI:scatterv-f32vector list->f32vector vvflodata) |
---|
370 | (srfi4-test-scatterv MPI:scatterv-f64vector list->f64vector vvflodata))) |
---|
371 | |
---|
372 | ;; Gather |
---|
373 | (let* ((test-gather |
---|
374 | (lambda (gather data total) |
---|
375 | (print myrank ": gather " (if (blob? data) (blob->string data) data)) |
---|
376 | (let ((res (gather data 3 0 comm-world))) |
---|
377 | (if (zero? myrank) |
---|
378 | (begin |
---|
379 | (print myrank ": received (gather) " (if (blob? res) (blob->string res) res)) |
---|
380 | (test res total)))) |
---|
381 | (MPI:barrier comm-world)))) |
---|
382 | (test-gather MPI:gather-bytevector (string->blob (list-ref vsdata myrank)) |
---|
383 | (string->blob (string-concatenate vsdata))) |
---|
384 | (test-gather MPI:gather-s8vector (list->s8vector (list-ref vintdata myrank)) |
---|
385 | (list->s8vector (concatenate vintdata))) |
---|
386 | (test-gather MPI:gather-u8vector (list->u8vector (list-ref vintdata myrank)) |
---|
387 | (list->u8vector (concatenate vintdata))) |
---|
388 | (test-gather MPI:gather-s16vector (list->s16vector (list-ref vintdata myrank)) |
---|
389 | (list->s16vector (concatenate vintdata))) |
---|
390 | (test-gather MPI:gather-u16vector (list->u16vector (list-ref vintdata myrank)) |
---|
391 | (list->u16vector (concatenate vintdata))) |
---|
392 | (test-gather MPI:gather-s32vector (list->s32vector (list-ref vintdata myrank)) |
---|
393 | (list->s32vector (concatenate vintdata))) |
---|
394 | (test-gather MPI:gather-u32vector (list->u32vector (list-ref vintdata myrank)) |
---|
395 | (list->u32vector (concatenate vintdata))) |
---|
396 | (test-gather MPI:gather-f32vector (list->f32vector (list-ref vflodata myrank)) |
---|
397 | (list->f32vector (concatenate vflodata))) |
---|
398 | (test-gather MPI:gather-f64vector (list->f64vector (list-ref vflodata myrank)) |
---|
399 | (list->f64vector (concatenate vflodata)))) |
---|
400 | |
---|
401 | |
---|
402 | ;; Gatherv |
---|
403 | (let* ((test-gatherv |
---|
404 | (lambda (gatherv data total) |
---|
405 | (print myrank ": gatherv " (if (blob? data) (blob->string data) data)) |
---|
406 | (let ((res (gatherv data 0 comm-world))) |
---|
407 | (if (zero? myrank) |
---|
408 | (begin |
---|
409 | (print myrank ": received (gatherv) " |
---|
410 | (map (lambda (x) (if (blob? x) (blob->string x) x)) res)) |
---|
411 | (test res total)))) |
---|
412 | (MPI:barrier comm-world)))) |
---|
413 | (test-gatherv MPI:gatherv-bytevector (string->blob (list-ref vvsdata myrank)) |
---|
414 | (map string->blob vvsdata)) |
---|
415 | (test-gatherv MPI:gatherv-s8vector (list->s8vector (list-ref vvintdata myrank)) |
---|
416 | (map list->s8vector vvintdata)) |
---|
417 | (test-gatherv MPI:gatherv-u8vector (list->u8vector (list-ref vvintdata myrank)) |
---|
418 | (map list->u8vector vvintdata)) |
---|
419 | (test-gatherv MPI:gatherv-s16vector (list->s16vector (list-ref vvintdata myrank)) |
---|
420 | (map list->s16vector vvintdata)) |
---|
421 | (test-gatherv MPI:gatherv-u16vector (list->u16vector (list-ref vvintdata myrank)) |
---|
422 | (map list->u16vector vvintdata)) |
---|
423 | (test-gatherv MPI:gatherv-s32vector (list->s32vector (list-ref vvintdata myrank)) |
---|
424 | (map list->s32vector vvintdata)) |
---|
425 | (test-gatherv MPI:gatherv-u32vector (list->u32vector (list-ref vvintdata myrank)) |
---|
426 | (map list->u32vector vvintdata)) |
---|
427 | (test-gatherv MPI:gatherv-f32vector (list->f32vector (list-ref vvflodata myrank)) |
---|
428 | (map list->f32vector vvflodata)) |
---|
429 | (test-gatherv MPI:gatherv-f64vector (list->f64vector (list-ref vvflodata myrank)) |
---|
430 | (map list->f64vector vvflodata))) |
---|
431 | |
---|
432 | |
---|
433 | ;; Gather to all |
---|
434 | (let* ((test-allgather |
---|
435 | (lambda (allgather data total) |
---|
436 | (print myrank ": allgather " data) |
---|
437 | (let ((res (allgather data 0 comm-world))) |
---|
438 | (print myrank ": received (allgather) " |
---|
439 | (map (lambda (x) (if (blob? x) (blob->string x) x)) res)) |
---|
440 | (test res total)) |
---|
441 | (MPI:barrier comm-world)))) |
---|
442 | (test-allgather MPI:allgather-bytevector (string->blob (list-ref vvsdata myrank)) |
---|
443 | (map string->blob vvsdata)) |
---|
444 | (test-allgather MPI:allgather-s8vector (list->s8vector (list-ref vvintdata myrank)) |
---|
445 | (map list->s8vector vvintdata)) |
---|
446 | (test-allgather MPI:allgather-u8vector (list->u8vector (list-ref vvintdata myrank)) |
---|
447 | (map list->u8vector vvintdata)) |
---|
448 | (test-allgather MPI:allgather-s16vector (list->s16vector (list-ref vvintdata myrank)) |
---|
449 | (map list->s16vector vvintdata)) |
---|
450 | (test-allgather MPI:allgather-u16vector (list->u16vector (list-ref vvintdata myrank)) |
---|
451 | (map list->u16vector vvintdata)) |
---|
452 | (test-allgather MPI:allgather-s32vector (list->s32vector (list-ref vvintdata myrank)) |
---|
453 | (map list->s32vector vvintdata)) |
---|
454 | (test-allgather MPI:allgather-u32vector (list->u32vector (list-ref vvintdata myrank)) |
---|
455 | (map list->u32vector vvintdata)) |
---|
456 | (test-allgather MPI:allgather-f32vector (list->f32vector (list-ref vvflodata myrank)) |
---|
457 | (map list->f32vector vvflodata)) |
---|
458 | (test-allgather MPI:allgather-f64vector (list->f64vector (list-ref vvflodata myrank)) |
---|
459 | (map list->f64vector vvflodata))) |
---|
460 | |
---|
461 | |
---|
462 | ;; Reduce |
---|
463 | (let* ((test-reduce |
---|
464 | (lambda (reducefun reduceops data) |
---|
465 | (for-each (lambda (op) |
---|
466 | (print myrank ": reduce") |
---|
467 | (let ((res (reducefun data op 0 comm-world))) |
---|
468 | (if (zero? myrank) |
---|
469 | (begin |
---|
470 | (print myrank ": the result of reduction " op " is " res) |
---|
471 | (test-assert res) |
---|
472 | )) |
---|
473 | (MPI:barrier comm-world) |
---|
474 | )) |
---|
475 | reduceops) |
---|
476 | (MPI:barrier comm-world)))) |
---|
477 | (test-reduce MPI:reduce-int |
---|
478 | (list MPI:i_max MPI:i_min MPI:i_sum MPI:i_prod ) |
---|
479 | (+ 1 myrank)) |
---|
480 | (test-reduce MPI:reduce-flonum |
---|
481 | (list MPI:f_max MPI:f_min MPI:f_sum MPI:f_prod ) |
---|
482 | (+ 1 myrank)) |
---|
483 | (test-reduce MPI:reduce-s8vector |
---|
484 | (list MPI:i_max MPI:i_min MPI:i_sum MPI:i_prod ) |
---|
485 | (s8vector (* 2 myrank) (+ 1 (* 2 myrank)) (+ 2 (* 2 myrank)))) |
---|
486 | (test-reduce MPI:reduce-u8vector |
---|
487 | (list MPI:i_max MPI:i_min MPI:i_sum MPI:i_prod ) |
---|
488 | (u8vector (* 2 myrank) (+ 1 (* 2 myrank)) (+ 2 (* 2 myrank)))) |
---|
489 | (test-reduce MPI:reduce-s16vector |
---|
490 | (list MPI:i_max MPI:i_min MPI:i_sum MPI:i_prod ) |
---|
491 | (s16vector (* 2 myrank) (+ 1 (* 2 myrank)) (+ 2 (* 2 myrank)))) |
---|
492 | (test-reduce MPI:reduce-u16vector |
---|
493 | (list MPI:i_max MPI:i_min MPI:i_sum MPI:i_prod ) |
---|
494 | (u16vector (* 2 myrank) (+ 1 (* 2 myrank)) (+ 2 (* 2 myrank)))) |
---|
495 | (test-reduce MPI:reduce-s32vector |
---|
496 | (list MPI:i_max MPI:i_min MPI:i_sum MPI:i_prod ) |
---|
497 | (s32vector (* 2 myrank) (+ 1 (* 2 myrank)) (+ 2 (* 2 myrank)))) |
---|
498 | (test-reduce MPI:reduce-u32vector |
---|
499 | (list MPI:i_max MPI:i_min MPI:i_sum MPI:i_prod ) |
---|
500 | (u32vector (* 2 myrank) (+ 1 (* 2 myrank)) (+ 2 (* 2 myrank)))) |
---|
501 | (test-reduce MPI:reduce-f32vector |
---|
502 | (list MPI:f_max MPI:f_min MPI:f_sum MPI:f_prod ) |
---|
503 | (f32vector (* 2 myrank) (+ 0.1 (* 2 myrank)) (+ 0.2 (* 2 myrank)))) |
---|
504 | (test-reduce MPI:reduce-f64vector |
---|
505 | (list MPI:f_max MPI:f_min MPI:f_sum MPI:f_prod ) |
---|
506 | (f64vector (* 2 myrank) (+ 0.1 (* 2 myrank)) (+ 0.2 (* 2 myrank)))) |
---|
507 | ) |
---|
508 | |
---|
509 | ;; Reduce all |
---|
510 | (let* ((test-allreduce |
---|
511 | (lambda (allreducefun reduceop data) |
---|
512 | (print myrank ": data is " data) |
---|
513 | (let ((res (allreducefun data reduceop comm-world))) |
---|
514 | (MPI:barrier comm-world) |
---|
515 | (print myrank ": the result of reduction " reduceop " is " res) |
---|
516 | (test-assert res) |
---|
517 | (MPI:barrier comm-world))))) |
---|
518 | (test-allreduce MPI:allreduce-int MPI:i_sum (+ 1 myrank)) |
---|
519 | (test-allreduce MPI:allreduce-flonum MPI:f_prod (+ 1.0 myrank)) |
---|
520 | (test-allreduce MPI:allreduce-s8vector MPI:i_sum |
---|
521 | (s8vector (* 2 myrank) (+ 1 (* 2 myrank)) (+ 2 (* 2 myrank)))) |
---|
522 | (test-allreduce MPI:allreduce-u8vector MPI:i_sum |
---|
523 | (u8vector (* 2 myrank) (+ 1 (* 2 myrank)) (+ 2 (* 2 myrank)))) |
---|
524 | (test-allreduce MPI:allreduce-s16vector MPI:i_sum |
---|
525 | (s16vector (* 2 myrank) (+ 1 (* 2 myrank)) (+ 2 (* 2 myrank)))) |
---|
526 | (test-allreduce MPI:allreduce-u16vector MPI:i_sum |
---|
527 | (u16vector (* 2 myrank) (+ 1 (* 2 myrank)) (+ 2 (* 2 myrank)))) |
---|
528 | (test-allreduce MPI:allreduce-s32vector MPI:i_sum |
---|
529 | (s32vector (* 2 myrank) (+ 1 (* 2 myrank)) (+ 2 (* 2 myrank)))) |
---|
530 | (test-allreduce MPI:allreduce-u32vector MPI:i_sum |
---|
531 | (u32vector (* 2 myrank) (+ 1 (* 2 myrank)) (+ 2 (* 2 myrank)))) |
---|
532 | (test-allreduce MPI:allreduce-f32vector MPI:f_sum |
---|
533 | (f32vector (* 2 myrank) (+ 0.1 (* 2 myrank)) (+ 0.2 (* 2 myrank)))) |
---|
534 | (test-allreduce MPI:allreduce-f64vector MPI:f_sum |
---|
535 | (f64vector (* 2 myrank) (+ 0.1 (* 2 myrank)) (+ 0.2 (* 2 myrank))))) |
---|
536 | |
---|
537 | ;; Scan |
---|
538 | (let* ((test-scan |
---|
539 | (lambda (scanfun reduceop data) |
---|
540 | (print myrank ": data is " data) |
---|
541 | (let ((res (scanfun data reduceop comm-world))) |
---|
542 | (MPI:barrier comm-world) |
---|
543 | (print myrank ": the result of scan " reduceop " is " res) |
---|
544 | (test-assert res)) |
---|
545 | (MPI:barrier comm-world)))) |
---|
546 | (test-scan MPI:scan-int MPI:i_sum (+ 1 myrank)) |
---|
547 | (test-scan MPI:scan-flonum MPI:f_prod (+ 1.0 myrank)) |
---|
548 | (test-scan MPI:scan-s8vector MPI:i_sum |
---|
549 | (s8vector (* 2 myrank) (+ 1 (* 2 myrank)) (+ 2 (* 2 myrank)))) |
---|
550 | (test-scan MPI:scan-u8vector MPI:i_sum |
---|
551 | (u8vector (* 2 myrank) (+ 1 (* 2 myrank)) (+ 2 (* 2 myrank)))) |
---|
552 | (test-scan MPI:scan-s16vector MPI:i_sum |
---|
553 | (s16vector (* 2 myrank) (+ 1 (* 2 myrank)) (+ 2 (* 2 myrank)))) |
---|
554 | (test-scan MPI:scan-u16vector MPI:i_sum |
---|
555 | (u16vector (* 2 myrank) (+ 1 (* 2 myrank)) (+ 2 (* 2 myrank)))) |
---|
556 | (test-scan MPI:scan-s32vector MPI:i_sum |
---|
557 | (s32vector (* 2 myrank) (+ 1 (* 2 myrank)) (+ 2 (* 2 myrank)))) |
---|
558 | (test-scan MPI:scan-u32vector MPI:i_sum |
---|
559 | (u32vector (* 2 myrank) (+ 1 (* 2 myrank)) (+ 2 (* 2 myrank)))) |
---|
560 | (test-scan MPI:scan-f32vector MPI:f_sum |
---|
561 | (f32vector (* 2 myrank) (+ 0.1 (* 2 myrank)) (+ 0.2 (* 2 myrank)))) |
---|
562 | (test-scan MPI:scan-f64vector MPI:f_sum |
---|
563 | (f64vector (* 2 myrank) (+ 0.1 (* 2 myrank)) (+ 0.2 (* 2 myrank))))) |
---|
564 | |
---|
565 | ;; Comm split |
---|
566 | (let ((send-in-comm |
---|
567 | (lambda (c init incr) |
---|
568 | (let ((rank-in-c (MPI:comm-rank c)) |
---|
569 | (size-of-c (MPI:comm-size c))) |
---|
570 | (if (zero? rank-in-c) |
---|
571 | (begin |
---|
572 | (print rank-in-c "[" myrank "]: sending " init) |
---|
573 | (MPI:send init 1 0 c) |
---|
574 | (let ((n (MPI:receive MPI:any-source MPI:any-tag c))) |
---|
575 | (print rank-in-c "[" myrank "]: received " n))) |
---|
576 | (let ((n (MPI:receive MPI:any-source MPI:any-tag c))) |
---|
577 | (let ((n1 (string->blob (string-append (blob->string n) incr)))) |
---|
578 | (print rank-in-c "[" myrank "]: received " n ", resending " n1) |
---|
579 | (MPI:send n1 (modulo (+ 1 rank-in-c) size-of-c) 0 c)))) |
---|
580 | (MPI:barrier comm-world))))) |
---|
581 | (let ((c (MPI:comm-split comm-world (modulo myrank 2) 0))) |
---|
582 | (if (zero? (modulo myrank 2)) |
---|
583 | (send-in-comm c (string->blob "aa") "a") |
---|
584 | (send-in-comm c (string->blob "bb") "b")))) |
---|
585 | |
---|
586 | ;; Cartesian topology |
---|
587 | (let ((cart (MPI:make-cart comm-world (u32vector 2 2) (u32vector 0 0) #t)) |
---|
588 | (test-dims-create |
---|
589 | (lambda (n hints) |
---|
590 | (print "make-dims " n " " hints " = " (MPI:make-dims n hints))))) |
---|
591 | (if (zero? myrank) |
---|
592 | (begin |
---|
593 | (print "ranks = " (map (lambda (x) (cons x (MPI:cart-rank cart x))) |
---|
594 | (list |
---|
595 | (u32vector 0 0) (u32vector 1 0) |
---|
596 | (u32vector 1 0) (u32vector 1 1)))) |
---|
597 | (print "coords = " (list-tabulate (MPI:comm-size cart) |
---|
598 | (lambda (n) (cons n (MPI:cart-coords cart n))))) |
---|
599 | (test-dims-create 60 (u32vector 0 0 0)) |
---|
600 | (test-dims-create 60 (u32vector 3 0 0)) |
---|
601 | (test-dims-create 60 (u32vector 0 4 0)) |
---|
602 | (test-dims-create 60 (u32vector 3 0 5)) |
---|
603 | ))) |
---|
604 | |
---|
605 | (MPI:barrier comm-world) |
---|
606 | |
---|
607 | ;; Wtime |
---|
608 | (print myrank ": wtime is " (MPI:wtime)) |
---|
609 | |
---|
610 | ) |
---|