source: project/release/4/nemo/trunk/templates/NEST-nodes.tmpl @ 31508

Last change on this file since 31508 was 31508, checked in by Ivan Raikov, 6 years ago

nemo: continuation of nest backend refactoring

File size: 10.5 KB
Line 
1
2{% if (ODEmethod == "cvode") %}
3
4{{modelName}}::{{modelName}}()
5    : Archiving_Node(),
6      P_(),
7      S_(P_),
8      B_(*this)
9{
10    recordablesMap_.create();
11}
12
13{{modelName}}::{{modelName}}(const {{modelName}}& n)
14    : Archiving_Node(n),
15      P_(n.P_),
16      S_(n.S_),
17      B_(n.B_, *this)
18{
19}
20
21
22{{modelName}}::~{{modelName}}()
23{
24
25  if ( B_.sys_ )
26  {
27    /* Free y vector */
28    N_VDestroy_Serial(B_.y);
29
30    /* Free integrator memory */
31    if (B_.sys_ != NULL)
32    {
33      CVodeFree(&B_.sys_);
34      B_.sys_ = NULL;
35    }
36  }
37}
38EOF
39   ))
40
41
42void {{modelName}}::init_node_(const Node& proto)
43{
44    const {{modelName}}& pr = downcast<{{modelName}}>(proto);
45    P_ = pr.P_;
46    S_ = State_(P_);
47}
48
49
50void {{modelName}}::init_state_(const Node& proto)
51{
52    const {{modelName}}& pr = downcast<{{modelName}}>(proto);
53    S_ = State_(pr.P_);
54}
55
56
57void {{modelName}}::init_buffers_()
58{
59
60   {% for synapticEvent in synapticEventDefs %}
61   B_.spike_{{synapticEvent.pscId}}.clear();
62   {% endfor %}
63
64   B_.currents_.clear();           
65   Archiving_Node::clear_history();
66
67   B_.logger_.reset();
68
69   B_.step_ = Time::get_resolution().get_ms();
70   B_.IntegrationStep_ = B_.step_;
71
72   B_.I_stim_ = 0.0;
73
74   int status, N, rootdir;
75
76   N = {{stateSize}};
77   // only positive direction (rising edge) of spike events will be detected
78   rootdir = 1;
79
80   /* Creates serial vector of length N */
81   B_.y = N_VNew_Serial(N);
82   if (check_flag((void *)B_.y, "N_VNew_Serial", 0)) throw CVodeSolverFailure (get_name(), 0);
83
84   for (int i = 0; i < N; i++)
85   {
86      Ith(B_.y,i) = S_.y_[i];
87   }
88 
89   /* Calls CVodeCreate to create the solver memory and specify the
90    * Backward Differentiation Formula and the use of a Newton iteration */
91   B_.sys_ = CVodeCreate(CV_BDF, CV_NEWTON);
92   if (check_flag((void *)B_.sys_, "CVodeCreate", 0)) throw CVodeSolverFailure (get_name(), 0);
93
94   /* Calls CVodeInit to initialize the integrator memory and specifies the
95    * right hand side function in y'=f(t,y), the initial time, and
96    * the initial values. */
97   status = CVodeInit (B_.sys_, {{modelName}}_dynamics, 0.0, B_.y);
98   if (check_flag(&status, "CVodeInit", 1)) throw CVodeSolverFailure (get_name(), status);
99
100{% if haskey(defaults,"V_t") %}
101
102   /* Spike event handler (detects zero-crossing of V-V_t) */
103   status = CVodeRootInit(B_.sys_, 1, (CVRootFn){{modelName}}_event);
104   if (check_flag(&status, "CVodeRootInit", 1)) throw CVodeSolverFailure (get_name(), status);
105
106   /* Detect only the rising edge of spikes */
107   status = CVodeSetRootDirection(B_.sys_, &rootdir);
108   if (check_flag(&status, "CVodeSetRootDirection", 1)) throw CVodeSolverFailure (get_name(), status);
109
110{% endif %}
111
112   /* Sets the relative and absolute error tolerances of CVode  */
113   status = CVodeSStolerances (B_.sys_,
114                               {% if abstol %}{{abstol}}{% else %}1e-7{% endif %},
115                               {% if reltol %}{{reltol}}{% else %}1e-7{% endif %});
116   if (check_flag(&status, "CVodeSStolerances", 1)) throw CVodeSolverFailure (get_name(), status);
117
118   /* Turns on CVode stability limit detection (only applicable for order 3 and greater) */
119   status = CVodeSetStabLimDet (B_.sys_,true);
120   if (check_flag(&status, "CVodeSetStabLimDet", 1)) throw CVodeSolverFailure (get_name(), status);
121
122   /* Sets the maximum order of CVode  */
123   status = CVodeSetMaxOrd (B_.sys_,5);
124   if (check_flag(&status, "CVodeSetMaxOrd", 1)) throw CVodeSolverFailure (get_name(), status);
125
126   /* Sets maximum step size. */
127   status = CVodeSetMaxStep (B_.sys_,{% if maxstep %}{{maxstep}}{% else %}B_.step_{% endif %});
128   if (check_flag(&status, "CVodeSetMaxStep", 1)) throw CVodeSolverFailure (get_name(), status);
129
130   /* Configures the integrator to pass the params structure to the right-hand function */
131   status = CVodeSetUserData(B_.sys_, reinterpret_cast<void*>(this));
132   if (check_flag(&status, "CVodeSetUserData", 1)) throw CVodeSolverFailure (get_name(), status);
133
134   /* Initializes diagonal linear solver. */
135   status = CVDiag (B_.sys_);
136   if (check_flag(&status, "CVDiag", 1)) throw CVodeSolverFailure (get_name(), status);
137  }
138
139
140void {{modelName}}::calibrate()
141{
142    B_.logger_.init(); 
143}
144
145
146{% elif (ODEmethod == "ida") %}
147
148{{modelName}}::{{modelName}}()
149    : Archiving_Node(),
150      P_(),
151      S_(P_),
152      B_(*this)
153{
154    recordablesMap_.create();
155}
156
157
158{{modelName}}::{{modelName}}(const {{modelName}}& n)
159    : Archiving_Node(n),
160      P_(n.P_),
161      S_(n.S_),
162      B_(n.B_, *this)
163{
164}
165
166
167{{modelName}}::~{{modelName}}()
168{
169
170  if ( B_.sys_ )
171  {
172    /* Free y vector */
173    N_VDestroy_Serial(B_.y);
174    N_VDestroy_Serial(B_.yp);
175
176    /* Free integrator memory */
177    if (B_.sys_ != NULL)
178    {
179      IDAFree(&B_.sys_);
180      B_.sys_ = NULL;
181    }
182
183  }
184}
185
186
187void {{modelName}}::init_node_(const Node& proto)
188{
189    const {{modelName}}& pr = downcast<{{modelName}}>(proto);
190    P_ = pr.P_;
191    S_ = State_(P_);
192}
193
194
195void {{modelName}}::init_state_(const Node& proto)
196{
197    const {{modelName}}& pr = downcast<{{modelName}}>(proto);
198    S_ = State_(pr.P_);
199}
200
201
202void {{modelName}}::init_buffers_()
203{
204
205   {% for synapticEvent in synapticEventDefs %}
206   B_.spike_{{synapticEvent.pscId}}.clear();
207   {% endfor %}
208
209   B_.currents_.clear();           
210   Archiving_Node::clear_history();
211
212   B_.logger_.reset();
213
214   B_.step_ = Time::get_resolution().get_ms();
215   B_.IntegrationStep_ = B_.step_;
216
217   B_.I_stim_ = 0.0;
218
219
220   int status, N, rootdir;
221
222   N = {{stateSize}};
223
224   // only positive direction (rising edge) of spike events will be detected
225   rootdir = 1;
226
227   /* Creates serial vectors of length N */
228   B_.y = N_VNew_Serial(N);
229   B_.y1 = N_VNew_Serial(N);
230   B_.yp = N_VNew_Serial(N);
231   if (check_flag((void *)B_.y, "N_VNew_Serial", 0)) throw IDASolverFailure (get_name(), 0);
232
233   for (int i = 0; i < N; i++)
234   {
235      Ith(B_.y,i) = S_.y_[i];
236   }
237 
238   {{modelName}}_dynamics (0.0, B_.y, B_.yp, reinterpret_cast<void*>(this));
239
240   /* Calls IDACreate to create the solver memory */
241   B_.sys_ = IDACreate();
242   if (check_flag((void *)B_.sys_, "IDACreate", 0)) throw IDASolverFailure (get_name(), 0);
243
244  /* Calls IDAInit to initialize the integrator memory and specify the
245   * resdual function, the initial time, and the initial values. */
246   status = IDAInit (B_.sys_, {{modelName}}_residual, 0.0, B_.y, B_.yp);
247
248   if (check_flag(&status, "IDAInit", 1)) throw IDASolverFailure (get_name(), status);
249
250{% if haskey(defaults,"V_t") %}
251
252   /* Spike event handler (detects zero-crossing of V-V_t) */
253   status = IDARootInit(B_.sys_, 1, (IDARootFn){{modelName}}_event);
254   if (check_flag(&status, "IDARootInit", 1)) throw IDASolverFailure (get_name(), status);
255
256   /* Detect only the rising edge of spikes */
257   status = IDASetRootDirection(B_.sys_, &rootdir);
258   if (check_flag(&status, "IDASetRootDirection", 1)) throw IDASolverFailure (get_name(), status);
259
260{% endif %}
261
262   /* Sets the relative and absolute error tolerances of IDA  */
263   status = IDASStolerances (B_.sys_,
264                             {% if abstol %}{{abstol}}{% else %}1e-7{% endif %},
265                             {% if reltol %}{{reltol}}{% else %}1e-7{% endif %});
266   if (check_flag(&status, "IDASStolerances", 1)) throw IDASolverFailure (get_name(), status);
267
268   /* Sets the maximum order of IDA  */
269   status = IDASetMaxOrd (B_.sys_,5);
270   if (check_flag(&status, "IDASetMaxOrd", 1)) throw IDASolverFailure (get_name(), status);
271
272   /* Sets maximum step size. */
273   status = IDASetMaxStep (B_.sys_,{% if maxstep %}{{maxstep}}{% else %}B_.step_{% endif %});
274   if (check_flag(&status, "IDASetMaxStep", 1)) throw IDASolverFailure (get_name(), status);
275
276   /* Calls IDASetUserData to configure the integrator to pass the
277    * params structure to the right-hand function */
278   status = IDASetUserData(B_.sys_, reinterpret_cast<void*>(this));
279   if (check_flag(&status, "IDASetUserData", 1)) throw IDASolverFailure (get_name(), status);
280
281   /* Initializes dense linear solver. */
282   status = IDADense (B_.sys_, N);
283   if (check_flag(&status, "IDADense", 1)) throw IDASolverFailure (get_name(), status);
284
285   status = IDACalcIC(B_.sys_, IDA_Y_INIT, 0.0);
286   if (check_flag(&status, "IDACalcIC", 1)) throw IDASolverFailure (get_name(), status);
287
288}
289
290void {{modelName}}::calibrate()
291{
292   B_.logger_.init(); 
293}
294
295{% elif (ODEmethod == "gsl") %}
296
297{{modelName}}::{{modelName}}()
298    : Archiving_Node(),
299      P_(),
300      S_(P_),
301      B_(*this)
302{
303    recordablesMap_.create();
304}
305
306
307{{modelName}}::{{modelName}}(const {{modelName}}& n)
308    : Archiving_Node(n),
309      P_(n.P_),
310      S_(n.S_),
311      B_(n.B_, *this)
312{
313}
314
315
316{{modelName}}::~{{modelName}} ()
317{
318    // GSL structs only allocated by init_nodes_(), so we need to protect destruction
319    if ( B_.s_ != NULL) gsl_odeiv2_step_free (B_.s_);
320    if ( B_.c_ != NULL) gsl_odeiv2_control_free (B_.c_);
321    if ( B_.e_ != NULL) gsl_odeiv2_evolve_free (B_.e_);
322    if ( B_.u != NULL) free (B_.u);
323    if ( B_.jac != NULL) free (B_.jac);
324}
325
326
327void {{modelName}}::init_node_(const Node& proto)
328{
329    const {{modelName}}& pr = downcast<{{modelName}}>(proto);
330    P_ = pr.P_;
331    S_ = State_(P_);
332}
333
334
335void {{modelName}}::init_state_(const Node& proto)
336{
337    const {{modelName}}& pr = downcast<{{modelName}}>(proto);
338    S_ = State_(pr.P_);
339}
340
341
342void {{modelName}}::init_buffers_()
343{
344   {% for synapticEvent in synapticEventDefs %}
345   B_.spike_{{synapticEvent.pscId}}.clear();
346   {% endfor %}
347
348   B_.currents_.clear();           
349   Archiving_Node::clear_history();
350
351   B_.logger_.reset();
352
353   B_.step_ = Time::get_resolution().get_ms();
354   B_.IntegrationStep_ = B_.step_;
355
356   B_.I_stim_ = 0.0;
357
358
359   static const gsl_odeiv2_step_type* T1 = gsl_odeiv2_step_rk2;
360   B_.N = {{stateSize}};
361 
362   if ( B_.s_ == 0 )
363     B_.s_ = gsl_odeiv2_step_alloc (T1, B_.N);
364   else
365     gsl_odeiv2_step_reset(B_.s_);
366   
367   if ( B_.c_ == 0 ) 
368     B_.c_ = gsl_odeiv2_control_standard_new (#{(or abstol 1e-7)}, #{(or reltol 1e-7)}, 1.0, 0.0);
369   else
370     gsl_odeiv2_control_init(B_.c_, #{(or abstol 1e-7)}, #{(or reltol 1e-7)}, 1.0, 0.0);
371   
372   if ( B_.e_ == 0 ) 
373     B_.e_ = gsl_odeiv2_evolve_alloc(B_.N);
374   else
375     gsl_odeiv2_evolve_reset(B_.e_);
376   
377   B_.sys_.function  = {{modelName}}_dynamics;
378   B_.sys_.jacobian  = {{modelName}}_jacobian;
379   B_.sys_.dimension = B_.N;
380   B_.sys_.params    = reinterpret_cast<void*>(this);
381
382   B_.u = (double *)malloc(sizeof(double) * B_.N);
383   assert (B_.u);
384   B_.jac = (double *)malloc(sizeof(double) * B_.N);
385   assert (B_.jac);
386
387}
388
389
390void {{modelName}}::calibrate()
391{
392    B_.logger_.init(); 
393    #{(if iv (sprintf "V_.U_old_ = S_.y_[~A];" iv) "")}
394    #{(if (lookup-def 't_ref defaults)
395          "V_.RefractoryCounts_ = Time(Time::ms(P_.t_ref)).get_steps();"
396          "V_.RefractoryCounts_ = 0;")}
397}
398
399
400{% endif %}
Note: See TracBrowser for help on using the repository browser.