1 | ;; |
---|
2 | ;; |
---|
3 | ;; Utility procedures for NEMO code generators. |
---|
4 | ;; |
---|
5 | ;; Copyright 2008-2012 Ivan Raikov and the Okinawa Institute of Science and Technology |
---|
6 | ;; |
---|
7 | ;; This program is free software: you can redistribute it and/or |
---|
8 | ;; modify it under the terms of the GNU General Public License as |
---|
9 | ;; published by the Free Software Foundation, either version 3 of the |
---|
10 | ;; License, or (at your option) any later version. |
---|
11 | ;; |
---|
12 | ;; This program is distributed in the hope that it will be useful, but |
---|
13 | ;; WITHOUT ANY WARRANTY; without even the implied warranty of |
---|
14 | ;; MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU |
---|
15 | ;; General Public License for more details. |
---|
16 | ;; |
---|
17 | ;; A full copy of the GPL license can be found at |
---|
18 | ;; <http://www.gnu.org/licenses/>. |
---|
19 | ;; |
---|
20 | |
---|
21 | (module nemo-utils |
---|
22 | |
---|
23 | (lookup-def enum-bnds enum-freevars sum |
---|
24 | if-convert let-enum let-elim let-lift |
---|
25 | s+ sw+ slp nl spaces ppf |
---|
26 | transitions-graph state-conseqs |
---|
27 | differentiate simplify distribute) |
---|
28 | |
---|
29 | (import scheme chicken data-structures srfi-1 srfi-13) |
---|
30 | |
---|
31 | (require-extension matchable strictly-pretty |
---|
32 | varsubst digraph nemo-core) |
---|
33 | |
---|
34 | |
---|
35 | (define (lookup-def k lst . rest) |
---|
36 | (let-optionals rest ((default #f)) |
---|
37 | (let ((k (->string k))) |
---|
38 | (let recur ((kv #f) (lst lst)) |
---|
39 | (if (or kv (null? lst)) |
---|
40 | (if (not kv) default |
---|
41 | (match kv ((k v) v) (else (cdr kv)))) |
---|
42 | (let ((kv (car lst))) |
---|
43 | (recur (and (string=? (->string (car kv)) k) kv) |
---|
44 | (cdr lst)) )))))) |
---|
45 | |
---|
46 | |
---|
47 | (define (enum-bnds expr ax) |
---|
48 | (match expr |
---|
49 | (('if . es) (fold enum-bnds ax es)) |
---|
50 | (('let bnds body) (enum-bnds body (append (map car bnds) (fold enum-bnds ax (map cadr bnds))))) |
---|
51 | ((s . es) (if (symbol? s) (fold enum-bnds ax es) ax)) |
---|
52 | (else ax))) |
---|
53 | |
---|
54 | |
---|
55 | (define (enum-freevars expr bnds ax) |
---|
56 | (match expr |
---|
57 | (('if . es) |
---|
58 | (fold (lambda (x ax) (enum-freevars x bnds ax)) ax es)) |
---|
59 | (('let lbnds body) |
---|
60 | (let ((bnds1 (append (map first lbnds) bnds))) |
---|
61 | (enum-freevars body bnds1 |
---|
62 | (fold (lambda (x ax) (enum-freevars x bnds ax)) ax |
---|
63 | (map second lbnds))))) |
---|
64 | ((s . es) (if (symbol? s) (fold (lambda (x ax) (enum-freevars x bnds ax)) ax es) ax)) |
---|
65 | (id (if (and (symbol? id) (not (member id bnds))) (cons id ax) ax)))) |
---|
66 | |
---|
67 | |
---|
68 | (define (sum lst) |
---|
69 | (if (null? lst) lst |
---|
70 | (match lst |
---|
71 | ((x) x) |
---|
72 | ((x y) `(+ ,x ,y)) |
---|
73 | ((x y . rest) `(+ (+ ,x ,y) ,(sum rest))) |
---|
74 | ((x . rest) `(+ ,x ,(sum rest)))))) |
---|
75 | |
---|
76 | |
---|
77 | (define (if-convert expr) |
---|
78 | (match expr |
---|
79 | (('if c t e) |
---|
80 | (let ((r (gensym "if"))) |
---|
81 | `(let ((,r (if ,(if-convert c) ,(if-convert t) ,(if-convert e)))) |
---|
82 | ,r))) |
---|
83 | (('let bs e) |
---|
84 | `(let ,(map (lambda (b) `(,(car b) ,(if-convert (cadr b)))) bs) ,(if-convert e))) |
---|
85 | ((f . es) |
---|
86 | (cons f (map if-convert es))) |
---|
87 | ((? atom? ) expr))) |
---|
88 | |
---|
89 | |
---|
90 | (define (let-enum expr ax) |
---|
91 | (match expr |
---|
92 | (('let ((x ('if c t e))) y) |
---|
93 | (let ((ax (fold let-enum ax (list c )))) |
---|
94 | (if (eq? x y) (append ax (list (list x `(if ,c ,t ,e)))) ax))) |
---|
95 | |
---|
96 | (('let bnds body) (append ax bnds)) |
---|
97 | |
---|
98 | (('if c t e) (let-enum c ax)) |
---|
99 | |
---|
100 | ((f . es) (fold let-enum ax es)) |
---|
101 | |
---|
102 | (else ax))) |
---|
103 | |
---|
104 | |
---|
105 | (define (let-elim expr) |
---|
106 | (match expr |
---|
107 | (('let ((x ('if c t e))) y) |
---|
108 | (if (eq? x y) y expr)) |
---|
109 | |
---|
110 | (('let bnds body) body) |
---|
111 | |
---|
112 | (('if c t e) `(if ,(let-elim c) ,(let-lift t) ,(let-lift e))) |
---|
113 | |
---|
114 | ((f . es) `(,f . ,(map let-elim es))) |
---|
115 | |
---|
116 | (else expr))) |
---|
117 | |
---|
118 | |
---|
119 | (define (let-lift expr) |
---|
120 | (define (fbnds bnds) |
---|
121 | (let ((bnds0 |
---|
122 | (fold (lambda (b ax) |
---|
123 | (let ((bexpr (cadr b))) |
---|
124 | (match bexpr |
---|
125 | (('let bnds expr) (append bnds ax)) |
---|
126 | (else (append (let-enum bexpr (list)) ax))))) |
---|
127 | '() bnds))) |
---|
128 | bnds0)) |
---|
129 | (let ((expr1 |
---|
130 | (match expr |
---|
131 | (('let bnds expr) |
---|
132 | (let ((bnds0 (fbnds bnds)) |
---|
133 | (expr1 |
---|
134 | `(let ,(map (lambda (b) (list (car b) (let-elim (cadr b)))) bnds) |
---|
135 | ,(let-lift expr)))) |
---|
136 | (if (null? bnds0) expr1 `(let ,bnds0 ,expr1)))) |
---|
137 | |
---|
138 | (else |
---|
139 | (let ((bnds (let-enum expr (list)))) |
---|
140 | (if (null? bnds) |
---|
141 | (let-elim expr) |
---|
142 | (let ((bnds0 (fbnds bnds)) |
---|
143 | (expr1 `(let ,(map (lambda (b) (list (car b) (let-elim (cadr b)))) bnds) |
---|
144 | ,(let-elim expr)))) |
---|
145 | (if (null? bnds0) expr1 `(let ,bnds0 ,expr1)))))) |
---|
146 | ))) |
---|
147 | (if (equal? expr expr1) expr1 |
---|
148 | (let-lift expr1)))) |
---|
149 | |
---|
150 | |
---|
151 | (define (s+ . lst) (string-concatenate (map ->string lst))) |
---|
152 | (define (sw+ lst) (string-intersperse (filter-map (lambda (x) (and x (->string x))) lst) " ")) |
---|
153 | (define (slp p lst) (string-intersperse (map ->string lst) p)) |
---|
154 | (define nl "\n") |
---|
155 | (define (spaces n) (list->string (list-tabulate n (lambda (x) #\space)))) |
---|
156 | |
---|
157 | (define (ppf indent . lst) |
---|
158 | (let ((sp (spaces indent))) |
---|
159 | (for-each (lambda (x) |
---|
160 | (and x (match x |
---|
161 | ((i . x1) (if (and (number? i) (positive? i)) |
---|
162 | (for-each (lambda (x) (ppf (+ indent i) x)) x1) |
---|
163 | (print sp (sw+ x)))) |
---|
164 | (else (print sp (if (list? x) (sw+ x) x)))))) |
---|
165 | lst))) |
---|
166 | |
---|
167 | |
---|
168 | (define (transitions-graph n open transitions conserve state-name) |
---|
169 | (let* ((subst-convert (subst-driver (lambda (x) (and (symbol? x) x)) |
---|
170 | nemo:binding? identity nemo:bind nemo:subst-term)) |
---|
171 | (g (make-digraph n (string-append (->string n) " transitions graph"))) |
---|
172 | (add-node! (g 'add-node!)) |
---|
173 | (add-edge! (g 'add-edge!)) |
---|
174 | (out-edges (g 'out-edges)) |
---|
175 | (in-edges (g 'in-edges)) |
---|
176 | (node-info (g 'node-info)) |
---|
177 | (node-list (let loop ((lst (list)) (tlst transitions)) |
---|
178 | (if (null? tlst) (delete-duplicates lst eq?) |
---|
179 | (match (car tlst) |
---|
180 | (('-> (and (? symbol?) s0) (and (? symbol?) s1) rate-expr) |
---|
181 | (loop (cons* s0 s1 lst) (cdr tlst))) |
---|
182 | (((and (? symbol?) s0) '-> (and (? symbol? s1)) rate-expr) |
---|
183 | (loop (cons* s0 s1 lst) (cdr tlst))) |
---|
184 | (('<-> (and (? symbol?) s0) (and (? symbol?) s1) rate-expr1 rate-expr2) |
---|
185 | (loop (cons* s0 s1 lst) (cdr tlst))) |
---|
186 | (((and (? symbol?) s0) 'M-> (and (? symbol? s1)) rate-expr1 rate-expr2) |
---|
187 | (loop (cons* s0 s1 lst) (cdr tlst))) |
---|
188 | (else |
---|
189 | (nemo:error 'state-eqs ": invalid transition equation " |
---|
190 | (car tlst) " in state complex " n)) |
---|
191 | (else (loop lst (cdr tlst))))))) |
---|
192 | (node-ids (list-tabulate (length node-list) identity)) |
---|
193 | (name->id-map (zip node-list node-ids)) |
---|
194 | (node-subs (fold (lambda (s ax) (subst-extend s (state-name n s) ax)) subst-empty node-list))) |
---|
195 | ;; insert state nodes in the dependency graph |
---|
196 | (for-each (lambda (i n) (add-node! i n)) node-ids node-list) |
---|
197 | (let* ((nodes ((g 'nodes))) |
---|
198 | (conserve (and (pair? conserve) (car conserve))) |
---|
199 | ;; if a conservation equation is present, we eliminate one |
---|
200 | ;; transition equation from the system |
---|
201 | (cvars (and conserve (enum-freevars (third conserve) '() '()))) |
---|
202 | (cnode (and conserve |
---|
203 | (find (lambda (s) |
---|
204 | (let ((n (second s))) |
---|
205 | (and (member n cvars) (not (eq? n open))))) |
---|
206 | nodes))) |
---|
207 | (cname (and cnode (second cnode))) |
---|
208 | (cnexpr (and cnode |
---|
209 | (let* ((cvars1 (filter-map (lambda (n) (and (not (eq? n cname)) n)) cvars)) |
---|
210 | (sumvar (gensym "sum"))) |
---|
211 | `(let ((,sumvar ,(sum cvars1))) (- ,(first conserve) ,sumvar))))) |
---|
212 | (add-tredge (lambda (s0 s1 rexpr1 rexpr2) |
---|
213 | (let* ((i (car (alist-ref s0 name->id-map))) |
---|
214 | (j (car (alist-ref s1 name->id-map))) |
---|
215 | (x0 (if (and cnode (eq? s0 cname)) cnexpr s0)) |
---|
216 | (x1 (if (and cnode (eq? s1 cname)) cnexpr s1)) |
---|
217 | (ij-expr `(* ,(subst-convert x0 node-subs) |
---|
218 | ,(subst-convert rexpr1 node-subs))) |
---|
219 | (ji-expr (and rexpr2 |
---|
220 | `(* ,(subst-convert x1 node-subs) |
---|
221 | ,(subst-convert rexpr2 node-subs))))) |
---|
222 | (add-edge! (list i j ij-expr)) |
---|
223 | (if rexpr2 (add-edge! (list j i ji-expr))))))) |
---|
224 | ;; create rate edges in the graph |
---|
225 | (for-each (lambda (e) |
---|
226 | (match e |
---|
227 | (('-> s0 s1 rexpr) (add-tredge s0 s1 rexpr #f)) |
---|
228 | ((s0 '-> s1 rexpr) (add-tredge s0 s1 rexpr #f)) |
---|
229 | (('<-> s0 s1 rexpr1 rexpr2) (add-tredge s0 s1 rexpr1 rexpr2)) |
---|
230 | ((s0 '<-> s1 rexpr1 rexpr2) (add-tredge s0 s1 rexpr1 rexpr2)) |
---|
231 | )) |
---|
232 | transitions) |
---|
233 | |
---|
234 | (list g cnode node-subs)))) |
---|
235 | |
---|
236 | |
---|
237 | (define (state-conseqs n transitions conseqs state-name) |
---|
238 | (let* ((subst-convert (subst-driver (lambda (x) (and (symbol? x) x)) |
---|
239 | nemo:binding? identity nemo:bind nemo:subst-term)) |
---|
240 | (state-list (let loop ((lst (list)) (tlst transitions)) |
---|
241 | (if (null? tlst) (delete-duplicates lst eq?) |
---|
242 | (match (car tlst) |
---|
243 | (('-> (and (? symbol?) s0) (and (? symbol?) s1) rate-expr) |
---|
244 | (loop (cons* s0 s1 lst) (cdr tlst))) |
---|
245 | (((and (? symbol?) s0) '-> (and (? symbol? s1)) rate-expr) |
---|
246 | (loop (cons* s0 s1 lst) (cdr tlst))) |
---|
247 | (('<-> (and (? symbol?) s0) (and (? symbol?) s1) rate-expr1 rate-expr2) |
---|
248 | (loop (cons* s0 s1 lst) (cdr tlst))) |
---|
249 | (((and (? symbol?) s0) 'M-> (and (? symbol? s1)) rate-expr1 rate-expr2) |
---|
250 | (loop (cons* s0 s1 lst) (cdr tlst))) |
---|
251 | (else |
---|
252 | (nemo:error 'nemo:state-conseq ": invalid transition equation " |
---|
253 | (car tlst) " in state complex " n)) |
---|
254 | (else (loop lst (cdr tlst))))))) |
---|
255 | (state-subs (fold (lambda (s ax) (subst-extend s (state-name n s) ax)) subst-empty state-list)) |
---|
256 | (conseqs1 (map (lambda (conseq) (match conseq ((i '= . expr) `(,i = . ,(subst-convert expr state-subs))))) |
---|
257 | conseqs))) |
---|
258 | (list n conseqs1))) |
---|
259 | |
---|
260 | ;; `(+ - * / pow neg abs atan asin acos sin cos exp ln |
---|
261 | ;; sqrt tan cosh sinh tanh hypot gamma lgamma log10 log2 log1p ldexp cube |
---|
262 | ;; > < <= >= = and or round ceiling floor max min |
---|
263 | ;; fpvector-ref)) |
---|
264 | |
---|
265 | (define LOG10E 0.434294481903252) |
---|
266 | (define LOG2E 1.44269504088896) |
---|
267 | |
---|
268 | |
---|
269 | (define (differentiate fenv x t) |
---|
270 | (define subst-convert |
---|
271 | (subst-driver |
---|
272 | (lambda (x) (and (symbol? x) x)) |
---|
273 | nemo:binding? identity nemo:bind nemo:subst-term)) |
---|
274 | |
---|
275 | (cond ((number? t) 0.0) |
---|
276 | ((symbol? t) (cond ((string=? (->string x) (->string t)) 1.0) |
---|
277 | (else 0.0))) |
---|
278 | (else (match t |
---|
279 | (('neg u) `(neg ,(differentiate fenv x u))) |
---|
280 | |
---|
281 | (('+ u v) `(+ ,(differentiate fenv x u) |
---|
282 | ,(differentiate fenv x v))) |
---|
283 | |
---|
284 | (('- u v) `(- ,(differentiate fenv x u) |
---|
285 | ,(differentiate fenv x v))) |
---|
286 | |
---|
287 | (('* (and u (? number?)) v) `(* ,u ,(differentiate fenv x v))) |
---|
288 | (('* v (and u (? number?))) `(* ,u ,(differentiate fenv x v))) |
---|
289 | |
---|
290 | (('* u v) `(+ (* ,(differentiate fenv x u) ,v) |
---|
291 | (* ,u ,(differentiate fenv x v)))) |
---|
292 | |
---|
293 | (('/ u v) `(/ (- (* ,(differentiate fenv x u) ,v) |
---|
294 | (* ,u ,(differentiate fenv x v))) |
---|
295 | (pow ,v 2.0))) |
---|
296 | |
---|
297 | (('cube u) (differentiate fenv x `(pow ,u 3.0))) |
---|
298 | |
---|
299 | (('pow u n) (chain fenv x u `(* ,n (pow ,u (- ,n 1.0))))) |
---|
300 | |
---|
301 | (('sqrt u) (chain fenv x u `(/ 1.0 (* 2.0 (sqrt ,u))))) |
---|
302 | |
---|
303 | (('exp u) (chain fenv x u `(exp ,u))) |
---|
304 | |
---|
305 | (('log u) (chain fenv x u `(/ 1.0 ,u))) |
---|
306 | |
---|
307 | (('log10 u) (chain fenv x u `(* ,LOG10E (/ ,(differentiate fenv x u) ,u)))) |
---|
308 | |
---|
309 | (('log2 u) (chain fenv x u `(* ,LOG2E (/ ,(differentiate fenv x u) ,u)))) |
---|
310 | |
---|
311 | (('log1p u) (differentiate fenv x `(log (+ 1.0 ,u)))) |
---|
312 | |
---|
313 | (('ldexp u n) (differentiate fenv x `(* ,u ,(expt 2 n)))) |
---|
314 | |
---|
315 | (('sin u) (chain fenv x u `(cos ,u))) |
---|
316 | |
---|
317 | (('cos u) (chain fenv x u `(neg (sin ,u)))) |
---|
318 | |
---|
319 | (('tan u) (differentiate fenv x `(* (sin ,u) (/ 1.0 (cos ,u))))) |
---|
320 | |
---|
321 | (('asin u) (chain fenv x u `(/ 1.0 (sqrt (- 1.0 (pow ,u 2.0)))))) |
---|
322 | |
---|
323 | (('acos u) (chain fenv x u `(/ (neg 1.0) (sqrt (- 1.0 (pow ,u 2.0)))))) |
---|
324 | |
---|
325 | (('atan u) (chain fenv x u `(/ 1.0 (+ 1.0 (pow ,u 2.0))))) |
---|
326 | |
---|
327 | (('sinh u) (differentiate fenv x `(/ (- (exp ,u) (exp (neg ,u))) 2.0))) |
---|
328 | |
---|
329 | (('cosh u) (differentiate fenv x `(/ (+ (exp ,u) (exp (neg ,u))) 2.0))) |
---|
330 | |
---|
331 | (('tanh u) (differentiate fenv x `(/ (sinh ,u) (cosh ,u)))) |
---|
332 | |
---|
333 | (('let bnds body) (let* ((body1 (subst-convert body bnds))) |
---|
334 | (differentiate fenv x body1))) |
---|
335 | |
---|
336 | ((op . us) (let ((fv (enum-freevars t '() '()))) |
---|
337 | (if (member x fv) |
---|
338 | (cond ((lookup-def op fenv) => |
---|
339 | (lambda (fs) |
---|
340 | (cond ((and (pair? fs) (pair? us)) |
---|
341 | `(+ . ,(map (lambda (fu u) (chain fenv x u `(,fu ,u))) |
---|
342 | fs us))) |
---|
343 | (else (chain fenv x us `(,fs ,us)))))) |
---|
344 | (else #f)) |
---|
345 | 0.0))) |
---|
346 | |
---|
347 | (else #f))))) |
---|
348 | |
---|
349 | (define (chain fenv x t u) |
---|
350 | (if (symbol? t) u |
---|
351 | `(* ,(differentiate fenv x t) ,u))) |
---|
352 | |
---|
353 | |
---|
354 | (define (simplify t) |
---|
355 | (match t |
---|
356 | (('neg 0.0) 0.0) |
---|
357 | |
---|
358 | (('+ 0.0 0.0) 0.0) |
---|
359 | (('+ 0.0 t1) t1) |
---|
360 | (('+ t1 0.0) t1) |
---|
361 | (('+ t1 ('neg t2)) `(- ,t1 ,t2)) |
---|
362 | (('+ (and t1 (? number?)) (and t2 (? number?))) (+ t1 t2)) |
---|
363 | |
---|
364 | (('- 0.0 0.0) 0.0) |
---|
365 | (('- 0.0 t1) `(neg ,t1)) |
---|
366 | (('- t1 0.0) t1) |
---|
367 | (('neg ('neg t1)) t1) |
---|
368 | (('- (and t1 (? number?)) (and t2 (? number?))) (- t1 t2)) |
---|
369 | |
---|
370 | (('* 0.0 0.0) 0.0) |
---|
371 | (('* 0.0 t1) 0.0) |
---|
372 | (('* t1 0.0) 0.0) |
---|
373 | (('* 1.0 t1) t1) |
---|
374 | (('* t1 1.0) t1) |
---|
375 | (('* ('neg t1) ('neg t2)) `(* ,t1 ,t2)) |
---|
376 | (('* (and t1 (? number?)) (and t2 (? number?))) (* t1 t2)) |
---|
377 | |
---|
378 | (('/ 0.0 t1) 0.0) |
---|
379 | |
---|
380 | (('pow t1 0.0) 1.0) |
---|
381 | (('pow t1 1.0) t1) |
---|
382 | (('pow (and t1 (? number?)) (and t2 (? number?))) (expt t1 t2)) |
---|
383 | |
---|
384 | (('let () body) (simplify body)) |
---|
385 | |
---|
386 | (('let bnds body) |
---|
387 | `(let ,(map (match-lambda ((v b) `(v ,(simplify b))) |
---|
388 | (else #f)) bnds) |
---|
389 | ,(simplify body))) |
---|
390 | |
---|
391 | ((op . ts) |
---|
392 | `(,op . ,(map simplify ts))) |
---|
393 | |
---|
394 | (else t))) |
---|
395 | |
---|
396 | (define (distribute t) |
---|
397 | (match t |
---|
398 | |
---|
399 | (((and (or '+ '- '* '/) op) x y) |
---|
400 | `(,op ,(distribute x) ,(distribute y))) |
---|
401 | |
---|
402 | (((and (or '+ '- '* '/) op) x y z) |
---|
403 | `(,op ,(distribute x) (,op ,(distribute y) ,(distribute z)))) |
---|
404 | |
---|
405 | (((and (or '+ '- '* '/) op) . lst) |
---|
406 | (let* ((n (length lst)) |
---|
407 | (n/2 (inexact->exact (round (/ n 2))))) |
---|
408 | `(,op ,(distribute `(,op . ,(take lst n/2))) |
---|
409 | ,(distribute `(,op . ,(drop lst n/2 )))))) |
---|
410 | |
---|
411 | (('let bnds body) |
---|
412 | `(let ,(map (match-lambda ((v b) `(,v ,(distribute b))) |
---|
413 | (else #f)) bnds) |
---|
414 | ,(distribute body))) |
---|
415 | |
---|
416 | ((op . ts) |
---|
417 | `(,op . ,(map distribute ts))) |
---|
418 | |
---|
419 | (else t))) |
---|
420 | |
---|
421 | |
---|
422 | ) |
---|