source: project/release/4/nemo/trunk/nemo-utils.scm @ 15425

Last change on this file since 15425 was 15425, checked in by Ivan Raikov, 12 years ago

added symbolic differentiation to nemo-utils

File size: 10.1 KB
Line 
1;;       
2;;
3;; Utility procedures for NEMO code generators.
4;;
5;; Copyright 2008-2009 Ivan Raikov and the Okinawa Institute of Science and Technology
6;;
7;; This program is free software: you can redistribute it and/or
8;; modify it under the terms of the GNU General Public License as
9;; published by the Free Software Foundation, either version 3 of the
10;; License, or (at your option) any later version.
11;;
12;; This program is distributed in the hope that it will be useful, but
13;; WITHOUT ANY WARRANTY; without even the implied warranty of
14;; MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
15;; General Public License for more details.
16;;
17;; A full copy of the GPL license can be found at
18;; <http://www.gnu.org/licenses/>.
19;;
20
21(module nemo-utils
22       
23 (lookup-def enum-bnds enum-freevars sum 
24             if-convert let-enum let-elim let-lift
25             s+ sw+ sl\ nl spaces ppf
26             transitions-graph state-lineqs
27             differentiate simplify )
28
29 (import scheme chicken data-structures srfi-1 srfi-13)
30
31 (require-extension matchable strictly-pretty 
32                    varsubst digraph nemo-core)
33
34
35(define (lookup-def k lst . rest)
36  (let-optionals rest ((default #f))
37      (let ((kv (assoc k lst)))
38        (if (not kv) default
39            (match kv ((k v) v) (else (cdr kv)))))))
40
41
42(define (enum-bnds expr ax)
43  (match expr 
44         (('if . es)        (fold enum-bnds ax es))
45         (('let bnds body)  (enum-bnds body (append (map car bnds) (fold enum-bnds ax (map cadr bnds)))))
46         ((s . es)          (if (symbol? s)  (fold enum-bnds ax es) ax))
47         (else ax)))
48
49
50(define (enum-freevars expr bnds ax)
51  (match expr 
52         (('if . es) 
53          (fold (lambda (x ax) (enum-freevars x bnds ax)) ax es))
54         (('let bnds body) 
55          (let ((bnds1 (append (map first bnds) bnds)))
56            (enum-freevars body bnds1 (fold (lambda (x ax) (enum-freevars x bnds ax)) ax (map second bnds)))))
57         ((s . es)    (if (symbol? s)  (fold (lambda (x ax) (enum-freevars x bnds ax)) ax es) ax))
58         (id          (if (and (symbol? id) (not (member id bnds)))  (cons id ax) ax))))
59
60
61(define (sum lst)
62  (if (null? lst) lst
63      (match lst
64             ((x)   x)
65             ((x y) `(+ ,x ,y))
66             ((x y . rest) `(+ (+ ,x ,y) ,(sum rest)))
67             ((x . rest) `(+ ,x ,(sum rest))))))
68
69
70(define (if-convert expr)
71  (match expr 
72         (('if c t e) 
73          (let ((r (gensym "if")))
74            `(let ((,r (if ,(if-convert c) ,(if-convert t) ,(if-convert e)))) 
75               ,r)))
76         (('let bs e)
77          `(let ,(map (lambda (b) `(,(car b) ,(if-convert (cadr b)))) bs) ,(if-convert e)))
78         ((f . es)
79          (cons f (map if-convert es)))
80         ((? atom? ) expr)))
81
82         
83(define (let-enum expr ax)
84  (match expr
85         (('let ((x ('if c t e))) y)
86          (let ((ax (fold let-enum ax (list c ))))
87            (if (eq? x y)  (append ax (list (list x `(if ,c ,t ,e)))) ax)))
88
89         (('let bnds body)  (let-enum body (append ax bnds)))
90
91         (('if c t e)  (let-enum ax c))
92
93         ((f . es)  (fold let-enum ax es))
94
95         (else ax)))
96
97
98(define (let-elim expr)
99  (match expr
100         (('let ((x ('if c t e))) y)
101          (if (eq? x y)  y expr))
102
103         (('let bnds body) (let-elim body))
104
105         (('if c t e)  `(if ,(let-elim c) ,(let-lift t) ,(let-lift e)))
106
107         ((f . es)  `(,f . ,(map let-elim es)))
108
109         (else expr)))
110 
111
112(define (let-lift expr)
113  (let ((bnds (let-enum expr (list))))
114    (if (null? bnds) expr
115        `(let ,(map (lambda (b) (list (car b) (let-elim (cadr b)))) bnds) ,(let-elim expr)))))
116
117
118(define (s+ . lst)    (string-concatenate (map ->string lst)))
119(define (sw+ lst)     (string-intersperse (filter-map (lambda (x) (and x (->string x))) lst) " "))
120(define (sl\ p lst)   (string-intersperse (map ->string lst) p))
121(define nl "\n")
122(define (spaces n)    (list->string (list-tabulate n (lambda (x) #\space))))
123
124(define (ppf indent . lst)
125  (let ((sp (spaces indent)))
126    (for-each (lambda (x)
127                (and x (match x 
128                              ((i . x1) (if (and (number? i) (positive? i))
129                                            (for-each (lambda (x) (ppf (+ indent i) x)) x1)
130                                            (print sp (sw+ x))))
131                              (else   (print sp (if (list? x) (sw+ x) x))))))
132              lst)))
133
134
135(define (transitions-graph n open transitions state-name)
136  (let* ((subst-convert  (subst-driver (lambda (x) (and (symbol? x) x)) 
137                                       nemo:binding? identity nemo:bind nemo:subst-term))
138         (g          (make-digraph n (string-append (->string n) " transitions graph")))
139         (add-node!  (g 'add-node!))
140         (add-edge!  (g 'add-edge!))
141         (out-edges  (g 'out-edges))
142         (in-edges   (g 'in-edges))
143         (node-info  (g 'node-info))
144         (node-list  (let loop ((lst (list)) (tlst transitions))
145                       (if (null? tlst)  (delete-duplicates lst eq?)
146                           (match (car tlst) 
147                                  (('-> (and (? symbol?) s0) (and (? symbol?) s1) rate-expr)
148                                   (loop (cons* s0 s1 lst) (cdr tlst)))
149                                  (((and (? symbol?) s0) '-> (and (? symbol? s1)) rate-expr)
150                                   (loop (cons* s0 s1 lst) (cdr tlst)))
151                                  (('<-> (and (? symbol?) s0) (and (? symbol?) s1) rate-expr1 rate-expr2)
152                                   (loop (cons* s0 s1 lst) (cdr tlst)))
153                                  (((and (? symbol?) s0) 'M-> (and (? symbol? s1)) rate-expr1 rate-expr2)
154                                   (loop (cons* s0 s1 lst) (cdr tlst)))
155                                  (else
156                                   (nemo:error 'state-eqs ": invalid transition equation " 
157                                                  (car tlst) " in state complex " n))
158                                  (else (loop lst (cdr tlst)))))))
159         (node-ids      (list-tabulate (length node-list) identity))
160         (name->id-map  (zip node-list node-ids))
161         (node-subs     (fold (lambda (s ax) (subst-extend s (state-name n s) ax)) subst-empty node-list)))
162    ;; insert state nodes in the dependency graph
163    (for-each (lambda (i n) (add-node! i n)) node-ids node-list)
164    (let* ((nodes  ((g 'nodes)))
165           (snode   (find (lambda (s) (not (eq? (second s) open))) nodes))
166           (snex   (let ((nodes/s (filter-map (lambda (s) (and (not (= (first s) (first snode))) (second s))) nodes))
167                         (sumvar  (gensym "sum")))
168                     `(let ((,sumvar ,(sum nodes/s))) (- 1 ,sumvar))))
169           (add-tredge (lambda (s0 s1 rexpr1 rexpr2)
170                         (let* ((i   (car (alist-ref s0 name->id-map)))
171                                (j   (car (alist-ref s1 name->id-map)))
172                                (x0  (if (eq? s0 (second snode)) snex s0))
173                                (x1  (if (eq? s1 (second snode)) snex s1))
174                                (ij-expr  `(* ,(subst-convert x0 node-subs) ,(subst-convert rexpr1 node-subs)))
175                                (ji-expr  (and rexpr2
176                                               `(* ,(subst-convert x1 node-subs) ,(subst-convert rexpr2 node-subs)))))
177                           (add-edge! (list i j ij-expr))
178                           (if rexpr2 (add-edge! (list j i ji-expr)))))))
179      ;; create rate edges in the graph
180      (for-each (lambda (e) 
181                  (match e
182                         (('-> s0 s1 rexpr)  (add-tredge s0 s1 rexpr #f))
183                         ((s0 '-> s1 rexpr)  (add-tredge s0 s1 rexpr #f))
184                         (('<-> s0 s1 rexpr1 rexpr2)  (add-tredge s0 s1 rexpr1 rexpr2))
185                         ((s0 '<-> s1 rexpr1 rexpr2)  (add-tredge s0 s1 rexpr1 rexpr2))
186                         ))
187                transitions)
188
189      (list g node-subs))))
190
191
192(define (state-lineqs n transitions lineqs state-name)
193  (let* ((subst-convert  (subst-driver (lambda (x) (and (symbol? x) x)) 
194                                       nemo:binding? identity nemo:bind nemo:subst-term))
195         (state-list     (let loop ((lst (list)) (tlst transitions))
196                           (if (null? tlst)  (delete-duplicates lst eq?)
197                               (match (car tlst) 
198                                      (('-> (and (? symbol?) s0) (and (? symbol?) s1) rate-expr)
199                                       (loop (cons* s0 s1 lst) (cdr tlst)))
200                                      (((and (? symbol?) s0) '-> (and (? symbol? s1)) rate-expr)
201                                       (loop (cons* s0 s1 lst) (cdr tlst)))
202                                      (('<-> (and (? symbol?) s0) (and (? symbol?) s1) rate-expr1 rate-expr2)
203                                       (loop (cons* s0 s1 lst) (cdr tlst)))
204                                      (((and (? symbol?) s0) 'M-> (and (? symbol? s1)) rate-expr1 rate-expr2)
205                                       (loop (cons* s0 s1 lst) (cdr tlst)))
206                                      (else
207                                       (nemo:error 'nemo:state-lineq ": invalid transition equation " 
208                                                   (car tlst) " in state complex " n))
209                                      (else (loop lst (cdr tlst)))))))
210         (state-subs     (fold (lambda (s ax) (subst-extend s (state-name n s) ax)) subst-empty state-list))
211         (lineqs1        (map (lambda (lineq) (match lineq ((i '= . expr) `(,i = . ,(subst-convert expr state-subs)))))
212                              lineqs)))
213    (list n lineqs1)))
214
215;;    `(+ - * / pow neg abs atan asin acos sin cos exp ln
216;;      sqrt tan cosh sinh tanh hypot gamma lgamma log10 log2 log1p ldexp cube
217;;      > < <= >= = and or round ceiling floor max min
218;;      fpvector-ref))
219
220(define LOG10E 0.434294481903252)
221(define LOG2E  1.44269504088896)
222
223(define (differentiate x t)
224  (cond ((number? t)  0.0)
225        ((symbol? t)  (if (equal? x t) 1.0 0.0))
226        (else (match t
227                (('neg u)  `(neg ,(differentiate x u)))
228
229                (('+ u v)  `(+ ,(differentiate x u) ,(differentiate x v)))
230                (('- u v)  `(- ,(differentiate x u) ,(differentiate x v)))
231
232                (('* (and u (? number?)) v)    `(* ,u ,(differentiate x v)))
233                (('* v (and u (? number?)))    `(* ,u ,(differentiate x v)))
234
235                (('* u v)     `(+ (* ,(differentiate x u) ,v)
236                                  (* ,u ,(differentiate x v))))
237
238                (('/ u v)     `(/ (- (* ,(differentiate x u) ,v)
239                                     (* ,u ,(differentiate x v)))
240                                  (pow ,v 2.0)))
241
242                (('cube u)     (differentiate x `(pow ,u 3.0)))
243
244                (('pow u n)    (chain x u `(* ,n (pow ,u (- ,n 1.0)))))
245               
246                (('sqrt u)     (chain x u `(/ 1.0 (* 2.0 (sqrt ,u)))))
247
248                (('exp u)      (chain x u `(exp ,u)))
249
250                (('log a u)    (chain x u `(/ 1.0 ,u)))
251
252                (('log10 u)    (chain x u `(* ,LOG10E (/ ,(differentiate x u) ,u))))
253       
254                (('log2 u)     (chain x u `(* ,LOG2E (/ ,(differentiate x u) ,u))))
255
256                (('log1p u)    (differentiate x `(log (+ 1.0 ,u))))
257
258                (('ldexp u n)  (differentiate x `(* ,u ,(expt 2 n))))
259       
260                (('sin u)      (chain x u `(cos ,u)))
261               
262                (('cos u)      (chain x u `(neg (sin ,u))))
263
264                (('tan u)      (differentiate x `(* (sin ,u) (/ 1.0 (cos ,u)))))
265               
266                (('asin u)     (chain x u `(/ 1.0 (sqrt (- 1.0 (pow ,u 2.0))))))
267                                             
268                (('acos u)     (chain x u `(/ (neg 1.0) (sqrt (- 1.0 (pow ,u 2.0))))))
269                                             
270                (('atan u)     (chain x u `(/ 1.0 (+ 1.0 (pow ,u 2.0)))))
271                                             
272                (('sinh u)     (differentiate x `(/ (- (exp ,u) (exp (neg ,u))) 2.0)))
273
274                (('cosh u)     (differentiate x `(/ (+ (exp ,u) (exp (neg ,u))) 2.0)))
275               
276                (('tanh u)     (differentiate x `(/ (sinh ,u) (cosh ,u))))
277
278                (else          #f)))))
279
280(define (chain x t u)
281  (if (symbol? t) u
282      `(* ,(differentiate x t) ,u)))
283       
284(define (simplify t)
285  (match t
286         (('+ 0.0 t1)  t1)
287         (('+ t1 0.0)  t1)
288         (('+ t1 ('neg t2))  `(- ,t1 ,t2))
289         
290         (('- 0.0 t1)  `(neg ,t1))
291         (('- t1 0.0)  t1)
292         (('neg ('neg t1))  t1)
293
294         (('* 0.0 t1)  0.0)
295         (('* t1 0.0)  0.0)
296         (('* 1.0 t1)  t1)
297         (('* t1 1.0)  t1)
298         (('* ('neg t1) ('neg t2))  `(* ,t1 ,t2))
299
300         (else t)))
301
302)
Note: See TracBrowser for help on using the repository browser.