source: project/release/4/9ML-toolkit/trunk/templates/Network.sml.elec.tmpl @ 29994

Last change on this file since 29994 was 29994, checked in by Ivan Raikov, 8 years ago

9ML-toolkit: additions to electrical synapses infrastructure

File size: 19.3 KB
Line 
1
2signature NODE_GRAPH =
3sig
4    type symbol = Symbol.symbol
5    type nindex = int
6    type rangemap = (int * int) list
7
8    val nodeGraph : (symbol * SparseMatrix.matrix * rangemap) ->
9                    (nindex, real, unit) Graph.graph
10                   
11    val junctionMatrix : ((nindex, real, unit) Graph.graph) -> SparseMatrix.matrix
12
13
14end
15
16structure NodeGraph: NODE_GRAPH =
17struct
18
19type symbol   = Symbol.symbol
20type index    = int
21type rangemap = {localStart: int, localEnd: int, globalStart: int} list
22
23exception Index
24
25fun nodeGraph (S, rangemap) =
26    let
27        val mapIndex m i = case List.find (fn (({localStart,localEnd,globalStart} =>
28                                                i >= localStart andalso i < localEnd)) m of
29                               SOME ({localStart,localEnd,globalStart}) => globalStart + (i-localStart)
30                             | NONE => raise Index
31       
32        val [N,_] = SparseMatrix.shape S
33       
34        val G as Graph.GRAPH g =
35            DirectedGraph.graph(Symbol.name modelname,(),N) :
36            (index,real,unit) Graph.graph
37
38        val idxs  = List.tabulate (N, fn (i) => i)
39        val gidxs = List.tabulate (N, fn (i) => mapIndex rangemap i)
40
41        val _ = ListPair.app (#add_node g) (gidxs,idxs)
42
43        val add_edge = #add_edge g
44
45        val _ List.app
46              (fn (s) =>
47                  let
48                      val sl = SparseMatrix.slice (S,1,i)
49                  in
50                      SparseMatrix.sliceAppi
51                          (fn (t,v) => add_edge (mapIndex rangemap s, mapIndex rangemap t,v))
52                          sl
53                  end)
54              idxs
55
56    in
57        G
58    end
59
60fun junctionMatrix (Graph.GRAPH g) =
61    let
62        fun nodeCoeffs n =
63            let
64                val out  = (#out_edges g) n
65                val self = Real.- (~1.0, foldl (fn ((s,t,v),ax) => Real.+(v,ax)) 0.0 out)
66            in
67                (n, (n,self) :: (map (fn (s,t,v) => (t,v)) out))
68            end
69
70        val lst = ref []
71           
72    in
73        ((#forall_nodes g) (fn (n,_) => (lst := ((nodeCoeffs n) :: !lst)));
74         SparseMatrix.fromLists (!lst))
75    end
76
77end
78
79structure {{group.name}} =
80struct
81
82  fun putStrLn out str =
83      (TextIO.output (out, str);
84       TextIO.output (out, "\n"))
85   
86  fun putStr out str =
87      (TextIO.output (out, str))
88     
89  fun showBoolean b = (if b then "1" else "0")
90                     
91  fun showReal n =
92      let open StringCvt
93          open Real
94      in
95          (if n < 0.0 then "-" else "") ^ (fmt (FIX (SOME 12)) (abs n))
96      end
97     
98  fun foldl1 f lst = let val v = List.hd lst
99                         val lst' = List.tl lst
100                     in
101                         List.foldl f v lst'
102                     end
103
104  fun fromDiag (m, n, a, dflt) =
105      if Index.validShape [m,n]
106      then
107          (let
108               val na  = RTensor.Array.length a
109               val na' = na-1
110               val te  = RTensor.new ([m,n], dflt)
111               fun diag (i, j, ia) =
112                   let
113                       val ia' =
114                           (RTensor.update (te, [i,j], RTensor.Array.sub (a, ia));
115                            if ia = na' then 0 else ia+1)
116                   in
117                       if (i=0) orelse (j=0)
118                       then te
119                       else diag (i-1, j-1, ia)
120                   end
121           in
122               diag (m-1, n-1, 0)
123           end)
124      else
125          raise RTensor.Shape
126
127  val RandomInit = RandomMTZig.fromEntropy
128
129  val ZigInit = RandomMTZig.ztnew
130       
131  exception Index       
132
133  val label = "{{group.name}}"
134           
135  val N = {{group.order}}     (* total population size *)
136
137
138  {% for p in dict (group.properties) %}
139  val {{p.name}} = {{p.value.exprML}}
140  {% endfor %}
141
142  {% with timestep = default(group.properties.timestep.exprML, 0.1) %}
143  val h = {{ timestep }}
144  {% endwith %}
145
146  (* delay expressed as # time steps *)
147  val D: (RTensor.tensor option) list  = List.tabulate (Real.round (Real.max (Real./({{group.properties.delay.exprML}},h),1.0)), fn i => NONE)
148
149  val seed_init = RandomInit() (* seed for initial membrane potential *)
150  val zt_init   = ZigInit()
151  fun randomNormal () = RandomMTZig.randNormal(seed_init,zt_init)
152  fun randomUniform () = RandomMTZig.randUniform(seed_init)
153
154{% for pop in dict (group.populations) %}
155
156  val N_{{pop.name}} = {{ pop.value.size }}
157
158  val {{pop.name}}_initial = {{pop.value.prototype.initialExprML}}
159
160{% if pop.value.prototype.fieldExprML %}
161  val {{pop.name}}_field_vector =
162    Vector.tabulate (N_{{pop.name}}, fn (i) =>  {{pop.value.prototype.fieldExprML}})
163{% endif %}
164
165  val {{pop.name}}_initial_vector =
166    Vector.tabulate (N_{{pop.name}}, fn (i) =>  {{pop.value.prototype.initialStateExprML}})
167
168  val {{pop.name}}_f = Model_{{pop.name}}.{{pop.value.prototype.ivpFn}}
169
170  fun {{pop.name}}_run (Wnet,n0) (i,input as { {{ join (",", pop.value.prototype.states) }} }) =
171    let
172        val initial = {{pop.name}}_initial
173{% if pop.value.prototype.fieldExprML %}
174        val fieldV = Vector.sub ({{pop.name}}_field_vector,i)
175{% endif %}
176        val Isyn_i  = case Wnet of SOME W => RTensor.sub(W,[i+n0,0]) | NONE => 0.0
177        (*val _ = putStrLn TextIO.stdOut ("# {{pop.name}}: t = " ^ (showReal t) ^ " Isyn_i = " ^ (showReal Isyn_i) ^ " V = " ^ (showReal V))*)
178        val nstate = {{pop.name}}_f {{ pop.value.prototype.copyStateIsynML }}
179        val nstate' = {{ pop.value.prototype.copyStateNstateML }}
180    in
181        nstate'
182    end
183
184{% endfor %}
185
186{% if group.psrtypes %}{% for psr in dict (group.psrtypes) %}
187  val {{psr.name}}_initial = {{psr.value.initialExprML}}
188
189  val {{psr.name}}_initial_vector = Vector.tabulate ({{psr.value.range}}, fn (i) => {{psr.value.initialStateExprML}})
190
191  val {{psr.name}}_f = Model_{{psr.name}}.{{psr.value.ivpFn}}
192
193  fun {{psr.name}}_response W (i,input as { {{ join (",", psr.value.states) }} }) =
194    let
195        val initial = {{psr.name}}_initial
196        val Ispike_i  = RTensor.sub(W,[i,0])
197        (*val _ = putStrLn TextIO.stdOut ("# {{psr.name}}: t = " ^ (showReal t) ^ " Ispike_i = " ^ (showReal Ispike_i)*)
198        val nstate = {{psr.name}}_f {{ psr.value.copyStateIspikeML }}
199        val nstate' = {{ psr.value.copyStateNstateML }}
200    in
201        RTensor.update(W,[i,0],(#Isyn nstate'));
202        nstate'
203    end
204{% endfor %}{% endif %}
205
206    val initial = (
207        {% for pop in dict (group.populations) %}
208        {{pop.name}}_initial_vector{% if not loop.last %},{% endif %}
209        {% endfor %}
210    )
211
212    val psr_initial = (
213        {% if group.psrtypes %}{% for psr in dict (group.psrtypes) %}
214        {{psr.name}}_initial_vector{% if not loop.last %},{% endif %}
215        {% endfor %}{% endif %}
216    )
217
218
219    {% for pop in dict (group.populations) %}
220    val {{pop.name}}_n0 = {{pop.value.start}}
221    {% endfor %}
222
223               
224    val Pn = [
225        {% for pop in dict (group.populations) %}
226        {{pop.name}}_n0{% if not loop.last %},{% endif %}
227        {% endfor %}
228    ]
229         
230    fun frun I
231             (
232              {% for pop in dict (group.populations) %}
233              {{pop.name}}_state_vector{% if not loop.last %},{% endif %}
234              {% endfor %} ) =
235        let
236            {% for pop in dict (group.populations) %}
237            val {{pop.name}}_state_vector' =
238                Vector.mapi ({{pop.name}}_run (I,{{pop.name}}_n0))
239                            {{pop.name}}_state_vector
240
241            {% endfor %}
242        in
243            (
244             {% for pop in dict (group.populations) %}
245             {{pop.name}}_state_vector'{% if not loop.last %},{% endif %}
246             {% endfor %}
247            )
248        end
249
250                         
251    fun fresponse I
252             (
253              {% if group.psrtypes %}{% for psr in dict (group.psrtypes) %}
254              {{psr.name}}_state_vector{% if not loop.last %},{% endif %}
255              {% endfor %}{% endif %}
256             ) =
257        let
258            {% if group.psrtypes %}{% for psr in dict (group.psrtypes) %}
259            val I' = case I of SOME I => I
260                             | NONE => RTensor.new ([N,1],0.0)
261
262            val {{psr.name}}_state_vector' =
263                Vector.mapi ({{psr.name}}_response I')
264                            {{psr.name}}_state_vector
265
266            {% endfor %}{% endif %}
267        in
268            ({% if group.psrtypes %}SOME I'{% else %}I{% endif %},
269             (
270              {% if group.psrtypes %}{% for psr in dict (group.psrtypes) %}
271              {{psr.name}}_state_vector'{% if not loop.last %},{% endif %}
272              {% endfor %}{% endif %}
273             ))
274        end
275                         
276
277    fun felec E I
278              ({% for pop in dict (group.populations) %}
279              {{pop.name}}_state_vector{% if not loop.last %},{% endif %}
280              {% endfor %} ) =
281        case E of NONE => I
282                | SOME E =>
283                  let
284                    val update = Unsafe.Real64Array.update
285                    val I' = case I of SOME I => I
286                                     | NONE => RTensor.new ([N,1],0.0)
287                  {% for pr in dict (group.projections) %}
288                  {% if pr.value.type == "cvar" %}
289                  {% for spop in pr.value.source %}
290                  {% for tpop in pr.value.target %}
291                    fun {{spop.name}}_sub i = #({{first (spop.value.prototype.states)}})(Vector.sub ({{spop.name}}_state_vector, i))
292                    fun {{tpop.name}}_sub i = #({{first (tpop.value.prototype.states)}})(Vector.sub ({{tpop.name}}_state_vector, i))
293                    val _ = Loop.app
294                                (0, N_{{spop.name}},
295                                 fn (i) =>
296                                    let
297                                        val Vi = {{spop.name}}_sub i
298                                        val sl = SparseMatrix.slice (#{{spop.name}}(E),1,i)
299                                    in
300                                        SparseMatrix.sliceAppi
301                                            (fn (j,g) => let val Vj = {{tpop.name}}_sub j
302                                                         in update (I,i,Real.- (sub(I,i), Real.* (g,Real.- (Vi,Vj)))) end)
303                                            sl
304                                    end)
305                               
306                  {% endfor %}
307                  {% endfor %}
308                  {% endif %}
309                  {% endfor %}
310                  in
311                  end
312                         
313    fun ftime (
314              {% for pop in dict (group.populations) %}
315              {{pop.name}}_state_vector{% if not loop.last %},{% endif %}
316              {% endfor %} ) =
317       
318        {% with pop = first (dict (group.populations)) %}
319        let val { {{ join (",", pop.value.prototype.states) }} } = Vector.sub ({{pop.name}}_state_vector,0)
320        in {{pop.value.prototype.ivar}} end
321        {% endwith %}
322   
323       
324    fun fspikes (
325              {% for pop in dict (group.populations) %}
326              {{pop.name}}_state_vector{% if not loop.last %},{% endif %}
327              {% endfor %} ) =
328
329        let
330           {% for pop in dict (group.populations) %}
331            val {{pop.name}}_spike_i =
332                Vector.foldri (fn (i,v as { {{ join (",", pop.value.prototype.states) }} },ax) =>
333                               {% if not pop.name in group.spikepoplst %}
334                               (if (#{{first (pop.value.prototype.events)}}(v)) then ((i+{{pop.name}}_n0,#{{first (pop.value.prototype.events)}}Count(v)))::ax else ax))
335                               {% else %}
336                               (if (#{{first (pop.value.prototype.events)}}(v)) then ((i+{{pop.name}}_n0,1.0))::ax else ax))
337                               {% endif %}
338                              [] {{pop.name}}_state_vector
339           {% endfor %}
340
341           val ext_spike_i = List.concat ( {% for pop in dict (group.populations) %}{% if not pop.name in group.spikepoplst %}{{pop.name}}_spike_i ::{% endif %}{% endfor %} [] )
342
343               
344           val neuron_spike_i = List.concat [
345                                 {% for name in (group.spikepoplst) %}
346                                 {{name}}_spike_i{% if not loop.last %},{% endif %}
347                                 {% endfor %}
348                                 ]
349
350            val all_spike_i    = List.concat [neuron_spike_i, ext_spike_i]
351        in
352            (all_spike_i, neuron_spike_i)
353        end
354
355{% macro random_divergent(name, sp, tp, epsilon, weight) %}
356             val Pr_{{name}}_seed = RandomInit()
357             val Pr_{{name}} = SparseMatrix.fromGeneratorList [N,N]
358                                  [
359                                    {% for s in sp %}
360                                    {% for t in tp %}
361                                      {offset=[{{t.start}},{{s.start}}],
362                                       fshape=[{{t.size}},{{s.size}}],
363                                       f=(fn (i) => if Real.> ({{epsilon}},
364                                                                RandomMTZig.randUniform Pr_{{name}}_seed)
365                                                    then {{weight}} else 0.0)}{% if not loop.last %},{% endif %}
366                                    {% endfor %}{% if not loop.last %},{% endif %}
367                                    {% endfor %}
368                                 ]
369{% endmacro %}
370
371
372{% macro all_to_all(name, sp, tp, weight) %}
373             val Pr_{{name}} = SparseMatrix.fromTensorList [N,N]
374                                  [
375                                    {% for s in sp %}
376                                    {% for t in tp %}
377                                    {offset=[{{t.start}},{{s.start}}],
378                                     tensor=(RTensor.*> {{weight}}
379                                                  (RTensor.new ([{{t.size}},{{s.size}}],1.0))),
380                                     sparse=false}{% if not loop.last %},{% endif %}
381                                    {% endfor %}{% if not loop.last %},{% endif %}
382                                    {% endfor %}
383                                 ]
384{% endmacro %}
385
386{% macro one_to_one(name, sp, tp, weight) %}
387             val Pr_{{name}} = SparseMatrix.fromTensorList [N,N]
388                                  [
389                                    {% for s in sp %}
390                                    {% for t in tp %}
391                                    {offset=[{{t.start}},{{s.start}}],
392                                     tensor=(fromDiag ({{t.size}},{{s.size}},Real64Array.fromList [{{weight}}],0.0)),
393                                     sparse=true}{% if not loop.last %},{% endif %}
394                                    {% endfor %}{% if not loop.last %},{% endif %}
395                                    {% endfor %}
396                                 ]
397{% endmacro %}
398
399
400{% macro from_file(name, sp, tp, filename) %}
401             val Pr_{{name}} = let val infile = TextIO.openIn "{{filename}}"
402                                   val S = TensorFile.realTensorRead (infile)
403                                   val _ = TextIO.closeIn infile
404                               in
405                                   SparseMatrix.fromTensorSliceList [N,N]
406                                   [
407                                     {% with %}
408                                     {% set soffset = 0 %}
409                                     {% for s in sp %}
410                                     {% set toffset = 0 %}
411                                     {% for t in tp %}
412                                     {offset=[{{t.start}},{{s.start}}],
413                                      slice=(RTensorSlice.slice ([([{{toffset}},{{soffset}}],[{{toffset}}+{{t.size}}-1,{{soffset}}+{{s.size}}-1])],S)),
414                                      sparse=false}{% if not loop.last %},{% endif %}
415                                     {% set toffset = toffset + t.size %}
416                                     {% endfor %}{% if not loop.last %},{% endif %}
417                                     {% set soffset = soffset + s.size %}
418                                     {% endfor %}
419                                     {% endwith %}
420                                  ]
421                                end
422{% endmacro %}
423
424           
425           
426    fun fprojection () =
427       
428        (let
429             
430             {% for pr in dict (group.projections) %}
431             val _ = putStrLn TextIO.stdOut "constructing {{pr.name}}"
432             {% if pr.value.rule == "random divergent" %}
433             {% call random_divergent(pr.name,
434                                      pr.value.source.populations,
435                                      pr.value.target.populations,
436                                      pr.value.properties.epsilon.exprML,
437                                      pr.value.properties.weight.exprML) %}
438             {% endcall %}
439             {% else %}
440             {% if pr.value.rule == "one-to-one" %}
441             {% call one_to_one(pr.name,
442                                pr.value.source.populations,
443                                pr.value.target.populations,
444                                pr.value.properties.weight.exprML) %}
445             {% endcall %}
446             {% else %}
447             {% if pr.value.rule == "all-to-all" %}
448             {% call all_to_all(pr.name,
449                                pr.value.source.populations,
450                                pr.value.target.populations,
451                                pr.value.properties.weight.exprML) %}
452             {% endcall %}
453             {% else %}
454             {% if pr.value.rule == "from file" %}
455             {% call from_file(pr.name,
456                               pr.value.source.populations,
457                               pr.value.target.populations,
458                               pr.value.properties.filename.exprML) %}
459             {% endcall %}
460             {% endif %}
461             {% endif %}
462             {% endif %}
463             {% endif %}
464             {% endfor %}
465               
466       
467{% if group.psrtypes %}{% for psr in dict (group.psrtypes) %}
468             val S_{{psr.name}} = foldl1 SparseMatrix.insert
469                                  ([
470                                    {% for pr in psr.value.projections %}
471                                    {% if pr.value.type == "event" %}
472                                    Pr_{{pr}}{% if not loop.last %},{% endif %}
473                                    {% endif %}
474                                    {% endfor %}
475                                   ])
476{% endfor %}{% else %}
477             val S = foldl1 SparseMatrix.insert
478                     ([
479                           {% for pr in dict (group.projections) %}
480                               {% if pr.value.type == "event" %}
481                               Pr_{{pr.name}}{% if not loop.last %},{% endif %}
482                               {% endif %}
483                           {% endfor %}
484                     ])
485{% endif %}
486
487             val Elst =
488                      [
489                           {% for pr in dict (group.projections) %}
490                               {% if pr.value.type == "cvar" %}
491                               Pr_{{pr.name}}{% if not loop.last %},{% endif %}
492                               {% endif %}
493                           {% endfor %}
494                     ]
495
496             val E = if List.null Elst then NONE else SOME Elst
497
498             in
499              ([
500{% if group.psrtypes %}{% for psr in dict (group.psrtypes) %}
501               S_{{ psr.name }}{% if not loop.last %},{% endif %}
502{% endfor %}{% else %}
503               S
504{% endif %}
505              ], E)
506             end)
507
508
509end
510       
Note: See TracBrowser for help on using the repository browser.