source: project/release/4/9ML-toolkit/trunk/templates/Network.sml.tmpl @ 30895

Last change on this file since 30895 was 30895, checked in by Ivan Raikov, 6 years ago

9ML-toolkit: updates to network code generation template

File size: 20.1 KB
Line 
1
2structure {{group.name}} =
3struct
4
5  fun putStrLn out str =
6      (TextIO.output (out, str);
7       TextIO.output (out, "\n"))
8   
9  fun putStr out str =
10      (TextIO.output (out, str))
11     
12  fun showBoolean b = (if b then "1" else "0")
13                     
14  fun showReal n =
15      let open StringCvt
16          open Real
17      in
18          (if n < 0.0 then "-" else "") ^ (fmt (FIX (SOME 12)) (abs n))
19      end
20     
21  fun foldl1 f lst = let val v = List.hd lst
22                         val lst' = List.tl lst
23                     in
24                         List.foldl f v lst'
25                     end
26
27  fun fromDiag (m, n, a, dflt) =
28      if Index.validShape [m,n]
29      then
30          (let
31               val na  = RTensor.Array.length a
32               val na' = na-1
33               val te  = RTensor.new ([m,n], dflt)
34               fun diag (i, j, ia) =
35                   let
36                       val ia' =
37                           (RTensor.update (te, [i,j], RTensor.Array.sub (a, ia));
38                            if ia = na' then 0 else ia+1)
39                   in
40                       if (i=0) orelse (j=0)
41                       then te
42                       else diag (i-1, j-1, ia)
43                   end
44           in
45               diag (m-1, n-1, 0)
46           end)
47      else
48          raise RTensor.Shape
49
50  val RandomInit = RandomMTZig.fromEntropy
51
52  val ZigInit = RandomMTZig.ztnew
53       
54  exception Index       
55
56  val label = "{{group.name}}"
57           
58  val N = {{group.order}}     (* total population size *)
59
60
61  {% for p in dict (group.properties) %}
62  val {{p.name}} = {{p.value.exprML}}
63  {% endfor %}
64
65  {% with timestep = default(group.properties.timestep.exprML, 0.1) %}
66  val h = {{ timestep }}
67  {% endwith %}
68
69  (* network propagation delay *)
70  {% if group.properties.delay %}
71  val D: real = {{group.properties.delay.exprML}}
72  {% else %}
73  {% if group.properties.delayMatrix %}
74  val D: SparseMatrix.matrix = {{group.properties.delayMatrix.exprML}}
75  {% endif %}
76  {% endif %}
77
78  structure TEQ = TEventQueue (val h = h)
79  val DQ = TEQ.empty
80
81  val seed_init = RandomInit() (* seed for randomized initial values *)
82  val zt_init   = ZigInit()
83  fun random_normal () = RandomMTZig.randNormal(seed_init,zt_init)
84  fun random_uniform () = RandomMTZig.randUniform(seed_init)
85
86{% for pop in dict (group.populations) %}
87
88  val N_{{pop.name}} = {{ pop.value.size }}
89
90  val {{pop.name}}_initial = {{pop.value.prototype.initialExprML}}
91
92{% if pop.value.prototype.fieldExprML %}
93  val {{pop.name}}_field_vector =
94    Vector.tabulate (N_{{pop.name}}, fn (i) =>  {{pop.value.prototype.fieldExprML}})
95{% endif %}
96
97  val {{pop.name}}_initial_vector =
98    Vector.tabulate (N_{{pop.name}}, fn (i) =>  {{pop.value.prototype.initialStateExprML}})
99
100  val {{pop.name}}_f = Model_{{pop.name}}.{{pop.value.prototype.ivpFn}}
101
102  fun {{pop.name}}_run (Wnet,n0) (i,input as { {{ join (",", pop.value.prototype.states) }} }) =
103    let
104        val initial = {{pop.name}}_initial
105{% if pop.value.prototype.fieldExprML %}
106        val fieldV = Vector.sub ({{pop.name}}_field_vector,i)
107{% endif %}
108        val Isyn_i  = case Wnet of SOME W => RTensor.sub(W,[i+n0,0]) | NONE => 0.0
109        (*val _ = putStrLn TextIO.stdOut ("# {{pop.name}}: t = " ^ (showReal t) ^ " Isyn_i = " ^ (showReal Isyn_i) ^ " V = " ^ (showReal V))*)
110        val nstate = {{pop.name}}_f {{ pop.value.prototype.updateStateML }}
111        val nstate' = {{ pop.value.prototype.copyStateML }}
112    in
113        nstate'
114    end
115
116{% endfor %}
117
118
119{% if group.plastypes %}{% for pl in dict (group.plastypes) %}
120  val {{pl.name}}_initial = {{pl.value.initialExprML}}
121{% endfor %}{% endif %}
122
123
124{% if group.psrtypes %}{% for psr in dict (group.psrtypes) %}
125  val {{psr.name}}_initial = {{psr.value.initialExprML}}
126
127  val {{psr.name}}_initial_vector = Vector.tabulate ({{psr.value.range}},
128                                                     fn (i) => {{psr.value.initialStateExprML}})
129
130  val {{psr.name}}_f = Model_{{psr.name}}.{{psr.value.ivpFn}}
131
132  fun {{psr.name}}_response (W,T) (i,input as { {{ join (",", psr.value.states) }} }) =
133    let
134        val initial   = {{psr.name}}_initial
135        val Ispike_i  = RTensor.sub(W,[i,0])
136        val spike_i   = Real.!= (Ispike_i, 0.0)
137        val tspike_i  = RTensor.sub(T,[i,0])
138        (*val _ = putStrLn TextIO.stdOut ("# {{psr.name}}: t = " ^ (showReal t) ^ " Ispike_i = " ^ (showReal Ispike_i)*)
139        val nstate  = {{psr.name}}_f {{ psr.value.updateStateML }}
140        val nstate' = {{ psr.value.copyStateML }}
141    in
142        RTensor.update(W,[i,0],Real.+(RTensor.sub(W,[i,0]),(#Isyn nstate)));
143        nstate'
144    end
145{% endfor %}{% endif %}
146
147
148{% if group.conntypes %}{% for conn in dict (group.conntypes) %}
149  val {{conn.name}}_initial = {{conn.value.initialExprML}}
150
151  val {{conn.name}}_f = Model_{{conn.name}}.{{conn.value.sysFn}}
152{% endfor %}{% endif %}
153
154    val initial = (
155        {% for pop in dict (group.populations) %}
156        {{pop.name}}_initial_vector{% if not loop.last %},{% endif %}
157        {% endfor %}
158    )
159
160    val psr_initial = (
161        {% if group.psrtypes %}{% for psr in dict (group.psrtypes) %}
162        {{psr.name}}_initial_vector{% if not loop.last %},{% endif %}
163        {% endfor %}{% endif %}
164    )
165
166
167    {% for pop in dict (group.populations) %}
168    val {{pop.name}}_n0 = {{pop.value.start}}
169    {% endfor %}
170
171               
172    val Pn = [
173        {% for pop in dict (group.populations) %}
174        {{pop.name}}_n0{% if not loop.last %},{% endif %}
175        {% endfor %}
176    ]
177         
178    fun frun I
179             (
180              {% for pop in dict (group.populations) %}
181              {{pop.name}}_state_vector{% if not loop.last %},{% endif %}
182              {% endfor %} ) =
183        let
184            {% for pop in dict (group.populations) %}
185            val {{pop.name}}_state_vector' =
186                Vector.mapi ({{pop.name}}_run (I,{{pop.name}}_n0))
187                            {{pop.name}}_state_vector
188
189            {% endfor %}
190        in
191            (
192             {% for pop in dict (group.populations) %}
193             {{pop.name}}_state_vector'{% if not loop.last %},{% endif %}
194             {% endfor %}
195            )
196        end
197
198                         
199    fun fresponse I
200             (
201              {% if group.psrtypes %}{% for psr in dict (group.psrtypes) %}
202              {{psr.name}}_state_vector{% if not loop.last %},{% endif %}
203              {% endfor %}{% endif %}
204             ) =
205        let
206            {% if group.psrtypes %}
207            val (W,T) = case I of SOME (W,T) => (W,T)
208                             | NONE => (RTensor.new ([N,1],0.0),
209                                        RTensor.new ([N,1],0.0))
210            {% for psr in dict (group.psrtypes) %}
211            val {{psr.name}}_state_vector' =
212                Vector.mapi ({{psr.name}}_response (W,T))
213                            {{psr.name}}_state_vector
214
215            {% endfor %}{% else %}
216            val W = case I of SOME (W,T) => W
217                            | NONE => RTensor.new ([N,1],0.0)
218            {% endif %}
219        in
220            (SOME W,
221             (
222              {% if group.psrtypes %}{% for psr in dict (group.psrtypes) %}
223              {{psr.name}}_state_vector'{% if not loop.last %},{% endif %}
224              {% endfor %}{% endif %}
225             ))
226        end
227
228
229    fun felec E I
230              ({% for pop in dict (group.populations) %}
231              {{pop.name}}_state_vector{% if not loop.last %},{% endif %}
232              {% endfor %}) =
233        case E of NONE => I
234                | SOME E =>
235                  let
236                    val update = Unsafe.Real64Array.update
237                    val I' = case I of SOME I => I
238                                     | NONE => RTensor.new ([N,1],0.0)
239                   
240                  {% for pr in dict (group.projections) %}
241                  {% if pr.value.type == "continuous" %}
242                  {% for spop in pr.value.source %}
243                  {% for tpop in pr.value.target %}
244                    fun {{spop.name}}_sub i = #({{first (spop.value.prototype.states)}})(Vector.sub ({{spop.name}}_state_vector, i))
245                    fun {{tpop.name}}_sub i = #({{first (tpop.value.prototype.states)}})(Vector.sub ({{tpop.name}}_state_vector, i))
246                    val _ = Loop.app
247                                (0, N_{{spop.name}},
248                                 fn (i) =>
249                                    let
250                                        val Vi = {{spop.name}}_sub i
251                                        val sl = SparseMatrix.slice (#{{spop.name}}(E),1,i)
252                                    in
253                                        SparseMatrix.sliceAppi
254                                            (fn (j,g) => let val Vj = {{tpop.name}}_sub j
255                                                         in update (I,i,Real.- (sub(I,i), Real.* (g,Real.- (Vi,Vj)))) end)
256                                            sl
257                                    end)
258                               
259                  {% endfor %}
260                  {% endfor %}
261                  {% endif %}
262                  {% endfor %}
263                  in
264                      SOME I'
265                  end
266                         
267                         
268    fun ftime (
269              {% for pop in dict (group.populations) %}
270              {{pop.name}}_state_vector{% if not loop.last %},{% endif %}
271              {% endfor %} ) =
272       
273        {% with pop = first (dict (group.populations)) %}
274        let val { {{ join (",", pop.value.prototype.states) }} } = Vector.sub ({{pop.name}}_state_vector,0)
275        in {{pop.value.prototype.ivar}} end
276        {% endwith %}
277   
278       
279    fun fspikes (
280              {% for pop in dict (group.populations) %}
281              {{pop.name}}_state_vector{% if not loop.last %},{% endif %}
282              {% endfor %} ) =
283
284        let
285           {% for pop in dict (group.populations) %}
286            val {{pop.name}}_spike_i =
287                Vector.foldri (fn (i,v as { {{ join (",", pop.value.prototype.states) }} },ax) =>
288                               {% if not pop.name in group.spikepoplst %}
289                               (if (#{{first (pop.value.prototype.events)}}(v)) then ((i+{{pop.name}}_n0,#{{first (pop.value.prototype.events)}}Count(v)))::ax else ax))
290                               {% else %}
291                               (if (#{{first (pop.value.prototype.events)}}(v)) then ((i+{{pop.name}}_n0,1.0))::ax else ax))
292                               {% endif %}
293                              [] {{pop.name}}_state_vector
294           {% endfor %}
295
296           val ext_spike_i = List.concat ( {% for pop in dict (group.populations) %}{% if not pop.name in group.spikepoplst %}{{pop.name}}_spike_i ::{% endif %}{% endfor %} [] )
297
298               
299           val neuron_spike_i = List.concat [
300                                 {% for name in (group.spikepoplst) %}
301                                 {{name}}_spike_i{% if not loop.last %},{% endif %}
302                                 {% endfor %}
303                                 ]
304
305            val all_spike_i    = List.concat [neuron_spike_i, ext_spike_i]
306        in
307            (all_spike_i, neuron_spike_i)
308        end
309
310{% macro cartesian_product(sp, tp) %}
311   {% for s in sp %}
312   {% for t in tp %}{{ caller(s,t) }}{% if not loop.last %},{% endif %}{% endfor %}{% if not loop.last %},{% endif %}
313   {% endfor %}
314{% endmacro %}
315
316{% macro for_each(name, sp, tp, plasticity, component, cstate) %}
317             val Pr_{{name}} = let
318                                  val weight  = {% if plasticity %}#weight({{plasticity}}_initial){% else %}1.0{% endif %}
319                               in
320                                  SparseMatrix.fromGeneratorList [N,N]
321                                  [
322                                    {% call cartesian_product (sp,tp) %}
323                                      {offset=[{{t.start}},{{s.start}}],
324                                       fshape=[{{t.size}},{{s.size}}],
325                                       f=(fn (i) => Real.* (weight, #{{cstate}}({{component}}_f {{component}}_initial) ))}
326                                    {% endcall %}
327                                  ]
328                               end
329{% endmacro %}
330
331
332{% macro all_to_all(name, sp, tp, plasticity)  %}
333             val Pr_{{name}} = let
334                                  val weight = {% if plasticity %}#weight({{plasticity}}_initial){% else %}1.0{% endif %}
335                               in
336                                  SparseMatrix.fromTensorList [N,N]
337                                  [
338                                    {% call cartesian_product (sp,tp) %}
339                                    {offset=[{{t.start}},{{s.start}}],
340                                     tensor=(RTensor.*> weight (RTensor.new ([{{t.size}},{{s.size}}],1.0))),
341                                     sparse=false}
342                                    {% endcall %}
343                                  ]
344                                end
345{% endmacro %}
346
347{% macro one_to_one(name, sp, tp, plasticity) %}
348             val Pr_{{name}} = let
349                                  val weight = {% if plasticity %}#weight({{plasticity}}_initial){% else %}1.0{% endif %}
350                               in
351                                  SparseMatrix.fromTensorList [N,N]
352                                  [
353                                    {% call cartesian_product (sp,tp) %}
354                                    {offset=[{{t.start}},{{s.start}}],
355                                     tensor=(fromDiag ({{t.size}},{{s.size}},Real64Array.fromList [{{weight}}],0.0)),
356                                     sparse=true}
357                                    {% endcall %}
358                                  ]
359                               end
360{% endmacro %}
361
362
363{% macro from_file(name, sp, tp, filename) %}
364             val Pr_{{name}} = let val infile = TextIO.openIn "{{filename}}"
365                                   val S = TensorFile.realTensorRead (infile)
366                                   val _ = TextIO.closeIn infile
367                               in
368                                   SparseMatrix.fromTensorSliceList [N,N]
369                                   [
370                                     {% with %}
371                                     {% set soffset = 0 %}
372                                     {% for s in sp %}
373                                     {% set toffset = 0 %}
374                                     {% for t in tp %}
375                                     {offset=[{{t.start}},{{s.start}}],
376                                      slice=(RTensorSlice.slice ([([{{toffset}},{{soffset}}],[{{toffset}}+{{t.size}}-1,{{soffset}}+{{s.size}}-1])],S)),
377                                      sparse=false}{% if not loop.last %},{% endif %}
378                                     {% set toffset = toffset + t.size %}
379                                     {% endfor %}{% if not loop.last %},{% endif %}
380                                     {% set soffset = soffset + s.size %}
381                                     {% endfor %}
382                                     {% endwith %}
383                                  ]
384                                end
385{% endmacro %}
386
387
388{% macro range_map(name, sp, tp) %}
389             val srangemap_{{name}} =
390                                    [
391                                     {% with %}
392                                     {% set soffset = 0 %}
393                                     {% for s in sp %}
394                                     {size={{s.size}}
395                                      localStart={{soffset}},
396                                      globalStart={{s.start}} }{% if not loop.last %},{% endif %}
397                                     {% set soffset = soffset + s.size %}
398                                     {% endfor %}
399                                     {% endwith %}
400                                    ]
401             val trangemap_{{name}} =
402                                    [
403                                     {% with %}
404                                     {% set toffset = 0 %}
405                                     {% for t in tp %}
406                                     {size={{t.size}}
407                                      localStart={{toffset}},
408                                      globalStart={{t.start}} }{% if not loop.last %},{% endif %}
409                                     {% set toffset = toffset + t.size %}
410                                     {% endfor %}
411                                     {% endwith %}
412                                    ]
413{% endmacro %}
414           
415           
416    fun fprojection () =
417       
418        (let
419             
420             {% for pr in dict (group.projections) %}
421             val _ = putStrLn TextIO.stdOut "constructing {{pr.name}}"
422
423             {% if pr.value.rule.operator == "for-each" %}
424             {% call for_each(pr.name,
425                              pr.value.source.populations,
426                              pr.value.target.populations,
427                              pr.value.plasticity,
428                              pr.value.rule.component,
429                              pr.value.rule.cstate) %}
430             {% endcall %}
431             {% else %}
432             {% if pr.value.rule.operator == "one-to-one" %}
433             {% call one_to_one(pr.name,
434                                pr.value.source.populations,
435                                pr.value.target.populations,
436                                pr.value.plasticity) %}
437             {% endcall %}
438             {% else %}
439             {% if pr.value.rule.operator == "all-to-all" %}
440             {% call all_to_all(pr.name,
441                                pr.value.source.populations,
442                                pr.value.target.populations,
443                                pr.value.plasticity) %}
444             {% endcall %}
445             {% else %}
446             {% if pr.value.rule.operator == "from-file" %}
447             {% call from_file(pr.name,
448                               pr.value.source.populations,
449                               pr.value.target.populations,
450                               pr.value.rule.properties.filename.exprML) %}
451             {% endcall %}
452             {% endif %}
453             {% endif %}
454             {% endif %}
455             {% endif %}
456             {% endfor %}
457               
458       
459{% if group.psrtypes %}{% for psr in dict (group.psrtypes) %}
460{% if psr.value.type == "event" %}
461             val S_{{psr.name}} = foldl1 SparseMatrix.insert
462                                  ([
463                                    {% for pr in psr.value.projections %}
464                                    Pr_{{pr}}{% if not loop.last %},{% endif %}
465                                    {% endfor %}
466                                   ])
467{% endif %}
468{% endfor %}{% else %}
469             val S = foldl1 SparseMatrix.insert
470                     ([
471                       {% for pr in dict (group.projections) %}
472                       {% if pr.value.type == "event" %}
473                       Pr_{{pr.name}}{% if not loop.last %},{% endif %}
474                       {% endif %}
475                       {% endfor %}
476                     ])
477{% endif %}
478
479             {% for pr in dict (group.projections) %}
480             {% if pr.value.type == "continuous" %}
481             {% call range_map(pr.name,
482                               pr.value.source.populations,
483                               pr.value.target.populations) %}
484             {% endcall %}
485             {% endif %}
486             {% endfor %}
487
488             val Elst = 
489                      [
490                        {% for pr in dict (group.projections) %}
491                        {% if pr.value.type == "continuous" %}
492                          ElecGraph.junctionMatrix ([N,N],ElecGraph.elecGraph ({{pr.name}}(srangemap_{{pr.name}},trangemap_{{pr.name}}))
493                          Pr_{{pr.name}}){% if not loop.last %},{% endif %}
494                        {% endif %}
495                        {% endfor %}
496                      ]
497
498             val E = if List.null Elst then NONE else SOME Elst
499
500             in
501              ([
502{% if group.psrtypes %}{% for psr in dict (group.psrtypes) %}
503               S_{{ psr.name }}{% if not loop.last %},{% endif %}
504{% endfor %}{% else %}
505               S
506{% endif %}
507              ])
508             end)
509
510
511end
512       
Note: See TracBrowser for help on using the repository browser.