source: project/release/4/9ML-toolkit/trunk/templates/Network.sml.tmpl @ 30881

Last change on this file since 30881 was 30881, checked in by Ivan Raikov, 6 years ago

9ML-toolkit: a round of bug fixes to synapse handling and code generation for relations

File size: 19.9 KB
Line 
1
2structure {{group.name}} =
3struct
4
5  fun putStrLn out str =
6      (TextIO.output (out, str);
7       TextIO.output (out, "\n"))
8   
9  fun putStr out str =
10      (TextIO.output (out, str))
11     
12  fun showBoolean b = (if b then "1" else "0")
13                     
14  fun showReal n =
15      let open StringCvt
16          open Real
17      in
18          (if n < 0.0 then "-" else "") ^ (fmt (FIX (SOME 12)) (abs n))
19      end
20     
21  fun foldl1 f lst = let val v = List.hd lst
22                         val lst' = List.tl lst
23                     in
24                         List.foldl f v lst'
25                     end
26
27  fun fromDiag (m, n, a, dflt) =
28      if Index.validShape [m,n]
29      then
30          (let
31               val na  = RTensor.Array.length a
32               val na' = na-1
33               val te  = RTensor.new ([m,n], dflt)
34               fun diag (i, j, ia) =
35                   let
36                       val ia' =
37                           (RTensor.update (te, [i,j], RTensor.Array.sub (a, ia));
38                            if ia = na' then 0 else ia+1)
39                   in
40                       if (i=0) orelse (j=0)
41                       then te
42                       else diag (i-1, j-1, ia)
43                   end
44           in
45               diag (m-1, n-1, 0)
46           end)
47      else
48          raise RTensor.Shape
49
50  val RandomInit = RandomMTZig.fromEntropy
51
52  val ZigInit = RandomMTZig.ztnew
53       
54  exception Index       
55
56  val label = "{{group.name}}"
57           
58  val N = {{group.order}}     (* total population size *)
59
60
61  {% for p in dict (group.properties) %}
62  val {{p.name}} = {{p.value.exprML}}
63  {% endfor %}
64
65  {% with timestep = default(group.properties.timestep.exprML, 0.1) %}
66  val h = {{ timestep }}
67  {% endwith %}
68
69  (* network propagation delay *)
70  {% if group.properties.delay %}
71  val D: real = {{group.properties.delay.exprML}}
72  {% else %}
73  {% if group.properties.delayMatrix %}
74  val D: SparseMatrix.matrix = {{group.properties.delayMatrix.exprML}}
75  {% endif %}
76  {% endif %}
77
78  structure TEQ = TEventQueue (val h = h)
79  val DQ = TEQ.empty
80
81  val seed_init = RandomInit() (* seed for randomized initial values *)
82  val zt_init   = ZigInit()
83  fun random_normal () = RandomMTZig.randNormal(seed_init,zt_init)
84  fun random_uniform () = RandomMTZig.randUniform(seed_init)
85
86{% for pop in dict (group.populations) %}
87
88  val N_{{pop.name}} = {{ pop.value.size }}
89
90  val {{pop.name}}_initial = {{pop.value.prototype.initialExprML}}
91
92{% if pop.value.prototype.fieldExprML %}
93  val {{pop.name}}_field_vector =
94    Vector.tabulate (N_{{pop.name}}, fn (i) =>  {{pop.value.prototype.fieldExprML}})
95{% endif %}
96
97  val {{pop.name}}_initial_vector =
98    Vector.tabulate (N_{{pop.name}}, fn (i) =>  {{pop.value.prototype.initialStateExprML}})
99
100  val {{pop.name}}_f = Model_{{pop.name}}.{{pop.value.prototype.ivpFn}}
101
102  fun {{pop.name}}_run (Wnet,n0) (i,input as { {{ join (",", pop.value.prototype.states) }} }) =
103    let
104        val initial = {{pop.name}}_initial
105{% if pop.value.prototype.fieldExprML %}
106        val fieldV = Vector.sub ({{pop.name}}_field_vector,i)
107{% endif %}
108        val Isyn_i  = case Wnet of SOME W => RTensor.sub(W,[i+n0,0]) | NONE => 0.0
109        (*val _ = putStrLn TextIO.stdOut ("# {{pop.name}}: t = " ^ (showReal t) ^ " Isyn_i = " ^ (showReal Isyn_i) ^ " V = " ^ (showReal V))*)
110        val nstate = {{pop.name}}_f {{ pop.value.prototype.updateStateML }}
111        val nstate' = {{ pop.value.prototype.copyStateML }}
112    in
113        nstate'
114    end
115
116{% endfor %}
117
118
119{% if group.plastypes %}{% for pl in dict (group.plastypes) %}
120  val {{pl.name}}_initial = {{pl.value.initialExprML}}
121{% endfor %}{% endif %}
122
123
124{% if group.psrtypes %}{% for psr in dict (group.psrtypes) %}
125  val {{psr.name}}_initial = {{psr.value.initialExprML}}
126
127  val {{psr.name}}_initial_vector = Vector.tabulate ({{psr.value.range}},
128                                                     fn (i) => {{psr.value.initialStateExprML}})
129
130  val {{psr.name}}_f = Model_{{psr.name}}.{{psr.value.ivpFn}}
131
132  fun {{psr.name}}_response (W,T) (i,input as { {{ join (",", psr.value.states) }} }) =
133    let
134        val initial   = {{psr.name}}_initial
135        val Ispike_i  = RTensor.sub(W,[i,0])
136        val spike_i   = Real.!= (Ispike_i, 0.0)
137        val tspike_i  = RTensor.sub(T,[i,0])
138        (*val _ = putStrLn TextIO.stdOut ("# {{psr.name}}: t = " ^ (showReal t) ^ " Ispike_i = " ^ (showReal Ispike_i)*)
139        val nstate  = {{psr.name}}_f {{ psr.value.updateStateML }}
140        val nstate' = {{ psr.value.copyStateML }}
141    in
142        RTensor.update(W,[i,0],(#Isyn nstate'));
143        nstate'
144    end
145{% endfor %}{% endif %}
146
147
148{% if group.conntypes %}{% for conn in dict (group.conntypes) %}
149  val {{conn.name}}_initial = {{conn.value.initialExprML}}
150
151  val {{conn.name}}_f = Model_{{conn.name}}.{{conn.value.sysFn}}
152{% endfor %}{% endif %}
153
154    val initial = (
155        {% for pop in dict (group.populations) %}
156        {{pop.name}}_initial_vector{% if not loop.last %},{% endif %}
157        {% endfor %}
158    )
159
160    val psr_initial = (
161        {% if group.psrtypes %}{% for psr in dict (group.psrtypes) %}
162        {{psr.name}}_initial_vector{% if not loop.last %},{% endif %}
163        {% endfor %}{% endif %}
164    )
165
166
167    {% for pop in dict (group.populations) %}
168    val {{pop.name}}_n0 = {{pop.value.start}}
169    {% endfor %}
170
171               
172    val Pn = [
173        {% for pop in dict (group.populations) %}
174        {{pop.name}}_n0{% if not loop.last %},{% endif %}
175        {% endfor %}
176    ]
177         
178    fun frun I
179             (
180              {% for pop in dict (group.populations) %}
181              {{pop.name}}_state_vector{% if not loop.last %},{% endif %}
182              {% endfor %} ) =
183        let
184            {% for pop in dict (group.populations) %}
185            val {{pop.name}}_state_vector' =
186                Vector.mapi ({{pop.name}}_run (I,{{pop.name}}_n0))
187                            {{pop.name}}_state_vector
188
189            {% endfor %}
190        in
191            (
192             {% for pop in dict (group.populations) %}
193             {{pop.name}}_state_vector'{% if not loop.last %},{% endif %}
194             {% endfor %}
195            )
196        end
197
198                         
199    fun fresponse I
200             (
201              {% if group.psrtypes %}{% for psr in dict (group.psrtypes) %}
202              {{psr.name}}_state_vector{% if not loop.last %},{% endif %}
203              {% endfor %}{% endif %}
204             ) =
205        let
206            {% if group.psrtypes %}{% for psr in dict (group.psrtypes) %}
207            val (W,T) = case I of SOME (W,T) => (W,T)
208                             | NONE => (RTensor.new ([N,1],0.0),
209                                        RTensor.new ([N,1],0.0))
210
211            val {{psr.name}}_state_vector' =
212                Vector.mapi ({{psr.name}}_response (W,T))
213                            {{psr.name}}_state_vector
214
215            {% endfor %}{% endif %}
216        in
217            ({% if group.psrtypes %}SOME W{% else %}W{% endif %},
218             (
219              {% if group.psrtypes %}{% for psr in dict (group.psrtypes) %}
220              {{psr.name}}_state_vector'{% if not loop.last %},{% endif %}
221              {% endfor %}{% endif %}
222             ))
223        end
224
225
226    fun felec E I
227              ({% for pop in dict (group.populations) %}
228              {{pop.name}}_state_vector{% if not loop.last %},{% endif %}
229              {% endfor %}) =
230        case E of NONE => I
231                | SOME E =>
232                  let
233                    val update = Unsafe.Real64Array.update
234                    val I' = case I of SOME I => I
235                                     | NONE => RTensor.new ([N,1],0.0)
236                   
237                  {% for pr in dict (group.projections) %}
238                  {% if pr.value.type == "continuous" %}
239                  {% for spop in pr.value.source %}
240                  {% for tpop in pr.value.target %}
241                    fun {{spop.name}}_sub i = #({{first (spop.value.prototype.states)}})(Vector.sub ({{spop.name}}_state_vector, i))
242                    fun {{tpop.name}}_sub i = #({{first (tpop.value.prototype.states)}})(Vector.sub ({{tpop.name}}_state_vector, i))
243                    val _ = Loop.app
244                                (0, N_{{spop.name}},
245                                 fn (i) =>
246                                    let
247                                        val Vi = {{spop.name}}_sub i
248                                        val sl = SparseMatrix.slice (#{{spop.name}}(E),1,i)
249                                    in
250                                        SparseMatrix.sliceAppi
251                                            (fn (j,g) => let val Vj = {{tpop.name}}_sub j
252                                                         in update (I,i,Real.- (sub(I,i), Real.* (g,Real.- (Vi,Vj)))) end)
253                                            sl
254                                    end)
255                               
256                  {% endfor %}
257                  {% endfor %}
258                  {% endif %}
259                  {% endfor %}
260                  in
261                      SOME I'
262                  end
263                         
264                         
265    fun ftime (
266              {% for pop in dict (group.populations) %}
267              {{pop.name}}_state_vector{% if not loop.last %},{% endif %}
268              {% endfor %} ) =
269       
270        {% with pop = first (dict (group.populations)) %}
271        let val { {{ join (",", pop.value.prototype.states) }} } = Vector.sub ({{pop.name}}_state_vector,0)
272        in {{pop.value.prototype.ivar}} end
273        {% endwith %}
274   
275       
276    fun fspikes (
277              {% for pop in dict (group.populations) %}
278              {{pop.name}}_state_vector{% if not loop.last %},{% endif %}
279              {% endfor %} ) =
280
281        let
282           {% for pop in dict (group.populations) %}
283            val {{pop.name}}_spike_i =
284                Vector.foldri (fn (i,v as { {{ join (",", pop.value.prototype.states) }} },ax) =>
285                               {% if not pop.name in group.spikepoplst %}
286                               (if (#{{first (pop.value.prototype.events)}}(v)) then ((i+{{pop.name}}_n0,#{{first (pop.value.prototype.events)}}Count(v)))::ax else ax))
287                               {% else %}
288                               (if (#{{first (pop.value.prototype.events)}}(v)) then ((i+{{pop.name}}_n0,1.0))::ax else ax))
289                               {% endif %}
290                              [] {{pop.name}}_state_vector
291           {% endfor %}
292
293           val ext_spike_i = List.concat ( {% for pop in dict (group.populations) %}{% if not pop.name in group.spikepoplst %}{{pop.name}}_spike_i ::{% endif %}{% endfor %} [] )
294
295               
296           val neuron_spike_i = List.concat [
297                                 {% for name in (group.spikepoplst) %}
298                                 {{name}}_spike_i{% if not loop.last %},{% endif %}
299                                 {% endfor %}
300                                 ]
301
302            val all_spike_i    = List.concat [neuron_spike_i, ext_spike_i]
303        in
304            (all_spike_i, neuron_spike_i)
305        end
306
307{% macro cartesian_product(sp, tp) %}
308   {% for s in sp %}
309   {% for t in tp %}{{ caller(s,t) }}{% if not loop.last %},{% endif %}{% endfor %}{% if not loop.last %},{% endif %}
310   {% endfor %}
311{% endmacro %}
312
313{% macro for_each(name, sp, tp, plasticity, component, cstate) %}
314             val Pr_{{name}} = let
315                                  val weight  = {% if plasticity %}#weight({{plasticity}}_initial){% else %}1.0{% endif %}
316                               in
317                                  SparseMatrix.fromGeneratorList [N,N]
318                                  [
319                                    {% call cartesian_product (sp,tp) %}
320                                      {offset=[{{t.start}},{{s.start}}],
321                                       fshape=[{{t.size}},{{s.size}}],
322                                       f=(fn (i) => Real.* (weight, #{{cstate}}({{component}}_f {{component}}_initial) ))}
323                                    {% endcall %}
324                                  ]
325                               end
326{% endmacro %}
327
328
329{% macro all_to_all(name, sp, tp, plasticity)  %}
330             val Pr_{{name}} = let
331                                  val weight = {% if plasticity %}#weight({{plasticity}}_initial){% else %}1.0{% endif %}
332                               in
333                                  SparseMatrix.fromTensorList [N,N]
334                                  [
335                                    {% call cartesian_product (sp,tp) %}
336                                    {offset=[{{t.start}},{{s.start}}],
337                                     tensor=(RTensor.*> weight (RTensor.new ([{{t.size}},{{s.size}}],1.0))),
338                                     sparse=false}
339                                    {% endcall %}
340                                  ]
341                                end
342{% endmacro %}
343
344{% macro one_to_one(name, sp, tp, plasticity) %}
345             val Pr_{{name}} = let
346                                  val weight = {% if plasticity %}#weight({{plasticity}}_initial){% else %}1.0{% endif %}
347                               in
348                                  SparseMatrix.fromTensorList [N,N]
349                                  [
350                                    {% call cartesian_product (sp,tp) %}
351                                    {offset=[{{t.start}},{{s.start}}],
352                                     tensor=(fromDiag ({{t.size}},{{s.size}},Real64Array.fromList [{{weight}}],0.0)),
353                                     sparse=true}
354                                    {% endcall %}
355                                  ]
356                               end
357{% endmacro %}
358
359
360{% macro from_file(name, sp, tp, filename) %}
361             val Pr_{{name}} = let val infile = TextIO.openIn "{{filename}}"
362                                   val S = TensorFile.realTensorRead (infile)
363                                   val _ = TextIO.closeIn infile
364                               in
365                                   SparseMatrix.fromTensorSliceList [N,N]
366                                   [
367                                     {% with %}
368                                     {% set soffset = 0 %}
369                                     {% for s in sp %}
370                                     {% set toffset = 0 %}
371                                     {% for t in tp %}
372                                     {offset=[{{t.start}},{{s.start}}],
373                                      slice=(RTensorSlice.slice ([([{{toffset}},{{soffset}}],[{{toffset}}+{{t.size}}-1,{{soffset}}+{{s.size}}-1])],S)),
374                                      sparse=false}{% if not loop.last %},{% endif %}
375                                     {% set toffset = toffset + t.size %}
376                                     {% endfor %}{% if not loop.last %},{% endif %}
377                                     {% set soffset = soffset + s.size %}
378                                     {% endfor %}
379                                     {% endwith %}
380                                  ]
381                                end
382{% endmacro %}
383
384
385{% macro range_map(name, sp, tp) %}
386             val srangemap_{{name}} =
387                                    [
388                                     {% with %}
389                                     {% set soffset = 0 %}
390                                     {% for s in sp %}
391                                     {size={{s.size}}
392                                      localStart={{soffset}},
393                                      globalStart={{s.start}} }{% if not loop.last %},{% endif %}
394                                     {% set soffset = soffset + s.size %}
395                                     {% endfor %}
396                                     {% endwith %}
397                                    ]
398             val trangemap_{{name}} =
399                                    [
400                                     {% with %}
401                                     {% set toffset = 0 %}
402                                     {% for t in tp %}
403                                     {size={{t.size}}
404                                      localStart={{toffset}},
405                                      globalStart={{t.start}} }{% if not loop.last %},{% endif %}
406                                     {% set toffset = toffset + t.size %}
407                                     {% endfor %}
408                                     {% endwith %}
409                                    ]
410{% endmacro %}
411           
412           
413    fun fprojection () =
414       
415        (let
416             
417             {% for pr in dict (group.projections) %}
418             val _ = putStrLn TextIO.stdOut "constructing {{pr.name}}"
419
420             {% if pr.value.rule.operator == "for-each" %}
421             {% call for_each(pr.name,
422                              pr.value.source.populations,
423                              pr.value.target.populations,
424                              pr.value.plasticity,
425                              pr.value.rule.component,
426                              pr.value.rule.cstate) %}
427             {% endcall %}
428             {% else %}
429             {% if pr.value.rule.operator == "one-to-one" %}
430             {% call one_to_one(pr.name,
431                                pr.value.source.populations,
432                                pr.value.target.populations,
433                                pr.value.plasticity) %}
434             {% endcall %}
435             {% else %}
436             {% if pr.value.rule.operator == "all-to-all" %}
437             {% call all_to_all(pr.name,
438                                pr.value.source.populations,
439                                pr.value.target.populations,
440                                pr.value.plasticity) %}
441             {% endcall %}
442             {% else %}
443             {% if pr.value.rule.operator == "from-file" %}
444             {% call from_file(pr.name,
445                               pr.value.source.populations,
446                               pr.value.target.populations,
447                               pr.value.rule.properties.filename.exprML) %}
448             {% endcall %}
449             {% endif %}
450             {% endif %}
451             {% endif %}
452             {% endif %}
453             {% endfor %}
454               
455       
456{% if group.psrtypes %}{% for psr in dict (group.psrtypes) %}
457{% if psr.value.type == "event" %}
458             val S_{{psr.name}} = foldl1 SparseMatrix.insert
459                                  ([
460                                    {% for pr in psr.value.projections %}
461                                    Pr_{{pr}}{% if not loop.last %},{% endif %}
462                                    {% endfor %}
463                                   ])
464{% endif %}
465{% endfor %}{% else %}
466             val S = foldl1 SparseMatrix.insert
467                     ([
468                       {% for pr in dict (group.projections) %}
469                       {% if pr.value.type == "event" %}
470                       Pr_{{pr.name}}{% if not loop.last %},{% endif %}
471                       {% endif %}
472                       {% endfor %}
473                     ])
474{% endif %}
475
476             {% for pr in dict (group.projections) %}
477             {% if pr.value.type == "continuous" %}
478             {% call range_map(pr.name,
479                               pr.value.source.populations,
480                               pr.value.target.populations) %}
481             {% endcall %}
482             {% endif %}
483             {% endfor %}
484
485             val Elst = 
486                      [
487                        {% for pr in dict (group.projections) %}
488                        {% if pr.value.type == "continuous" %}
489                          ElecGraph.junctionMatrix ([N,N],ElecGraph.elecGraph ({{pr.name}}(srangemap_{{pr.name}},trangemap_{{pr.name}}))
490                          Pr_{{pr.name}}){% if not loop.last %},{% endif %}
491                        {% endif %}
492                        {% endfor %}
493                      ]
494
495             val E = if List.null Elst then NONE else SOME Elst
496
497             in
498              ([
499{% if group.psrtypes %}{% for psr in dict (group.psrtypes) %}
500               S_{{ psr.name }}{% if not loop.last %},{% endif %}
501{% endfor %}{% else %}
502               S
503{% endif %}
504              ])
505             end)
506
507
508end
509       
Note: See TracBrowser for help on using the repository browser.