source: project/release/4/9ML-toolkit/trunk/templates/Network.sml.tmpl @ 30932

Last change on this file since 30932 was 30932, checked in by Ivan Raikov, 6 years ago

9ML-toolkit: adding source and destination port fields to projection

File size: 20.3 KB
Line 
1
2structure {{group.name}} =
3struct
4
5  fun putStrLn out str =
6      (TextIO.output (out, str);
7       TextIO.output (out, "\n"))
8   
9  fun putStr out str =
10      (TextIO.output (out, str))
11     
12  fun showBoolean b = (if b then "1" else "0")
13                     
14  fun showReal n =
15      let open StringCvt
16          open Real
17      in
18          (if n < 0.0 then "-" else "") ^ (fmt (FIX (SOME 12)) (abs n))
19      end
20     
21  fun foldl1 f lst = let val v = List.hd lst
22                         val lst' = List.tl lst
23                     in
24                         List.foldl f v lst'
25                     end
26
27  fun fromDiag (m, n, a, dflt) =
28      if Index.validShape [m,n]
29      then
30          (let
31               val na  = RTensor.Array.length a
32               val na' = na-1
33               val te  = RTensor.new ([m,n], dflt)
34               fun diag (i, j, ia) =
35                   let
36                       val ia' =
37                           (RTensor.update (te, [i,j], RTensor.Array.sub (a, ia));
38                            if ia = na' then 0 else ia+1)
39                   in
40                       if (i=0) orelse (j=0)
41                       then te
42                       else diag (i-1, j-1, ia)
43                   end
44           in
45               diag (m-1, n-1, 0)
46           end)
47      else
48          raise RTensor.Shape
49
50  val RandomInit = RandomMTZig.fromEntropy
51
52  val ZigInit = RandomMTZig.ztnew
53       
54  exception Index       
55
56  val label = "{{group.name}}"
57           
58  val N = {{group.order}}     (* total population size *)
59
60
61  {% for p in dict (group.properties) %}
62  val {{p.name}} = {{p.value.exprML}}
63  {% endfor %}
64
65  {% with timestep = default(group.properties.timestep.exprML, 0.1) %}
66  val h = {{ timestep }}
67  {% endwith %}
68
69  (* network propagation delay *)
70  {% if group.properties.delay %}
71  val D: real = {{group.properties.delay.exprML}}
72  {% else %}
73  {% if group.properties.delayMatrix %}
74  val D: SparseMatrix.matrix = {{group.properties.delayMatrix.exprML}}
75  {% endif %}
76  {% endif %}
77
78  structure TEQ = TEventQueue (val h = h)
79  val DQ = TEQ.empty
80
81  val seed_init = RandomInit() (* seed for randomized initial values *)
82  val zt_init   = ZigInit()
83  fun random_normal () = RandomMTZig.randNormal(seed_init,zt_init)
84  fun random_uniform () = RandomMTZig.randUniform(seed_init)
85
86{% for pop in dict (group.populations) %}
87
88  val N_{{pop.name}} = {{ pop.value.size }}
89
90  val {{pop.name}}_initial = {{pop.value.prototype.initialExprML}}
91
92{% if pop.value.prototype.fieldExprML %}
93  val {{pop.name}}_field_vector =
94    Vector.tabulate (N_{{pop.name}}, fn (i) =>  {{pop.value.prototype.fieldExprML}})
95{% endif %}
96
97  val {{pop.name}}_initial_vector =
98    Vector.tabulate (N_{{pop.name}}, fn (i) =>  {{pop.value.prototype.initialStateExprML}})
99
100  val {{pop.name}}_f = Model_{{pop.name}}.{{pop.value.prototype.ivpFn}}
101
102  fun {{pop.name}}_run (Wnet,n0) (i,input as { {{ join (",", pop.value.prototype.states) }} }) =
103    let
104        val initial = {{pop.name}}_initial
105{% if pop.value.prototype.fieldExprML %}
106        val fieldV = Vector.sub ({{pop.name}}_field_vector,i)
107{% endif %}
108        val {{pop.inputAnalogPorts}}_i  =
109            case Wnet of
110                SOME W => foldl (fn (W,ax) => Real.+ (RTensor.sub(W,[i+n0,0]), ax)) 0.0 W
111              | NONE => 0.0
112        (*val _ = putStrLn TextIO.stdOut ("# {{pop.name}}: t = " ^ (showReal t) ^ " Isyn_i = " ^ (showReal Isyn_i) ^ " V = " ^ (showReal V))*)
113        val nstate = {{pop.name}}_f {{ pop.value.prototype.updateStateML }}
114        val nstate' = {{ pop.value.prototype.copyStateML }}
115    in
116        nstate'
117    end
118
119{% endfor %}
120
121
122{% if group.plastypes %}{% for pl in dict (group.plastypes) %}
123  val {{pl.name}}_initial = {{pl.value.initialExprML}}
124{% endfor %}{% endif %}
125
126
127{% if group.psrtypes %}{% for psr in dict (group.psrtypes) %}
128  val {{psr.name}}_initial = {{psr.value.initialExprML}}
129
130  val {{psr.name}}_initial_vector = Vector.tabulate ({{psr.value.range}},
131                                                     fn (i) => {{psr.value.initialStateExprML}})
132
133  val {{psr.name}}_f = Model_{{psr.name}}.{{psr.value.ivpFn}}
134
135  fun {{psr.name}}_response (W,T) (i,input as { {{ join (",", psr.value.states) }} }) =
136    let
137        val initial   = {{psr.name}}_initial
138        val Ispike_i  = RTensor.sub(W,[i,0])
139        val spike_i   = Real.!= (Ispike_i, 0.0)
140        val tspike_i  = RTensor.sub(T,[i,0])
141        (*val _ = putStrLn TextIO.stdOut ("# {{psr.name}}: t = " ^ (showReal t) ^ " Ispike_i = " ^ (showReal Ispike_i)*)
142        val nstate  = {{psr.name}}_f {{ psr.value.updateStateML }}
143        val nstate' = {{ psr.value.copyStateML }}
144    in
145        RTensor.update(W,[i,0],Real.+(RTensor.sub(W,[i,0]),(#Isyn nstate)));
146        nstate'
147    end
148{% endfor %}{% endif %}
149
150
151{% if group.conntypes %}{% for conn in dict (group.conntypes) %}
152  val {{conn.name}}_initial = {{conn.value.initialExprML}}
153
154  val {{conn.name}}_f = Model_{{conn.name}}.{{conn.value.sysFn}}
155{% endfor %}{% endif %}
156
157    val initial = (
158        {% for pop in dict (group.populations) %}
159        {{pop.name}}_initial_vector{% if not loop.last %},{% endif %}
160        {% endfor %}
161    )
162
163    val psr_initial = (
164        {% if group.psrtypes %}{% for psr in dict (group.psrtypes) %}
165        {{psr.name}}_initial_vector{% if not loop.last %},{% endif %}
166        {% endfor %}{% endif %}
167    )
168
169
170    {% for pop in dict (group.populations) %}
171    val {{pop.name}}_n0 = {{pop.value.start}}
172    {% endfor %}
173
174               
175    val Pn = [
176        {% for pop in dict (group.populations) %}
177        {{pop.name}}_n0{% if not loop.last %},{% endif %}
178        {% endfor %}
179    ]
180         
181    fun frun I
182             (
183              {% for pop in dict (group.populations) %}
184              {{pop.name}}_state_vector{% if not loop.last %},{% endif %}
185              {% endfor %} ) =
186        let
187            {% for pop in dict (group.populations) %}
188            val {{pop.name}}_state_vector' =
189                Vector.mapi ({{pop.name}}_run (I,{{pop.name}}_n0))
190                            {{pop.name}}_state_vector
191
192            {% endfor %}
193        in
194            (
195             {% for pop in dict (group.populations) %}
196             {{pop.name}}_state_vector'{% if not loop.last %},{% endif %}
197             {% endfor %}
198            )
199        end
200
201                         
202    fun fresponse I
203             (
204              {% if group.psrtypes %}{% for psr in dict (group.psrtypes) %}
205              {{psr.name}}_state_vector{% if not loop.last %},{% endif %}
206              {% endfor %}{% endif %}
207             ) =
208        let
209            {% if group.psrtypes %}
210            val (T,W) = case I of SOME I' => I'
211                                | NONE => (RTensor.new ([N,1],0.0),
212                                           List.tabulate ({{length (dict (group.psrtypes))}},
213                                                          fn (i) => (RTensor.new ([N,1],0.0))))
214            {% for psr in dict (group.psrtypes) %}
215            val W' = List.nth (W,{{loop.index0}})
216            val {{psr.name}}_state_vector' =
217                Vector.mapi ({{psr.name}}_response (W',T)) {{psr.name}}_state_vector
218            {% endfor %}{% else %}
219            val W = case I of SOME (T,W) => W
220                            | NONE => [RTensor.new ([N,1],0.0)]
221            {% endif %}
222        in
223            (SOME W,
224             (
225              {% if group.psrtypes %}{% for psr in dict (group.psrtypes) %}
226              {{psr.name}}_state_vector'{% if not loop.last %},{% endif %}
227              {% endfor %}{% endif %}
228             ))
229        end
230
231
232    fun felec E I
233              ({% for pop in dict (group.populations) %}
234              {{pop.name}}_state_vector{% if not loop.last %},{% endif %}
235              {% endfor %}) =
236        case E of NONE => I
237                | SOME E =>
238                  let
239                    val update = Unsafe.Real64Array.update
240                    val I' = case I of SOME I => I
241                                     | NONE => RTensor.new ([N,1],0.0)
242                   
243                  {% for pr in dict (group.projections) %}
244                  {% if pr.value.type == "continuous" %}
245                  {% for spop in pr.value.source %}
246                  {% for tpop in pr.value.target %}
247                    fun {{spop.name}}_sub i = #({{first (spop.value.prototype.states)}})(Vector.sub ({{spop.name}}_state_vector, i))
248                    fun {{tpop.name}}_sub i = #({{first (tpop.value.prototype.states)}})(Vector.sub ({{tpop.name}}_state_vector, i))
249                    val _ = Loop.app
250                                (0, N_{{spop.name}},
251                                 fn (i) =>
252                                    let
253                                        val Vi = {{spop.name}}_sub i
254                                        val sl = SparseMatrix.slice (#{{spop.name}}(E),1,i)
255                                    in
256                                        SparseMatrix.sliceAppi
257                                            (fn (j,g) => let val Vj = {{tpop.name}}_sub j
258                                                         in update (I,i,Real.- (sub(I,i), Real.* (g,Real.- (Vi,Vj)))) end)
259                                            sl
260                                    end)
261                               
262                  {% endfor %}
263                  {% endfor %}
264                  {% endif %}
265                  {% endfor %}
266                  in
267                      SOME I'
268                  end
269                         
270                         
271    fun ftime (
272              {% for pop in dict (group.populations) %}
273              {{pop.name}}_state_vector{% if not loop.last %},{% endif %}
274              {% endfor %} ) =
275       
276        {% with pop = first (dict (group.populations)) %}
277        let val { {{ join (",", pop.value.prototype.states) }} } = Vector.sub ({{pop.name}}_state_vector,0)
278        in {{pop.value.prototype.ivar}} end
279        {% endwith %}
280   
281       
282    fun fspikes (
283              {% for pop in dict (group.populations) %}
284              {{pop.name}}_state_vector{% if not loop.last %},{% endif %}
285              {% endfor %} ) =
286
287        let
288           {% for pop in dict (group.populations) %}
289            val {{pop.name}}_spike_i =
290                Vector.foldri (fn (i,v as { {{ join (",", pop.value.prototype.states) }} },ax) =>
291                               {% if not pop.name in group.spikepoplst %}
292                               (if (#{{first (pop.value.prototype.events)}}(v)) then ((i+{{pop.name}}_n0,#{{first (pop.value.prototype.events)}}Count(v)))::ax else ax))
293                               {% else %}
294                               (if (#{{first (pop.value.prototype.events)}}(v)) then ((i+{{pop.name}}_n0,1.0))::ax else ax))
295                               {% endif %}
296                              [] {{pop.name}}_state_vector
297           {% endfor %}
298
299           val ext_spike_i = List.concat ( {% for pop in dict (group.populations) %}{% if not pop.name in group.spikepoplst %}{{pop.name}}_spike_i ::{% endif %}{% endfor %} [] )
300
301               
302           val neuron_spike_i = List.concat [
303                                 {% for name in (group.spikepoplst) %}
304                                 {{name}}_spike_i{% if not loop.last %},{% endif %}
305                                 {% endfor %}
306                                 ]
307
308            val all_spike_i    = List.concat [neuron_spike_i, ext_spike_i]
309        in
310            (all_spike_i, neuron_spike_i)
311        end
312
313{% macro cartesian_product(sp, tp) %}
314   {% for s in sp %}
315   {% for t in tp %}{{ caller(s,t) }}{% if not loop.last %},{% endif %}{% endfor %}{% if not loop.last %},{% endif %}
316   {% endfor %}
317{% endmacro %}
318
319{% macro for_each(name, sp, tp, plasticity, component, cstate) %}
320             val Pr_{{name}} = let
321                                  val weight  = {% if plasticity %}#weight({{plasticity}}_initial){% else %}1.0{% endif %}
322                               in
323                                  SparseMatrix.fromGeneratorList [N,N]
324                                  [
325                                    {% call cartesian_product (sp,tp) %}
326                                      {offset=[{{t.start}},{{s.start}}],
327                                       fshape=[{{t.size}},{{s.size}}],
328                                       f=(fn (i) => Real.* (weight, #{{cstate}}({{component}}_f {{component}}_initial) ))}
329                                    {% endcall %}
330                                  ]
331                               end
332{% endmacro %}
333
334
335{% macro all_to_all(name, sp, tp, plasticity)  %}
336             val Pr_{{name}} = let
337                                  val weight = {% if plasticity %}#weight({{plasticity}}_initial){% else %}1.0{% endif %}
338                               in
339                                  SparseMatrix.fromTensorList [N,N]
340                                  [
341                                    {% call cartesian_product (sp,tp) %}
342                                    {offset=[{{t.start}},{{s.start}}],
343                                     tensor=(RTensor.*> weight (RTensor.new ([{{t.size}},{{s.size}}],1.0))),
344                                     sparse=false}
345                                    {% endcall %}
346                                  ]
347                                end
348{% endmacro %}
349
350{% macro one_to_one(name, sp, tp, plasticity) %}
351             val Pr_{{name}} = let
352                                  val weight = {% if plasticity %}#weight({{plasticity}}_initial){% else %}1.0{% endif %}
353                               in
354                                  SparseMatrix.fromTensorList [N,N]
355                                  [
356                                    {% call cartesian_product (sp,tp) %}
357                                    {offset=[{{t.start}},{{s.start}}],
358                                     tensor=(fromDiag ({{t.size}},{{s.size}},Real64Array.fromList [{{weight}}],0.0)),
359                                     sparse=true}
360                                    {% endcall %}
361                                  ]
362                               end
363{% endmacro %}
364
365
366{% macro from_file(name, sp, tp, filename) %}
367             val Pr_{{name}} = let val infile = TextIO.openIn "{{filename}}"
368                                   val S = TensorFile.realTensorRead (infile)
369                                   val _ = TextIO.closeIn infile
370                               in
371                                   SparseMatrix.fromTensorSliceList [N,N]
372                                   [
373                                     {% with %}
374                                     {% set soffset = 0 %}
375                                     {% for s in sp %}
376                                     {% set toffset = 0 %}
377                                     {% for t in tp %}
378                                     {offset=[{{t.start}},{{s.start}}],
379                                      slice=(RTensorSlice.slice ([([{{toffset}},{{soffset}}],[{{toffset}}+{{t.size}}-1,{{soffset}}+{{s.size}}-1])],S)),
380                                      sparse=false}{% if not loop.last %},{% endif %}
381                                     {% set toffset = toffset + t.size %}
382                                     {% endfor %}{% if not loop.last %},{% endif %}
383                                     {% set soffset = soffset + s.size %}
384                                     {% endfor %}
385                                     {% endwith %}
386                                  ]
387                                end
388{% endmacro %}
389
390
391{% macro range_map(name, sp, tp) %}
392             val srangemap_{{name}} =
393                                    [
394                                     {% with %}
395                                     {% set soffset = 0 %}
396                                     {% for s in sp %}
397                                     {size={{s.size}}
398                                      localStart={{soffset}},
399                                      globalStart={{s.start}} }{% if not loop.last %},{% endif %}
400                                     {% set soffset = soffset + s.size %}
401                                     {% endfor %}
402                                     {% endwith %}
403                                    ]
404             val trangemap_{{name}} =
405                                    [
406                                     {% with %}
407                                     {% set toffset = 0 %}
408                                     {% for t in tp %}
409                                     {size={{t.size}}
410                                      localStart={{toffset}},
411                                      globalStart={{t.start}} }{% if not loop.last %},{% endif %}
412                                     {% set toffset = toffset + t.size %}
413                                     {% endfor %}
414                                     {% endwith %}
415                                    ]
416{% endmacro %}
417           
418           
419    fun fprojection () =
420       
421        (let
422             
423             {% for pr in dict (group.projections) %}
424             val _ = putStrLn TextIO.stdOut "constructing {{pr.name}}"
425
426             {% if pr.value.rule.operator == "for-each" %}
427             {% call for_each(pr.name,
428                              pr.value.source.populations,
429                              pr.value.target.populations,
430                              pr.value.plasticity,
431                              pr.value.rule.component,
432                              pr.value.rule.cstate) %}
433             {% endcall %}
434             {% else %}
435             {% if pr.value.rule.operator == "one-to-one" %}
436             {% call one_to_one(pr.name,
437                                pr.value.source.populations,
438                                pr.value.target.populations,
439                                pr.value.plasticity) %}
440             {% endcall %}
441             {% else %}
442             {% if pr.value.rule.operator == "all-to-all" %}
443             {% call all_to_all(pr.name,
444                                pr.value.source.populations,
445                                pr.value.target.populations,
446                                pr.value.plasticity) %}
447             {% endcall %}
448             {% else %}
449             {% if pr.value.rule.operator == "from-file" %}
450             {% call from_file(pr.name,
451                               pr.value.source.populations,
452                               pr.value.target.populations,
453                               pr.value.rule.properties.filename.exprML) %}
454             {% endcall %}
455             {% endif %}
456             {% endif %}
457             {% endif %}
458             {% endif %}
459             {% endfor %}
460               
461       
462{% if group.psrtypes %}{% for psr in dict (group.psrtypes) %}
463{% if psr.value.type == "event" %}
464             val S_{{psr.name}} = foldl1 SparseMatrix.insert
465                                  ([
466                                    {% for pr in psr.value.projections %}
467                                    Pr_{{pr}}{% if not loop.last %},{% endif %}
468                                    {% endfor %}
469                                   ])
470{% endif %}
471{% endfor %}{% else %}
472             val S = foldl1 SparseMatrix.insert
473                     ([
474                       {% for pr in dict (group.projections) %}
475                       {% if pr.value.type == "event" %}
476                       Pr_{{pr.name}}{% if not loop.last %},{% endif %}
477                       {% endif %}
478                       {% endfor %}
479                     ])
480{% endif %}
481
482             {% for pr in dict (group.projections) %}
483             {% if pr.value.type == "continuous" %}
484             {% call range_map(pr.name,
485                               pr.value.source.populations,
486                               pr.value.target.populations) %}
487             {% endcall %}
488             {% endif %}
489             {% endfor %}
490
491             val Elst = 
492                      [
493                        {% for pr in dict (group.projections) %}
494                        {% if pr.value.type == "continuous" %}
495                          ElecGraph.junctionMatrix ([N,N],ElecGraph.elecGraph ({{pr.name}}(srangemap_{{pr.name}},trangemap_{{pr.name}}))
496                          Pr_{{pr.name}}){% if not loop.last %},{% endif %}
497                        {% endif %}
498                        {% endfor %}
499                      ]
500
501             val E = if List.null Elst then NONE else SOME Elst
502
503             in
504              ([
505{% if group.psrtypes %}{% for psr in dict (group.psrtypes) %}
506               S_{{ psr.name }}{% if not loop.last %},{% endif %}
507{% endfor %}{% else %}
508               S
509{% endif %}
510              ])
511             end)
512
513
514end
515       
Note: See TracBrowser for help on using the repository browser.