source: project/release/4/9ML-toolkit/trunk/templates/Network.sml.tmpl @ 30902

Last change on this file since 30902 was 30902, checked in by Ivan Raikov, 6 years ago

9ML-toolkit: bugfixes related to multiple synapse handling in code generation templates

File size: 20.2 KB
Line 
1
2structure {{group.name}} =
3struct
4
5  fun putStrLn out str =
6      (TextIO.output (out, str);
7       TextIO.output (out, "\n"))
8   
9  fun putStr out str =
10      (TextIO.output (out, str))
11     
12  fun showBoolean b = (if b then "1" else "0")
13                     
14  fun showReal n =
15      let open StringCvt
16          open Real
17      in
18          (if n < 0.0 then "-" else "") ^ (fmt (FIX (SOME 12)) (abs n))
19      end
20     
21  fun foldl1 f lst = let val v = List.hd lst
22                         val lst' = List.tl lst
23                     in
24                         List.foldl f v lst'
25                     end
26
27  fun fromDiag (m, n, a, dflt) =
28      if Index.validShape [m,n]
29      then
30          (let
31               val na  = RTensor.Array.length a
32               val na' = na-1
33               val te  = RTensor.new ([m,n], dflt)
34               fun diag (i, j, ia) =
35                   let
36                       val ia' =
37                           (RTensor.update (te, [i,j], RTensor.Array.sub (a, ia));
38                            if ia = na' then 0 else ia+1)
39                   in
40                       if (i=0) orelse (j=0)
41                       then te
42                       else diag (i-1, j-1, ia)
43                   end
44           in
45               diag (m-1, n-1, 0)
46           end)
47      else
48          raise RTensor.Shape
49
50  val RandomInit = RandomMTZig.fromEntropy
51
52  val ZigInit = RandomMTZig.ztnew
53       
54  exception Index       
55
56  val label = "{{group.name}}"
57           
58  val N = {{group.order}}     (* total population size *)
59
60
61  {% for p in dict (group.properties) %}
62  val {{p.name}} = {{p.value.exprML}}
63  {% endfor %}
64
65  {% with timestep = default(group.properties.timestep.exprML, 0.1) %}
66  val h = {{ timestep }}
67  {% endwith %}
68
69  (* network propagation delay *)
70  {% if group.properties.delay %}
71  val D: real = {{group.properties.delay.exprML}}
72  {% else %}
73  {% if group.properties.delayMatrix %}
74  val D: SparseMatrix.matrix = {{group.properties.delayMatrix.exprML}}
75  {% endif %}
76  {% endif %}
77
78  structure TEQ = TEventQueue (val h = h)
79  val DQ = TEQ.empty
80
81  val seed_init = RandomInit() (* seed for randomized initial values *)
82  val zt_init   = ZigInit()
83  fun random_normal () = RandomMTZig.randNormal(seed_init,zt_init)
84  fun random_uniform () = RandomMTZig.randUniform(seed_init)
85
86{% for pop in dict (group.populations) %}
87
88  val N_{{pop.name}} = {{ pop.value.size }}
89
90  val {{pop.name}}_initial = {{pop.value.prototype.initialExprML}}
91
92{% if pop.value.prototype.fieldExprML %}
93  val {{pop.name}}_field_vector =
94    Vector.tabulate (N_{{pop.name}}, fn (i) =>  {{pop.value.prototype.fieldExprML}})
95{% endif %}
96
97  val {{pop.name}}_initial_vector =
98    Vector.tabulate (N_{{pop.name}}, fn (i) =>  {{pop.value.prototype.initialStateExprML}})
99
100  val {{pop.name}}_f = Model_{{pop.name}}.{{pop.value.prototype.ivpFn}}
101
102  fun {{pop.name}}_run (Wnet,n0) (i,input as { {{ join (",", pop.value.prototype.states) }} }) =
103    let
104        val initial = {{pop.name}}_initial
105{% if pop.value.prototype.fieldExprML %}
106        val fieldV = Vector.sub ({{pop.name}}_field_vector,i)
107{% endif %}
108        val Isyn_i  = case Wnet of SOME W => RTensor.sub(W,[i+n0,0]) | NONE => 0.0
109        (*val _ = putStrLn TextIO.stdOut ("# {{pop.name}}: t = " ^ (showReal t) ^ " Isyn_i = " ^ (showReal Isyn_i) ^ " V = " ^ (showReal V))*)
110        val nstate = {{pop.name}}_f {{ pop.value.prototype.updateStateML }}
111        val nstate' = {{ pop.value.prototype.copyStateML }}
112    in
113        nstate'
114    end
115
116{% endfor %}
117
118
119{% if group.plastypes %}{% for pl in dict (group.plastypes) %}
120  val {{pl.name}}_initial = {{pl.value.initialExprML}}
121{% endfor %}{% endif %}
122
123
124{% if group.psrtypes %}{% for psr in dict (group.psrtypes) %}
125  val {{psr.name}}_initial = {{psr.value.initialExprML}}
126
127  val {{psr.name}}_initial_vector = Vector.tabulate ({{psr.value.range}},
128                                                     fn (i) => {{psr.value.initialStateExprML}})
129
130  val {{psr.name}}_f = Model_{{psr.name}}.{{psr.value.ivpFn}}
131
132  fun {{psr.name}}_response (W,T) (i,input as { {{ join (",", psr.value.states) }} }) =
133    let
134        val initial   = {{psr.name}}_initial
135        val Ispike_i  = RTensor.sub(W,[i,0])
136        val spike_i   = Real.!= (Ispike_i, 0.0)
137        val tspike_i  = RTensor.sub(T,[i,0])
138        (*val _ = putStrLn TextIO.stdOut ("# {{psr.name}}: t = " ^ (showReal t) ^ " Ispike_i = " ^ (showReal Ispike_i)*)
139        val nstate  = {{psr.name}}_f {{ psr.value.updateStateML }}
140        val nstate' = {{ psr.value.copyStateML }}
141    in
142        RTensor.update(W,[i,0],Real.+(RTensor.sub(W,[i,0]),(#Isyn nstate)));
143        nstate'
144    end
145{% endfor %}{% endif %}
146
147
148{% if group.conntypes %}{% for conn in dict (group.conntypes) %}
149  val {{conn.name}}_initial = {{conn.value.initialExprML}}
150
151  val {{conn.name}}_f = Model_{{conn.name}}.{{conn.value.sysFn}}
152{% endfor %}{% endif %}
153
154    val initial = (
155        {% for pop in dict (group.populations) %}
156        {{pop.name}}_initial_vector{% if not loop.last %},{% endif %}
157        {% endfor %}
158    )
159
160    val psr_initial = (
161        {% if group.psrtypes %}{% for psr in dict (group.psrtypes) %}
162        {{psr.name}}_initial_vector{% if not loop.last %},{% endif %}
163        {% endfor %}{% endif %}
164    )
165
166
167    {% for pop in dict (group.populations) %}
168    val {{pop.name}}_n0 = {{pop.value.start}}
169    {% endfor %}
170
171               
172    val Pn = [
173        {% for pop in dict (group.populations) %}
174        {{pop.name}}_n0{% if not loop.last %},{% endif %}
175        {% endfor %}
176    ]
177         
178    fun frun I
179             (
180              {% for pop in dict (group.populations) %}
181              {{pop.name}}_state_vector{% if not loop.last %},{% endif %}
182              {% endfor %} ) =
183        let
184            {% for pop in dict (group.populations) %}
185            val {{pop.name}}_state_vector' =
186                Vector.mapi ({{pop.name}}_run (I,{{pop.name}}_n0))
187                            {{pop.name}}_state_vector
188
189            {% endfor %}
190        in
191            (
192             {% for pop in dict (group.populations) %}
193             {{pop.name}}_state_vector'{% if not loop.last %},{% endif %}
194             {% endfor %}
195            )
196        end
197
198                         
199    fun fresponse I
200             (
201              {% if group.psrtypes %}{% for psr in dict (group.psrtypes) %}
202              {{psr.name}}_state_vector{% if not loop.last %},{% endif %}
203              {% endfor %}{% endif %}
204             ) =
205        let
206            {% if group.psrtypes %}
207            val (T,W) = case I of SOME I' => List.nth (I',{{loop.index0}})
208                                | NONE => (RTensor.new ([N,1],0.0),
209                                           List.tabulate ({{length (group.psrtypes)}}, fn (i) => (RTensor.new ([N,1],0.0))))
210            {% for psr in dict (group.psrtypes) %}
211            val W' = List.nth (W,{{loop.index0}})
212            val {{psr.name}}_state_vector' =
213                Vector.mapi ({{psr.name}}_response (W',T)) {{psr.name}}_state_vector
214            {% endfor %}{% else %}
215            val W = case I of SOME (W,T) => W
216                            | NONE => RTensor.new ([N,1],0.0)
217            {% endif %}
218        in
219            (SOME W,
220             (
221              {% if group.psrtypes %}{% for psr in dict (group.psrtypes) %}
222              {{psr.name}}_state_vector'{% if not loop.last %},{% endif %}
223              {% endfor %}{% endif %}
224             ))
225        end
226
227
228    fun felec E I
229              ({% for pop in dict (group.populations) %}
230              {{pop.name}}_state_vector{% if not loop.last %},{% endif %}
231              {% endfor %}) =
232        case E of NONE => I
233                | SOME E =>
234                  let
235                    val update = Unsafe.Real64Array.update
236                    val I' = case I of SOME I => I
237                                     | NONE => RTensor.new ([N,1],0.0)
238                   
239                  {% for pr in dict (group.projections) %}
240                  {% if pr.value.type == "continuous" %}
241                  {% for spop in pr.value.source %}
242                  {% for tpop in pr.value.target %}
243                    fun {{spop.name}}_sub i = #({{first (spop.value.prototype.states)}})(Vector.sub ({{spop.name}}_state_vector, i))
244                    fun {{tpop.name}}_sub i = #({{first (tpop.value.prototype.states)}})(Vector.sub ({{tpop.name}}_state_vector, i))
245                    val _ = Loop.app
246                                (0, N_{{spop.name}},
247                                 fn (i) =>
248                                    let
249                                        val Vi = {{spop.name}}_sub i
250                                        val sl = SparseMatrix.slice (#{{spop.name}}(E),1,i)
251                                    in
252                                        SparseMatrix.sliceAppi
253                                            (fn (j,g) => let val Vj = {{tpop.name}}_sub j
254                                                         in update (I,i,Real.- (sub(I,i), Real.* (g,Real.- (Vi,Vj)))) end)
255                                            sl
256                                    end)
257                               
258                  {% endfor %}
259                  {% endfor %}
260                  {% endif %}
261                  {% endfor %}
262                  in
263                      SOME I'
264                  end
265                         
266                         
267    fun ftime (
268              {% for pop in dict (group.populations) %}
269              {{pop.name}}_state_vector{% if not loop.last %},{% endif %}
270              {% endfor %} ) =
271       
272        {% with pop = first (dict (group.populations)) %}
273        let val { {{ join (",", pop.value.prototype.states) }} } = Vector.sub ({{pop.name}}_state_vector,0)
274        in {{pop.value.prototype.ivar}} end
275        {% endwith %}
276   
277       
278    fun fspikes (
279              {% for pop in dict (group.populations) %}
280              {{pop.name}}_state_vector{% if not loop.last %},{% endif %}
281              {% endfor %} ) =
282
283        let
284           {% for pop in dict (group.populations) %}
285            val {{pop.name}}_spike_i =
286                Vector.foldri (fn (i,v as { {{ join (",", pop.value.prototype.states) }} },ax) =>
287                               {% if not pop.name in group.spikepoplst %}
288                               (if (#{{first (pop.value.prototype.events)}}(v)) then ((i+{{pop.name}}_n0,#{{first (pop.value.prototype.events)}}Count(v)))::ax else ax))
289                               {% else %}
290                               (if (#{{first (pop.value.prototype.events)}}(v)) then ((i+{{pop.name}}_n0,1.0))::ax else ax))
291                               {% endif %}
292                              [] {{pop.name}}_state_vector
293           {% endfor %}
294
295           val ext_spike_i = List.concat ( {% for pop in dict (group.populations) %}{% if not pop.name in group.spikepoplst %}{{pop.name}}_spike_i ::{% endif %}{% endfor %} [] )
296
297               
298           val neuron_spike_i = List.concat [
299                                 {% for name in (group.spikepoplst) %}
300                                 {{name}}_spike_i{% if not loop.last %},{% endif %}
301                                 {% endfor %}
302                                 ]
303
304            val all_spike_i    = List.concat [neuron_spike_i, ext_spike_i]
305        in
306            (all_spike_i, neuron_spike_i)
307        end
308
309{% macro cartesian_product(sp, tp) %}
310   {% for s in sp %}
311   {% for t in tp %}{{ caller(s,t) }}{% if not loop.last %},{% endif %}{% endfor %}{% if not loop.last %},{% endif %}
312   {% endfor %}
313{% endmacro %}
314
315{% macro for_each(name, sp, tp, plasticity, component, cstate) %}
316             val Pr_{{name}} = let
317                                  val weight  = {% if plasticity %}#weight({{plasticity}}_initial){% else %}1.0{% endif %}
318                               in
319                                  SparseMatrix.fromGeneratorList [N,N]
320                                  [
321                                    {% call cartesian_product (sp,tp) %}
322                                      {offset=[{{t.start}},{{s.start}}],
323                                       fshape=[{{t.size}},{{s.size}}],
324                                       f=(fn (i) => Real.* (weight, #{{cstate}}({{component}}_f {{component}}_initial) ))}
325                                    {% endcall %}
326                                  ]
327                               end
328{% endmacro %}
329
330
331{% macro all_to_all(name, sp, tp, plasticity)  %}
332             val Pr_{{name}} = let
333                                  val weight = {% if plasticity %}#weight({{plasticity}}_initial){% else %}1.0{% endif %}
334                               in
335                                  SparseMatrix.fromTensorList [N,N]
336                                  [
337                                    {% call cartesian_product (sp,tp) %}
338                                    {offset=[{{t.start}},{{s.start}}],
339                                     tensor=(RTensor.*> weight (RTensor.new ([{{t.size}},{{s.size}}],1.0))),
340                                     sparse=false}
341                                    {% endcall %}
342                                  ]
343                                end
344{% endmacro %}
345
346{% macro one_to_one(name, sp, tp, plasticity) %}
347             val Pr_{{name}} = let
348                                  val weight = {% if plasticity %}#weight({{plasticity}}_initial){% else %}1.0{% endif %}
349                               in
350                                  SparseMatrix.fromTensorList [N,N]
351                                  [
352                                    {% call cartesian_product (sp,tp) %}
353                                    {offset=[{{t.start}},{{s.start}}],
354                                     tensor=(fromDiag ({{t.size}},{{s.size}},Real64Array.fromList [{{weight}}],0.0)),
355                                     sparse=true}
356                                    {% endcall %}
357                                  ]
358                               end
359{% endmacro %}
360
361
362{% macro from_file(name, sp, tp, filename) %}
363             val Pr_{{name}} = let val infile = TextIO.openIn "{{filename}}"
364                                   val S = TensorFile.realTensorRead (infile)
365                                   val _ = TextIO.closeIn infile
366                               in
367                                   SparseMatrix.fromTensorSliceList [N,N]
368                                   [
369                                     {% with %}
370                                     {% set soffset = 0 %}
371                                     {% for s in sp %}
372                                     {% set toffset = 0 %}
373                                     {% for t in tp %}
374                                     {offset=[{{t.start}},{{s.start}}],
375                                      slice=(RTensorSlice.slice ([([{{toffset}},{{soffset}}],[{{toffset}}+{{t.size}}-1,{{soffset}}+{{s.size}}-1])],S)),
376                                      sparse=false}{% if not loop.last %},{% endif %}
377                                     {% set toffset = toffset + t.size %}
378                                     {% endfor %}{% if not loop.last %},{% endif %}
379                                     {% set soffset = soffset + s.size %}
380                                     {% endfor %}
381                                     {% endwith %}
382                                  ]
383                                end
384{% endmacro %}
385
386
387{% macro range_map(name, sp, tp) %}
388             val srangemap_{{name}} =
389                                    [
390                                     {% with %}
391                                     {% set soffset = 0 %}
392                                     {% for s in sp %}
393                                     {size={{s.size}}
394                                      localStart={{soffset}},
395                                      globalStart={{s.start}} }{% if not loop.last %},{% endif %}
396                                     {% set soffset = soffset + s.size %}
397                                     {% endfor %}
398                                     {% endwith %}
399                                    ]
400             val trangemap_{{name}} =
401                                    [
402                                     {% with %}
403                                     {% set toffset = 0 %}
404                                     {% for t in tp %}
405                                     {size={{t.size}}
406                                      localStart={{toffset}},
407                                      globalStart={{t.start}} }{% if not loop.last %},{% endif %}
408                                     {% set toffset = toffset + t.size %}
409                                     {% endfor %}
410                                     {% endwith %}
411                                    ]
412{% endmacro %}
413           
414           
415    fun fprojection () =
416       
417        (let
418             
419             {% for pr in dict (group.projections) %}
420             val _ = putStrLn TextIO.stdOut "constructing {{pr.name}}"
421
422             {% if pr.value.rule.operator == "for-each" %}
423             {% call for_each(pr.name,
424                              pr.value.source.populations,
425                              pr.value.target.populations,
426                              pr.value.plasticity,
427                              pr.value.rule.component,
428                              pr.value.rule.cstate) %}
429             {% endcall %}
430             {% else %}
431             {% if pr.value.rule.operator == "one-to-one" %}
432             {% call one_to_one(pr.name,
433                                pr.value.source.populations,
434                                pr.value.target.populations,
435                                pr.value.plasticity) %}
436             {% endcall %}
437             {% else %}
438             {% if pr.value.rule.operator == "all-to-all" %}
439             {% call all_to_all(pr.name,
440                                pr.value.source.populations,
441                                pr.value.target.populations,
442                                pr.value.plasticity) %}
443             {% endcall %}
444             {% else %}
445             {% if pr.value.rule.operator == "from-file" %}
446             {% call from_file(pr.name,
447                               pr.value.source.populations,
448                               pr.value.target.populations,
449                               pr.value.rule.properties.filename.exprML) %}
450             {% endcall %}
451             {% endif %}
452             {% endif %}
453             {% endif %}
454             {% endif %}
455             {% endfor %}
456               
457       
458{% if group.psrtypes %}{% for psr in dict (group.psrtypes) %}
459{% if psr.value.type == "event" %}
460             val S_{{psr.name}} = foldl1 SparseMatrix.insert
461                                  ([
462                                    {% for pr in psr.value.projections %}
463                                    Pr_{{pr}}{% if not loop.last %},{% endif %}
464                                    {% endfor %}
465                                   ])
466{% endif %}
467{% endfor %}{% else %}
468             val S = foldl1 SparseMatrix.insert
469                     ([
470                       {% for pr in dict (group.projections) %}
471                       {% if pr.value.type == "event" %}
472                       Pr_{{pr.name}}{% if not loop.last %},{% endif %}
473                       {% endif %}
474                       {% endfor %}
475                     ])
476{% endif %}
477
478             {% for pr in dict (group.projections) %}
479             {% if pr.value.type == "continuous" %}
480             {% call range_map(pr.name,
481                               pr.value.source.populations,
482                               pr.value.target.populations) %}
483             {% endcall %}
484             {% endif %}
485             {% endfor %}
486
487             val Elst = 
488                      [
489                        {% for pr in dict (group.projections) %}
490                        {% if pr.value.type == "continuous" %}
491                          ElecGraph.junctionMatrix ([N,N],ElecGraph.elecGraph ({{pr.name}}(srangemap_{{pr.name}},trangemap_{{pr.name}}))
492                          Pr_{{pr.name}}){% if not loop.last %},{% endif %}
493                        {% endif %}
494                        {% endfor %}
495                      ]
496
497             val E = if List.null Elst then NONE else SOME Elst
498
499             in
500              ([
501{% if group.psrtypes %}{% for psr in dict (group.psrtypes) %}
502               S_{{ psr.name }}{% if not loop.last %},{% endif %}
503{% endfor %}{% else %}
504               S
505{% endif %}
506              ])
507             end)
508
509
510end
511       
Note: See TracBrowser for help on using the repository browser.