source: project/release/4/9ML-toolkit/trunk/examples/tensor.sml @ 28665

Last change on this file since 28665 was 28665, checked in by Ivan Raikov, 7 years ago

9ML-toolkit: SML tensor lib updates

File size: 103.1 KB
Line 
1(* Obtained at http://www.arrakis.es/~worm/ *)
2
3signature MONO_VECTOR =
4  sig
5    type vector
6    type elem
7    val maxLen : int
8    val fromList : elem list -> vector
9    val tabulate : (int * (int -> elem)) -> vector
10    val length : vector -> int
11    val sub : (vector * int) -> elem
12    val extract : (vector * int * int option) -> vector
13    val concat : vector list -> vector
14    val mapi : ((int * elem) -> elem) -> (vector * int * int option) -> vector
15    val map : (elem -> elem) -> vector -> vector
16    val appi : ((int * elem) -> unit) -> (vector * int * int option) -> unit
17    val app : (elem -> unit) -> vector -> unit
18    val foldli : ((int * elem * 'a) -> 'a) -> 'a -> (vector * int * int option) -> 'a
19    val foldri : ((int * elem * 'a) -> 'a) -> 'a -> (vector * int * int option) -> 'a
20    val foldl : ((elem * 'a) -> 'a) -> 'a -> vector -> 'a
21    val foldr : ((elem * 'a) -> 'a) -> 'a -> vector -> 'a 
22  end
23
24(*
25 Copyright (c) Juan Jose Garcia Ripoll.
26 All rights reserved.
27
28 Refer to the COPYRIGHT file for license conditions
29*)
30
31(* COPYRIGHT
32 
33Redistribution and use in source and binary forms, with or
34without modification, are permitted provided that the following
35conditions are met:
36
371. Redistributions of source code must retain the above copyright
38   notice, this list of conditions and the following disclaimer.
39
402. Redistributions in binary form must reproduce the above
41   copyright notice, this list of conditions and the following
42   disclaimer in the documentation and/or other materials provided
43   with the distribution.
44
453. All advertising materials mentioning features or use of this
46   software must display the following acknowledgement:
47        This product includes software developed by Juan Jose
48        Garcia Ripoll.
49
504. The name of Juan Jose Garcia Ripoll may not be used to endorse
51   or promote products derived from this software without
52   specific prior written permission.
53
54THIS SOFTWARE IS PROVIDED BY JUAN JOSE GARCIA RIPOLL ``AS IS''
55AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
56TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
57PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL HE BE
58LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY,
59OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
60PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,
61OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
62THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR
63TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT
64OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY
65OF SUCH DAMAGE.
66*)
67
68structure Loop =
69    struct
70        fun all (a, b, f) =
71            if a > b then
72                true
73            else if f a then
74                all (a+1, b, f)
75            else
76                false
77
78        fun any (a, b, f) =
79            if a > b then
80                false
81            else if f a then
82                true
83            else
84                any (a+1, b, f)
85
86        fun app (a, b, f) =
87            if a < b then
88                (f a; app (a+1, b, f))
89            else
90                ()
91
92        fun app' (a, b, d, f) =
93            if a < b then
94                (f a; app' (a+d, b, d, f))
95            else
96                ()
97
98        fun appi' (a, b, d, f) =
99            if a < b then
100                (f a; appi' (a+d, b, d, f))
101            else
102                ()
103    end
104(*
105  INDEX         -Signature-
106
107  Indices are a enumerable finite set of data with an order and a map
108  to a continous nonnegative interval of integers.  In the sample
109  implementation, Index, each index is a list of integers,
110        [i1,...,in]
111  and each set of indices is defined by a shape, which has the same
112  shape of an index but with each integer incremented by one
113        shape = [k1,...,kn]
114        0 <= i1 < k1
115
116  type storage = RowMajor | ColumnMajor
117  order : storage
118        Identifies:
119                1) the underlying algorithms for this structure
120                2) the most significant index
121                3) the index that varies more slowly
122                4) the total order
123        RowMajor means that first index is most significant and varies
124        more slowly, while ColumnMajor means that last index is the most
125        significant and varies more slowly. For instance
126                RowMajor => [0,0]<[0,1]<[1,0]<[1,1] (C, C++, Pascal)
127                ColumnMajor => [0,0]>[1,0]>[0,1]>[1,1] (Fortran)
128  last shape
129  first shape
130        Returns the last/first index that belongs to the sed defined by
131        'shape'.
132  inBounds shape index
133        Checkes whether 'index' belongs to the set defined by 'shape'.
134  toInt shape index
135        As we said, indices can be sorted and mapped to a finite set of
136        integers. 'toInt' obtaines the integer number that corresponds to
137        a certain index.
138  indexer shape
139        It is equivalent to the partial evaluation 'toInt shape' but
140        optimized for 'shape'.
141
142  next shape index
143  prev shape index
144  next' shape index
145  prev' shape index
146        Obtain the following or previous index to the one we supply.
147        next and prev return an object of type 'index option' so that
148        if there is no such following/previous, the output is NONE.
149        On the other hand, next'/prev' raise an exception when the
150        output is not well defined and their output is always of type
151        index. next/prev/next'/prev' raise an exception if 'index'
152        does not belong to the set of 'shape'.
153
154  all shape f
155  any shape f
156  app shape f
157        Iterates 'f' over every index of the set defined by 'shape'.
158        'all' stops when 'f' first returns false, 'any' stops when
159        'f' first returns true and 'app' does not stop and discards the
160        output of 'f'.
161
162  compare(a,b)
163        Returns LESS/GREATER/EQUAL according to the total order which
164        is defined in the set of all indices.
165  <,>,eq,<=,>=,<>
166        Reduced comparisons which are defined in terms of 'compare'.
167
168  validShape t
169  validIndex t
170        Checks whether 't' conforms a valid shape or index.
171
172  iteri shape f
173*)
174
175signature INDEX =
176    sig
177        type t
178        type indexer = t -> int
179        datatype storage = RowMajor | ColumnMajor
180
181        exception Index
182        exception Shape
183
184        val order : storage
185        val toInt : t -> t -> int
186        val length : t -> int
187        val first : t -> t
188        val last : t -> t
189        val next : t -> t -> t option
190        val prev : t -> t -> t option
191        val next' : t -> t -> t
192        val prev' : t -> t -> t
193        val indexer : t -> (t -> int)
194
195        val inBounds : t -> t -> bool
196        val compare : t * t -> order
197        val < : t * t -> bool
198        val > : t * t -> bool
199        val eq : t * t -> bool
200        val <= : t * t -> bool
201        val >= : t * t -> bool
202        val <> : t * t -> bool
203        val - : t * t -> t
204
205        val validShape : t -> bool
206        val validIndex : t -> bool
207
208        val all : t -> (t -> bool) -> bool
209        val any : t -> (t -> bool) -> bool
210        val app : t -> (t -> unit) -> unit
211    end
212structure Index : INDEX =
213    struct
214        type t = int list
215        type indexer = t -> int
216        datatype storage = RowMajor | ColumnMajor
217
218        exception Index
219        exception Shape
220
221        val order = ColumnMajor
222
223        fun validShape shape = List.all (fn x => x > 0) shape
224
225        fun validIndex index = List.all (fn x => x >= 0) index
226
227        fun toInt shape index =
228            let fun loop ([], [], accum, _) = accum
229                  | loop ([], _, _, _) = raise Index
230                  | loop (_, [], _, _) = raise Index
231                  | loop (i::ri, l::rl, accum, fac) =
232                if (i >= 0) andalso (i < l) then
233                    loop (ri, rl, i*fac + accum, fac*l)
234                else
235                    raise Index
236            in loop (index, shape, 0, 1)
237            end
238
239        (* ----- CACHED LINEAR INDEXER -----
240
241           An indexer is a function that takes a list of
242           indices, validates it and produces a nonnegative
243           integer number. In short, the indexer is the
244           mapper from indices to element positions in
245           arrays.
246
247           'indexer' builds such a mapper by optimizing
248           the most common cases, which are 1d and 2d
249           tensors.
250         *)
251    local
252        fun doindexer [] _ = raise Shape
253          | doindexer [a] [dx] =
254            let fun f [x] = if (x > 0) andalso (x < a)
255                            then x
256                            else raise Index
257                  | f _ = raise Index
258            in f end
259          | doindexer [a,b] [dx, dy] =
260            let fun f [x,y] = if ((x > 0) andalso (x < a) andalso
261                                  (y > 0) andalso (y < b))
262                              then x + dy * y
263                              else raise Index
264                  | f _ = raise Index
265            in f end
266          | doindexer [a,b,c] [dx,dy,dz] =
267            let fun f [x,y,z] = if ((x > 0) andalso (x < a) andalso
268                                    (y > 0) andalso (y < b) andalso
269                                    (z > 0) andalso (z < c))
270                                then x + dy * y + dz * z
271                                else raise Index
272                  | f _ = raise Index
273            in f end
274          | doindexer shape memo =
275            let fun f [] [] accum [] = accum
276                  | f _  _  _ [] = raise Index
277                  | f (fac::rf) (ndx::ri) accum (dim::rd) =
278                    if (ndx >= 0) andalso (ndx < dim) then
279                        f rf ri (accum + ndx * fac) rd
280                    else
281                        raise Index
282            in f shape memo 0
283            end
284    in
285        fun indexer shape =
286            let fun memoize accum [] = []
287                  | memoize accum (dim::rd) =
288                    accum :: (memoize (dim * accum) rd)
289            in
290                if validShape shape
291                then doindexer shape (memoize 1 shape)
292                else raise Shape
293            end
294    end
295
296        fun length shape =
297            let fun prod (a,b) =
298                if b < 0 then raise Shape else a * b
299            in foldl prod 1 shape
300            end
301
302        fun first shape = map (fn x => 0) shape
303
304        fun last [] = []
305          | last (size :: rest) =
306            if size < 1
307            then raise Shape
308            else size - 1 :: last rest
309
310        fun next' [] [] = raise Subscript
311          | next' _ [] = raise Index
312          | next' [] _ = raise Index
313          | next' (dimension::restd) (index::resti) =
314            if (index + 1) < dimension
315            then (index + 1) :: resti
316            else 0 :: (next' restd resti)
317
318        fun prev' [] [] = raise Subscript
319          | prev' _ [] = raise Index
320          | prev' [] _ = raise Index
321          | prev' (dimension::restd) (index::resti) =
322            if (index > 0)
323            then index - 1 :: resti
324            else dimension - 1 :: prev' restd resti
325
326        fun next shape index = (SOME (next' shape index)) handle
327            Subscript => NONE
328
329        fun prev shape index = (SOME (prev' shape index)) handle
330            Subscript => NONE
331
332        fun inBounds shape index =
333            ListPair.all (fn (x,y) => (x >= 0) andalso (x < y))
334            (index, shape)
335
336        fun compare ([],[]) = EQUAL
337          | compare (_, []) = raise Index
338          | compare ([],_) = raise Index
339          | compare (a::ra, b::rb) =
340            case Int.compare (a,b) of
341                EQUAL => compare (ra,rb)
342              | LESS => LESS
343              | GREATER => GREATER
344
345    local
346        fun iterator a inner =
347            let fun loop accum f =
348                let fun innerloop i =
349                    if i < a
350                    then if inner (i::accum) f
351                         then innerloop (i+1)
352                         else false
353                    else true
354                in innerloop 0
355                end
356            in loop
357            end
358        fun build_iterator [a] =
359            let fun loop accum f =
360                let fun innerloop i =
361                    if i < a
362                    then if f (i::accum)
363                         then innerloop (i+1)
364                         else false
365                    else true
366                in innerloop 0
367                end
368            in loop
369            end
370          | build_iterator (a::rest) = iterator a (build_iterator rest)
371    in
372        fun all shape = build_iterator shape []
373    end
374
375    local
376        fun iterator a inner =
377            let fun loop accum f =
378                let fun innerloop i =
379                    if i < a
380                    then if inner (i::accum) f
381                         then true
382                         else innerloop (i+1)
383                    else false
384                in innerloop 0
385                end
386            in loop
387            end
388        fun build_iterator [a] =
389            let fun loop accum f =
390                let fun innerloop i =
391                    if i < a
392                    then if f (i::accum)
393                         then true
394                         else innerloop (i+1)
395                    else false
396                in innerloop 0
397                end
398            in loop
399            end
400          | build_iterator (a::rest) = iterator a (build_iterator rest)
401    in
402        fun any shape = build_iterator shape []
403    end
404
405    local
406        fun iterator a inner =
407            let fun loop accum f =
408                let fun innerloop i =
409                    if i < a
410                    then (inner (i::accum) f;
411                          innerloop (i+1))
412                    else ()
413                in innerloop 0
414                end
415            in loop
416            end
417        fun build_iterator [a] =
418            let fun loop accum f =
419                let fun innerloop i =
420                    if i < a
421                    then (f (i::accum); innerloop (i+1))
422                    else ()
423                in innerloop 0
424                end
425            in loop
426            end
427          | build_iterator (a::rest) = iterator a (build_iterator rest)
428    in
429        fun app shape = build_iterator shape []
430    end
431
432        fun a < b = compare(a,b) = LESS
433        fun a > b = compare(a,b) = GREATER
434        fun eq (a, b) = compare(a,b) = EQUAL
435        fun a <> b = not (a = b)
436        fun a <= b = not (a > b)
437        fun a >= b = not (a < b)
438        fun a - b = ListPair.map Int.- (a,b)
439
440    end
441
442
443signature RANGE =
444    sig
445        structure Index : INDEX
446        type index = Index.t
447        type t
448
449        exception Range
450
451        val fromto : index -> index * index -> t
452        val fromto' : index -> index * index -> t
453        val upto : index -> index -> t
454        val below : index -> index -> t
455
456        val first : t -> index
457        val last : t -> index
458
459        val length : t -> int
460        val shape : t -> index
461        val inRange : t -> index -> bool
462        val next : t -> index -> index option
463        val prev : t -> index -> index option
464        val ranges : index -> ((index * index) list) -> t
465        val iteri : (index -> bool) -> t -> bool
466    end
467
468structure Range : RANGE =
469struct
470
471    structure Index = Index
472    type index = Index.t
473
474    datatype t = RangeEmpty | RangeIn of index * index * index | RangeSet of index * ((index * index) list)
475
476    exception Range
477
478    local
479        fun next'' [] [] [] = raise Range
480          | next'' [] _  _  = raise Index.Index
481          | next'' _  [] _  = raise Index.Index
482          | next'' _  _  [] = raise Index.Index
483          | next'' (low::rl) (up::ru) (index::ri) =
484            if index < up then
485                index + 1 :: ri
486            else
487                low :: (next'' rl ru ri)
488
489        fun prev'' [] [] [] = raise Range
490          | prev'' [] _  _  = raise Index.Index
491          | prev'' _  [] _  = raise Index.Index
492          | prev'' _  _  [] = raise Index.Index
493          | prev'' (low::rl) (up::ru) (index::ri) =
494            if index > low then
495                index - 1 :: ri
496            else
497                up :: (prev'' rl ru ri)
498
499        (* Builds the simple loop
500           for i := first to last
501              if not (g (i::ndx)) then
502                  break
503           endfor;
504           *)
505
506        fun simple_loop (first : int) (last : int) =
507            let fun loop (ndx : index) (g: index -> bool) =
508                let fun innerloop i =
509                    if i > last then
510                        true
511                    else if g (i::ndx) then
512                        innerloop (i+1)
513                    else
514                        false
515                in innerloop first end
516            in loop end
517
518        (* Builds the nested loop
519           for i := first to last
520              if not (f (i:ndx) g) then
521                  break
522           endfor
523         *)
524        fun nested_loop f (first : int) (last : int) =
525            let fun loop (ndx: index) (g: index -> bool) =
526                let fun innerloop i =
527                    (if i > last then
528                        true
529                    else if (f (i::ndx) g) then
530                        innerloop (i+1)
531                    else
532                        false)
533                in 
534                     innerloop first
535                end
536            in loop end
537
538        fun build_iterator ([a] : index) ([b] : index) = 
539            simple_loop a b
540          | build_iterator (a::ra) (b::rb) =
541            nested_loop (build_iterator ra rb) a b
542
543    in
544
545        (* ----- CONSTRUCTORS ----- *)
546
547        fun fromto shape (lo, up) =
548            if (Index.validShape shape) andalso (Index.validIndex lo) andalso (Index.validIndex up) andalso
549               (Index.inBounds shape lo) andalso (Index.inBounds shape up) andalso Index.< (lo,up)
550            then
551                RangeIn(shape,lo,up)
552            else
553                RangeEmpty
554
555        fun fromto' shape (lo, up) = fromto shape (lo, (Index.prev' shape up))
556
557        fun upto shape index = fromto shape (Index.first index, index)         
558
559        fun below shape index = fromto' shape (Index.first index, index)
560
561        fun range_append ((lo,up),((lo',up')::ranges)) = 
562            if Index.< (up,lo') then ((lo,up)::(lo',up')::ranges) else
563               (if Index.> (up,up') then (lo',up')::(range_append ((lo,up),ranges)) else
564                 (if Index.> (lo,lo') then ((lo',up')::ranges) else ((lo,up')::ranges) ))
565          | range_append ((lo,up),[]) = [(lo,up)]
566
567        fun ranges' shape ((lo, up) :: rest) = 
568            if (Index.validShape shape) andalso (Index.validIndex lo) andalso (Index.validIndex up) andalso
569               (Index.inBounds shape lo) andalso (Index.inBounds shape up) andalso Index.< (lo,up)
570               then range_append ((lo,up), (ranges' shape rest)) else (ranges' shape rest)
571            | ranges' shape ([]) = []
572
573        fun ranges shape xs = 
574            let val set = ranges' shape xs
575            in 
576                if List.null set then RangeEmpty else RangeSet (shape,set)
577            end
578
579        fun length RangeEmpty = 0
580          | length (RangeIn(shape,lo,up)) =
581            let fun diff (x,y) = (y-x+1) in
582                Index.length (ListPair.map diff (lo,up))
583            end
584          | length (RangeSet(shape,set)) =
585            let fun diff (x,y) = (y-x+1) in
586                foldl (fn ((lo,up),ax) => (Index.length (ListPair.map diff (lo,up))) + ax)
587                      0 set
588            end
589 
590        fun shape RangeEmpty = []
591          | shape (RangeIn(shape,lo,up)) =
592            let fun diff (x,y) = (y-x+1) in
593                ListPair.map diff (lo,up)
594            end
595          | shape (RangeSet(shape,set)) =
596            let fun diff (x,y) = (y-x+1) in
597                foldl (fn ((lo,up),ax) => ListPair.map (op +) ((ListPair.map diff (lo,up)), ax))
598                      (List.map (fn (x) => 0) shape) set
599            end
600 
601
602        fun first RangeEmpty = raise Range
603          | first (RangeIn(shape,lo,up)) = lo
604          | first (RangeSet(shape,(lo,up)::_)) = lo
605
606        fun last RangeEmpty = raise Range
607          | last (RangeIn(shape,lo,up)) = up
608          | last (RangeSet(shape,set)) = let val (lo,up) = List.last set in up end
609
610        (* ----- PREDICATES & OPERATIONS ----- *)
611
612        fun inRange RangeEmpty _ = false
613          | inRange (RangeIn(shape,lo,up)) ndx =
614            (Index.<=(lo,ndx)) andalso (Index.<=(ndx,up))
615          | inRange (RangeSet(shape,set)) ndx =
616            List.exists (fn ((lo,up)) => (Index.<=(lo,ndx)) andalso (Index.<=(ndx,up))) set
617
618        fun next RangeEmpty _ = NONE
619          | next (RangeIn(shape,lo,up)) index =
620            (SOME (next'' lo up index) handle Range => NONE)
621          | next (RangeSet(shape,set)) index =
622            let val m = List.find (fn ((lo,up)) => (Index.<=(lo,index)) andalso (Index.<=(index,up))) set
623            in
624               case m of NONE => NONE
625                       | SOME (lo,up) => (SOME (next'' lo up index) handle Range => NONE)
626            end
627
628        fun prev RangeEmpty _ = NONE
629          | prev (RangeIn(shape,lo,up)) index =
630            (SOME (prev'' lo up index) handle Range => NONE)
631          | prev (RangeSet(shape,set)) index =
632            let val m = List.find (fn ((lo,up)) => (Index.<=(lo,index)) andalso (Index.<=(index,up))) set
633            in
634               case m of NONE => NONE
635                       | SOME (lo,up) => (SOME (prev'' lo up index) handle Range => NONE)
636            end
637
638        fun next' RangeEmpty _ = raise Range
639          | next' (RangeIn(shape,lo,up)) index = next'' lo up index
640          | next' (RangeSet(shape,set)) index =
641            let val m = List.find (fn ((lo,up)) => (Index.<=(lo,index)) andalso (Index.<=(index,up))) set
642                val (lo,up) = valOf m
643            in
644                next'' lo up index
645            end
646
647        fun prev' RangeEmpty _ = raise Range
648          | prev' (RangeIn(shape,lo,up)) index = prev'' lo up index
649          | prev' (RangeSet(shape,set)) index =
650            let val m = List.find (fn ((lo,up)) => (Index.<=(lo,index)) andalso (Index.<=(index,up))) set
651                val (lo,up) = valOf m
652            in
653                prev'' lo up index
654            end
655
656        (* ----- ITERATION ----- *)
657
658        (* Builds an interator that applies 'f' sequentially to
659           all the indices in the range, *)
660        fun iteri f RangeEmpty = f []
661          | iteri (f: index -> bool) (RangeIn(shape,lo: index,up: index)) = 
662           (build_iterator lo up) lo f
663          | iteri (f: index -> bool) (RangeSet(shape,set)) = 
664           List.all (fn (lo,up) => (build_iterator lo up) lo f) set
665
666    end
667end
668
669
670(*
671 Copyright (c) Juan Jose Garcia Ripoll.
672 All rights reserved.
673 
674 Refer to the COPYRIGHT file for license conditions
675*)
676
677(*
678 TENSOR         - Signature -
679
680 Polymorphic tensors of any type. With 'tensor' we denote a (mutable)
681 array of any rank, with as many indices as one wishes, and that may
682 be traversed (map, fold, etc) according to any of those indices.
683
684 type 'a tensor
685        Polymorphic tensor whose elements are all of type 'a.
686 val storage = RowMajor | ColumnMajor
687        RowMajor = data is stored in consecutive cells, first index
688        varying fastest (FORTRAN convention)
689        ColumnMajor = data is stored in consecutive cells, last
690        index varying fastest (C,C++,Pascal,CommonLisp convention)
691 new ([i1,...,in],init)
692        Build a new tensor with n indices, each of sizes i1...in,
693        filled with 'init'.
694 fromArray (shape,data)
695 fromList (shape,data)
696        Use 'data' to fill a tensor of that shape. An exception is
697        raised if 'data' is too large or too small to properly
698        fill the vector. Later use of a 'data' array is disregarded
699        -- one must think that the tensor now owns the array.
700 length tensor
701 rank tensor
702 shape tensor
703        Return the number of elements, the number of indices and
704        the shape (size of each index) of the tensor.
705 toArray tensor
706        Return the data of the tensor in the form of an array.
707        Mutation of this array may lead to unexpected behavior.
708
709 sub (tensor,[i1,...,in])
710 update (tensor,[i1,...,in],new_value)
711        Access the element that is indexed by the numbers [i1,..,in]
712
713 app f a
714 appi f a
715        The same as 'map' and 'mapi' but the function 'f' outputs
716        nothing and no new array is produced, i.e. one only seeks
717        the side effect that 'f' may produce.
718 map2 operation a b
719        Apply function 'f' to pairs of elements of 'a' and 'b'
720        and build a new tensor with the output. Both operands
721        must have the same shape or an exception is raised.
722        The procedure is sequential, as specified by 'storage'.
723 foldl operation a n
724        Fold-left the elements of tensor 'a' along the n-th
725        index.
726 all test a
727 any test a
728        Folded boolean tests on the elements of the tensor.
729*)
730
731signature TENSOR =
732    sig
733        structure Array : ARRAY
734        structure Index : INDEX
735        type index = Index.t
736        type 'a tensor
737
738        val new : index * 'a -> 'a tensor
739        val tabulate : index * (index -> 'a) -> 'a tensor
740        val length : 'a tensor -> int
741        val rank : 'a tensor -> int
742        val shape : 'a tensor -> (index)
743        val reshape : index -> 'a tensor -> 'a tensor
744        val fromList : index * 'a list -> 'a tensor
745        val fromArray : index * 'a array -> 'a tensor
746        val toArray : 'a tensor -> 'a array
747
748        val sub : 'a tensor * index -> 'a
749        val update : 'a tensor * index * 'a -> unit
750        val map : ('a -> 'b) -> 'a tensor -> 'b tensor
751        val map2 : ('a * 'b -> 'c) -> 'a tensor -> 'b tensor -> 'c tensor
752        val app : ('a -> unit) -> 'a tensor -> unit
753        val appi : (int * 'a -> unit) -> 'a tensor -> unit
754        val foldl : ('c * 'a -> 'c) -> 'c -> 'a tensor -> int -> 'c tensor
755        val all : ('a -> bool) -> 'a tensor -> bool
756        val any : ('a -> bool) -> 'a tensor -> bool
757
758        val cat : int * 'a tensor * 'a tensor -> 'a tensor       
759    end
760
761
762signature TENSOR_SLICE =
763    sig
764        structure Tensor : TENSOR
765        structure Range : RANGE
766
767        type index = Tensor.Index.t
768        type range = Range.t
769        type 'a tensor = 'a Tensor.tensor
770        type 'a slice
771
772        val slice  : ((index * index) list) * 'a tensor -> 'a slice
773        val length : 'a slice -> int
774        val base   : 'a slice -> 'a tensor
775        val shape  : 'a slice -> (index)
776
777        val map : ('a -> 'b) -> 'a slice -> 'b Vector.vector
778        val map2 : ('a * 'b -> 'c) -> 'a slice -> 'b slice -> 'c Vector.vector
779        val foldl  : ('a * 'b -> 'b) -> 'b -> 'a slice -> 'b
780
781    end
782
783(*
784 Copyright (c) Juan Jose Garcia Ripoll.
785 All rights reserved.
786 
787 Refer to the COPYRIGHT file for license conditions
788*)
789
790structure Tensor : TENSOR =
791    struct
792        structure Array = Array
793        structure Index = Index
794           
795        type index = Index.t
796        type 'a tensor = {shape : index, indexer : Index.indexer, data : 'a array}
797
798        exception Shape
799        exception Match
800        exception Index
801
802    local
803    (*----- LOCALS -----*)
804
805        fun make' (shape, data) =
806            {shape = shape, indexer = Index.indexer shape, data = data}
807
808        fun toInt {shape, indexer, data} index = indexer index
809
810        fun array_map f a =
811            let fun apply index = f(Array.sub(a,index)) in
812                Array.tabulate(Array.length a, apply)
813            end
814
815        fun splitList (l as (a::rest), place) =
816            let fun loop (left,here,right) 0 =  (List.rev left,here,right)
817                  | loop (_,_,[]) place = raise Index
818                  | loop (left,here,a::right) place = 
819                loop (here::left,a,right) (place-1)
820            in
821                if place <= 0 then
822                    loop ([],a,rest) (List.length rest - place)
823                else
824                    loop ([],a,rest) (place - 1)
825            end
826
827    in
828    (*----- STRUCTURAL OPERATIONS & QUERIES ------*)
829
830      fun cat (dim, x: 'a tensor, y: 'a tensor) =
831        (let val xshape = (#shape x)
832             val yshape = (#shape y)
833             val xdata  = (#data x)
834             val ydata  = (#data y)
835        in
836           if  not (length xshape  = length yshape) then
837           raise Shape
838           else
839                let 
840                   val (_,newshape)   = ListPair.foldl
841                                      (fn (x,y,(i,ax)) => if (dim = i) then (i+1,(x+y) :: ax) 
842                                                                       else if not (x=y) then raise Shape else (i+1,x :: ax))
843                                       (0,[]) (xshape, yshape)
844                   val newlength  = Index.length newshape
845                   val newdata    = Array.array(newlength,Array.sub(xdata,0))
846                in
847                    Array.copy {src=xdata,dst=newdata,di=0};
848                    Array.copy {src=ydata,dst=newdata,di=(Index.length xshape)};
849                    {shape = newshape,
850                     indexer = Index.indexer newshape,
851                     data = newdata}
852                end
853        end)
854
855        fun new (shape, init) =
856            if not (Index.validShape shape) then
857                raise Shape
858            else
859                let val length = Index.length shape in
860                    {shape = shape,
861                     indexer = Index.indexer shape,
862                     data = Array.array(length,init)}
863                end
864
865        fun toArray {shape, indexer, data} = data
866
867        fun length {shape, indexer, data} =  Array.length data
868
869        fun shape {shape, indexer, data} = shape
870
871        fun rank t = List.length (shape t)
872
873        fun reshape new_shape tensor =
874            if Index.validShape new_shape then
875                case (Index.length new_shape) = length tensor of
876                    true => make'(new_shape, toArray tensor)
877                  | false => raise Match
878            else
879                raise Shape
880
881        fun fromArray (s, a) =
882            case Index.validShape s andalso 
883                 ((Index.length s) = (Array.length a)) of
884                 true => make'(s, a)
885               | false => raise Shape
886
887        fun fromList (s, a) = fromArray (s, Array.fromList a)
888
889        fun tabulate (shape,f) =
890            if Index.validShape shape then
891                let val last = Index.last shape
892                    val length = Index.length shape
893                    val c = Array.array(length, f last)
894                    fun dotable (c, indices, i) =
895                        (Array.update(c, i, f indices);
896                         case i of
897                             0 => c
898                           | i => dotable(c, Index.prev' shape indices, i-1))
899                in
900                    make'(shape,dotable(c, Index.prev' shape last, length-1))
901                end
902            else
903                raise Shape
904
905        (*----- ELEMENTWISE OPERATIONS -----*)
906
907        fun sub (t, index) = Array.sub(#data t, toInt t index)
908
909        fun update (t, index, value) =
910            Array.update(toArray t, toInt t index, value)
911
912        fun map f {shape, indexer, data} =
913            {shape = shape, indexer = indexer, data = array_map f data}
914
915        fun map2 f t1 t2=
916            let val {shape, indexer, data} = t1
917                val {shape=shape2, indexer=indexer2, data=data2} = t2
918                fun apply i = f (Array.sub(data,i), Array.sub(data2,i))
919                val len = Array.length data
920            in
921                if Index.eq(shape, shape2) then
922                    {shape = shape,
923                     indexer = indexer,
924                     data = Array.tabulate(len, apply)}
925                else
926                    raise Match
927        end
928
929        fun appi f tensor = Array.appi f (toArray tensor)
930
931        fun app f tensor = Array.app f (toArray tensor)
932
933        fun all f tensor =
934            let val a = toArray tensor
935            in Loop.all(0, length tensor - 1, fn i =>
936                        f (Array.sub(a, i)))
937            end
938
939        fun any f tensor =
940            let val a = toArray tensor
941            in Loop.any(0, length tensor - 1, fn i =>
942                        f (Array.sub(a, i)))
943            end
944
945        fun foldl f init {shape, indexer, data=a} index =
946            let val (head,lk,tail) = splitList(shape, index)
947                val li = Index.length head
948                val lj = Index.length tail
949                val c = Array.array(li * lj,init)
950                fun loopi (0, _,  _)  = ()
951                  | loopi (i, ia, ic) =
952                    (Array.update(c, ic, f(Array.sub(c,ic), Array.sub(a,ia)));
953                     loopi (i-1, ia+1, ic+1))
954                fun loopk (0, ia, _)  = ia
955                  | loopk (k, ia, ic) = (loopi (li, ia, ic);
956                                         loopk (k-1, ia+li, ic))
957                fun loopj (0, _,  _)  = ()
958                  | loopj (j, ia, ic) = loopj (j-1, loopk(lk,ia,ic), ic+li)
959            in
960                loopj (lj, 0, 0);
961                make'(head @ tail, c)
962            end
963
964    end
965    end (* Tensor *)
966
967
968structure TensorSlice : TENSOR_SLICE =
969    struct
970        structure Tensor = Tensor
971        structure Index  = Tensor.Index
972        structure Range  = Range
973           
974        type index = Tensor.Index.t
975        type range = Range.t
976        type 'a tensor = 'a Tensor.tensor
977
978        type 'a slice = {range : range, shape: index, tensor : 'a tensor}
979
980(*
981        val slice : ((index * index) list) * 'a tensor -> 'a slice
982        val length : 'a slice -> int
983        val base : 'a slice -> range * 'a tensor
984        val shape : 'a slice -> (index)
985*)
986        fun slice (rs,tensor) =
987            let val r = (Range.ranges (Tensor.shape tensor) rs)
988            in
989                {range=r,
990                 shape=(Range.shape r),
991                 tensor=tensor}
992            end
993        fun length ({range, shape, tensor}) = Range.length range
994        fun base ({range, shape, tensor}) = tensor
995        fun shape ({range, shape, tensor}) = Range.shape range
996
997
998    end                               
999
1000
1001(*
1002 Copyright (c) Juan Jose Garcia Ripoll.
1003 All rights reserved.
1004 
1005 Refer to the COPYRIGHT file for license conditions
1006*)
1007
1008(*
1009 MONO_TENSOR            - signature -
1010
1011 Monomorphic tensor of arbitrary data (not only numbers). Operations
1012 should be provided to run the data in several ways, according to one
1013 index.
1014
1015 type tensor
1016        The type of the tensor itself
1017 type elem
1018        The type of every element
1019 val storage = RowMajor | ColumnMajor
1020        RowMajor = data is stored in consecutive cells, first index
1021        varying fastest (FORTRAN convention)
1022        ColumnMajor = data is stored in consecutive cells, last
1023        index varying fastest (C,C++,Pascal,CommonLisp convention)
1024 new ([i1,...,in],init)
1025        Build a new tensor with n indices, each of sizes i1...in,
1026        filled with 'init'.
1027 fromArray (shape,data)
1028 fromList (shape,data)
1029        Use 'data' to fill a tensor of that shape. An exception is
1030        raised if 'data' is too large or too small to properly
1031        fill the vector. Later use of a 'data' array is disregarded
1032        -- one must think that the tensor now owns the array.
1033 length tensor
1034 rank tensor
1035 shape tensor
1036        Return the number of elements, the number of indices and
1037        the shape (size of each index) of the tensor.
1038 toArray tensor
1039        Return the data of the tensor in the form of an array.
1040        Mutation of this array may lead to unexpected behavior.
1041        The data in the array is stored according to `storage'.
1042
1043 sub (tensor,[i1,...,in])
1044 update (tensor,[i1,...,in],new_value)
1045        Access the element that is indexed by the numbers [i1,..,in]
1046
1047 map f a
1048 mapi f a
1049        Produce a new array by mapping the function sequentially
1050        as specified by 'storage', to each element of tensor 'a'.
1051        In 'mapi' the function receives a (indices,value) tuple,
1052        while in 'map' it only receives the value.
1053 app f a
1054 appi f a
1055        The same as 'map' and 'mapi' but the function 'f' outputs
1056        nothing and no new array is produced, i.e. one only seeks
1057        the side effect that 'f' may produce.
1058 map2 operation a b
1059        Apply function 'f' to pairs of elements of 'a' and 'b'
1060        and build a new tensor with the output. Both operands
1061        must have the same shape or an exception is raised.
1062        The procedure is sequential, as specified by 'storage'.
1063 foldl operation a n
1064        Fold-left the elements of tensor 'a' along the n-th
1065        index.
1066 all test a
1067 any test a
1068        Folded boolean tests on the elements of the tensor.
1069
1070 map', map2', foldl'
1071        Polymorphic versions of map, map2, foldl.
1072*)
1073
1074signature MONO_TENSOR =
1075    sig
1076        structure Array : MONO_ARRAY
1077        structure Index : INDEX
1078        type index = Index.t
1079        type elem
1080        type tensor
1081        type t = tensor
1082
1083        val new : index * elem -> tensor
1084        val tabulate : index * (index -> elem) -> tensor
1085        val length : tensor -> int
1086        val rank : tensor -> int
1087        val shape : tensor -> (index)
1088        val reshape : index -> tensor -> tensor
1089        val fromList : index * elem list -> tensor
1090        val fromArray : index * Array.array -> tensor
1091        val toArray : tensor -> Array.array
1092
1093        val sub : tensor * index -> elem
1094        val update : tensor * index * elem -> unit
1095        val map : (elem -> elem) -> tensor -> tensor
1096        val map2 : (elem * elem -> elem) -> tensor -> tensor -> tensor
1097        val app : (elem -> unit) -> tensor -> unit
1098        val appi : (int * elem -> unit) -> tensor -> unit
1099        val foldl : (elem * 'a -> 'a) -> 'a -> tensor -> tensor
1100        val foldln : (elem * elem -> elem) -> elem -> tensor -> int -> tensor
1101        val all : (elem -> bool) -> tensor -> bool
1102        val any : (elem -> bool) -> tensor -> bool
1103
1104        val map' : (elem -> 'a) -> tensor -> 'a Tensor.tensor
1105        val map2' : (elem * elem -> 'a) -> tensor -> tensor -> 'a Tensor.tensor
1106        val foldl' : ('a * elem -> 'a) -> 'a -> tensor -> int -> 'a Tensor.tensor
1107    end
1108
1109(*
1110 NUMBER         - Signature -
1111
1112 Guarantees a structure with a minimal number of mathematical operations
1113 so as to build an algebraic structure named Tensor.
1114 *)
1115
1116signature NUMBER =
1117    sig
1118        type t
1119        val zero : t
1120        val one  : t
1121        val ~ : t -> t
1122        val + : t * t -> t
1123        val - : t * t -> t
1124        val * : t * t -> t
1125        val / : t * t -> t
1126        val toString : t -> string
1127    end
1128
1129signature NUMBER =
1130    sig
1131        type t
1132        val zero : t
1133        val one : t
1134
1135        val + : t * t -> t
1136        val - : t * t -> t
1137        val * : t * t -> t
1138        val *+ : t * t * t -> t
1139        val *- : t * t * t -> t
1140        val ** : t * int -> t
1141
1142        val ~ : t -> t
1143        val abs : t -> t
1144        val signum : t -> t
1145
1146        val == : t * t -> bool
1147        val != : t * t -> bool
1148
1149        val toString : t -> string
1150        val fromInt : int -> t
1151        val scan : (char,'a) StringCvt.reader -> (t,'a) StringCvt.reader
1152    end
1153
1154signature INTEGRAL_NUMBER =
1155    sig
1156        include NUMBER
1157
1158        val quot : t * t -> t
1159        val rem  : t * t -> t
1160        val mod  : t * t -> t
1161        val div  : t * t -> t
1162
1163        val compare : t * t -> order
1164        val < : t * t -> bool
1165        val > : t * t -> bool
1166        val <= : t * t -> bool
1167        val >= : t * t -> bool
1168
1169        val max : t * t -> t
1170        val min : t * t -> t
1171    end
1172
1173signature FRACTIONAL_NUMBER =
1174    sig
1175        include NUMBER
1176
1177        val pi : t
1178        val e : t
1179
1180        val / : t * t -> t
1181        val recip : t -> t
1182
1183        val ln : t -> t
1184        val pow : t * t -> t
1185        val exp : t -> t
1186        val sqrt : t -> t
1187
1188        val cos : t -> t
1189        val sin : t -> t
1190        val tan : t -> t
1191        val sinh : t -> t
1192        val cosh : t -> t
1193        val tanh : t -> t
1194
1195        val acos : t -> t
1196        val asin : t -> t
1197        val atan : t -> t
1198        val asinh : t -> t
1199        val acosh : t -> t
1200        val atanh : t -> t
1201        val atan2 : t * t -> t
1202    end
1203
1204signature REAL_NUMBER =
1205    sig
1206        include FRACTIONAL_NUMBER
1207
1208        val compare : t * t -> order
1209        val < : t * t -> bool
1210        val > : t * t -> bool
1211        val <= : t * t -> bool
1212        val >= : t * t -> bool
1213
1214        val max : t * t -> t
1215        val min : t * t -> t
1216    end
1217
1218signature COMPLEX_NUMBER =
1219    sig
1220        include FRACTIONAL_NUMBER
1221
1222        structure Real : REAL_NUMBER
1223        type real = Real.t
1224
1225        val make : real * real -> t
1226        val split : t -> real * real
1227        val realPart : t -> real
1228        val imagPart : t -> real
1229        val abs2 : t -> real
1230    end
1231
1232structure INumber : INTEGRAL_NUMBER =
1233    struct
1234        open Int
1235        type t = Int.int
1236        val zero = 0
1237        val one = 1
1238
1239        infix **
1240        fun i ** n =
1241            let fun loop 0 = 1
1242                  | loop 1 = i
1243                  | loop n =
1244                let val x = loop (Int.div(n, 2))
1245                    val m = Int.mod(n, 2)
1246                in
1247                    if m = 0 then
1248                        x * x
1249                    else
1250                        x * x * i
1251                end
1252            in if n < 0
1253               then raise Domain
1254               else loop n
1255            end
1256
1257        fun signum i = case compare(i, 0) of
1258            GREATER => 1
1259          | EQUAL => 0
1260          | LESS => ~1
1261
1262        infix ==
1263        infix !=
1264        fun a == b = a = b
1265        fun a != b = (a <> b)
1266        fun *+(b,c,a) = b * c + a
1267        fun *-(b,c,a) = b * c - b
1268
1269        fun scan getc = Int.scan StringCvt.DEC getc
1270    end
1271
1272structure RNumber : REAL_NUMBER =
1273    struct
1274        open Real
1275        open Real.Math
1276        type t = Real.real
1277        val zero = 0.0
1278        val one = 1.0
1279
1280        fun signum x = case compare(x,0.0) of
1281            LESS => ~1.0
1282          | GREATER => 1.0
1283          | EQUAL => 0.0
1284
1285        fun recip x = 1.0 / x
1286
1287        infix **
1288        fun i ** n =
1289            let fun loop 0 = one
1290                  | loop 1 = i
1291                  | loop n =
1292                let val x = loop (Int.div(n, 2))
1293                    val m = Int.mod(n, 2)
1294                in
1295                    if m = 0 then
1296                        x * x
1297                    else
1298                        x * x * i
1299                end
1300            in if Int.<(n, 0)
1301               then raise Domain
1302               else loop n
1303            end
1304
1305        fun max (a, b) = if a < b then b else a
1306        fun min (a, b) = if a < b then a else b
1307
1308        fun asinh x = ln (x + sqrt(1.0 + x * x))
1309        fun acosh x = ln (x + (x + 1.0) * sqrt((x - 1.0)/(x + 1.0)))
1310        fun atanh x = ln ((1.0 + x) / sqrt(1.0 - x * x))
1311
1312    end
1313(*
1314 Complex(R)     - Functor -
1315
1316 Provides support for complex numbers based on tuples. Should be
1317 highly efficient as most operations can be inlined.
1318 *)
1319
1320structure CNumber : COMPLEX_NUMBER =
1321struct
1322        structure Real = RNumber
1323
1324        type t = Real.t * Real.t
1325        type real = Real.t
1326
1327        val zero = (0.0,0.0)
1328        val one = (1.0,0.0)
1329        val pi = (Real.pi, 0.0)
1330        val e = (Real.e, 0.0)
1331
1332        fun make (r,i) = (r,i) : t
1333        fun split z = z
1334        fun realPart (r,_) = r
1335        fun imagPart (_,i) = i
1336
1337        fun abs2 (r,i) = Real.+(Real.*(r,r),Real.*(i,i)) (* FIXME!!! *)
1338        fun arg (r,i) = Real.atan2(i,r)
1339        fun modulus z = Real.sqrt(abs2 z)
1340        fun abs z = (modulus z, 0.0)
1341        fun signum (z as (r,i)) =
1342            let val m = modulus z
1343            in (Real./(r,m), Real./(i,m))
1344            end
1345
1346        fun ~ (r1,i1) = (Real.~ r1, Real.~ i1)
1347        fun (r1,i1) + (r2,i2) = (Real.+(r1,r2), Real.+(i1,i2))
1348        fun (r1,i1) - (r2,i2) = (Real.-(r1,r2), Real.-(i1,i1))
1349        fun (r1,i1) * (r2,i2) = (Real.-(Real.*(r1,r2),Real.*(i1,i2)),
1350                                 Real.+(Real.*(r1,i2),Real.*(r2,i1)))
1351        fun (r1,i1) / (r2,i2) =
1352            let val modulus = abs2(r2,i2)
1353                val (nr,ni) = (r1,i1) * (r2,i2)
1354            in
1355                (Real./(nr,modulus), Real./(ni,modulus))
1356            end
1357        fun *+((r1,i1),(r2,i2),(r0,i0)) =
1358            (Real.*+(Real.~ i1, i2, Real.*+(r1,r2,r0)),
1359             Real.*+(r2, i2, Real.*+(r1,i2,i0)))
1360        fun *-((r1,i1),(r2,i2),(r0,i0)) =
1361            (Real.*+(Real.~ i1, i2, Real.*-(r1,r2,r0)),
1362             Real.*+(r2, i2, Real.*-(r1,i2,i0)))
1363
1364        infix **
1365        fun i ** n =
1366            let fun loop 0 = one
1367                  | loop 1 = i
1368                  | loop n =
1369                let val x = loop (Int.div(n, 2))
1370                    val m = Int.mod(n, 2)
1371                in
1372                    if m = 0 then
1373                        x * x
1374                    else
1375                        x * x * i
1376                end
1377            in if Int.<(n, 0)
1378                   then raise Domain
1379               else loop n
1380            end
1381
1382        fun recip (r1, i1) = 
1383            let val modulus = abs2(r1, i1)
1384            in (Real./(r1, modulus), Real./(Real.~ i1, modulus))
1385            end
1386        fun ==(z, w) = Real.==(realPart z, realPart w) andalso Real.==(imagPart z, imagPart w)
1387        fun !=(z, w) = Real.!=(realPart z, realPart w) andalso Real.!=(imagPart z, imagPart w)
1388        fun fromInt i = (Real.fromInt i, 0.0)
1389        fun toString (r,i) =
1390            String.concat ["(",Real.toString r,",",Real.toString i,")"]
1391
1392        fun exp (x, y) =
1393            let val expx = Real.exp x
1394            in (Real.*(x, (Real.cos y)), Real.*(x, (Real.sin y)))
1395            end
1396
1397    local
1398        val half = Real.recip (Real.fromInt 2)
1399    in
1400        fun sqrt (z as (x,y)) =
1401            if Real.==(x, 0.0) andalso Real.==(y, 0.0) then
1402                zero
1403            else
1404                let val m = Real.+(modulus z, Real.abs x)
1405                    val u' = Real.sqrt (Real.*(m, half))
1406                    val v' = Real./(Real.abs y , Real.+(u',u'))
1407                    val (u,v) = if Real.<(x, 0.0) then (v',u') else (u',v')
1408                in (u, if Real.<(y, 0.0) then Real.~ v else v)
1409                end
1410    end
1411        fun ln z = (Real.ln (modulus z), arg z)
1412
1413        fun pow (z, n) =
1414            let val l = ln z
1415            in exp (l * n)
1416            end
1417
1418        fun sin (x, y) = (Real.*(Real.sin x, Real.cosh y),
1419                          Real.*(Real.cos x, Real.sinh y))
1420        fun cos (x, y) = (Real.*(Real.cos x, Real.cosh y),
1421                          Real.~ (Real.*(Real.sin x, Real.sinh y)))
1422        fun tan (x, y) =
1423            let val (sx, cx) = (Real.sin x, Real.cos x)
1424                val (shy, chy) = (Real.sinh y, Real.cosh y)
1425                val a = (Real.*(sx, chy), Real.*(cx, shy))
1426                val b = (Real.*(cx, chy), Real.*(Real.~ sx, shy))
1427            in a / b
1428            end
1429
1430        fun sinh (x, y) = (Real.*(Real.cos y, Real.sinh x),
1431                           Real.*(Real.sin y, Real.cosh x))
1432        fun cosh (x, y) = (Real.*(Real.cos y, Real.cosh x),
1433                           Real.*(Real.sin y, Real.sinh x))
1434        fun tanh (x, y) =
1435            let val (sy, cy) = (Real.sin y, Real.cos y)
1436                val (shx, chx) = (Real.sinh x, Real.cosh x)
1437                val a = (Real.*(cy, shx), Real.*(sy, chx))
1438                val b = (Real.*(cy, chx), Real.*(sy, shx))
1439            in a / b
1440            end
1441
1442        fun asin (z as (x,y)) =
1443            let val w = sqrt (one - z * z)
1444                val (x',y') = ln ((Real.~ y, x) + w)
1445            in (y', Real.~ x')
1446            end
1447
1448        fun acos (z as (x,y)) = 
1449            let val (x', y') = sqrt (one + z * z)
1450                val (x'', y'') = ln (z + (Real.~ y', x'))
1451            in (y'', Real.~ x'')
1452            end
1453
1454        fun atan (z as (x,y)) =
1455            let val w = sqrt (one + z*z)
1456                val (x',y') = ln ((Real.-(1.0, y), x) / w)
1457            in (y', Real.~ x')
1458            end
1459
1460        fun atan2 (y, x) = atan(y / x)
1461
1462        fun asinh x = ln (x + sqrt(one + x * x))
1463        fun acosh x = ln (x + (x + one) * sqrt((x - one)/(x + one)))
1464        fun atanh x = ln ((one + x) / sqrt(one - x * x))
1465
1466        fun scan getc =
1467            let val scanner = Real.scan getc
1468            in fn stream => 
1469                  case scanner stream of
1470                      NONE => NONE
1471                    | SOME (a, rest) =>
1472                      case scanner rest of
1473                          NONE => NONE
1474                        | SOME (b, rest) => SOME (make(a,b), rest)
1475            end
1476
1477end (* ComplexNumber *)
1478
1479(*
1480 Copyright (c) Juan Jose Garcia Ripoll.
1481 All rights reserved.
1482 
1483 Refer to the COPYRIGHT file for license conditions
1484*)
1485
1486structure INumberArray =
1487    struct
1488        open Array
1489        type array = INumber.t array
1490        type vector = INumber.t vector
1491        type elem  = INumber.t
1492        structure Vector =
1493            struct
1494                open Vector
1495                type vector = INumber.t Vector.vector
1496                type elem = INumber.t
1497            end
1498        fun map f a = tabulate(length a, fn x => (f (sub(a,x))))
1499        fun mapi f a = tabulate(length a, fn x => (f (x,sub(a,x))))
1500        fun map2 f a b = tabulate(length a, fn x => (f(sub(a,x),sub(b,x))))
1501    end
1502
1503structure RNumberArray =
1504    struct
1505        open Real64Array
1506        val sub = Unsafe.Real64Array.sub
1507        val update = Unsafe.Real64Array.update
1508        fun map f a = tabulate(length a, fn x => (f (sub(a,x))))
1509        fun mapi f a = tabulate(length a, fn x => (f (x,sub(a,x))))
1510        fun map2 f a b = tabulate(length a, fn x => (f(sub(a,x),sub(b,x))))
1511    end
1512structure RNumber : REAL_NUMBER =
1513    struct
1514        open Real
1515        open Real.Math
1516        type t = Real.real
1517        val zero = 0.0
1518        val one = 1.0
1519        fun signum x = case compare(x,0.0) of
1520            LESS => ~1.0
1521          | GREATER => 1.0
1522          | EQUAL => 0.0
1523        fun recip x = 1.0 / x
1524        infix **
1525        fun i ** n =
1526            let fun loop 0 = one
1527                  | loop 1 = i
1528                  | loop n =
1529                let val x = loop (Int.div(n, 2))
1530                    val m = Int.mod(n, 2)
1531                in
1532                    if m = 0 then
1533                        x * x
1534                    else
1535                        x * x * i
1536                end
1537            in if Int.<(n, 0)
1538               then raise Domain
1539               else loop n
1540            end
1541        fun max (a, b) = if a < b then b else a
1542        fun min (a, b) = if a < b then a else b
1543        fun asinh x = ln (x + sqrt(1.0 + x * x))
1544        fun acosh x = ln (x + (x + 1.0) * sqrt((x - 1.0)/(x + 1.0)))
1545        fun atanh x = ln ((1.0 + x) / sqrt(1.0 - x * x))
1546    end
1547(*
1548 Complex(R)     - Functor -
1549 Provides support for complex numbers based on tuples. Should be
1550 highly efficient as most operations can be inlined.
1551 *)
1552structure CNumber : COMPLEX_NUMBER =
1553struct
1554        structure Real = RNumber
1555        type t = Real.t * Real.t
1556        type real = Real.t
1557        val zero = (0.0,0.0)
1558        val one = (1.0,0.0)
1559        val pi = (Real.pi, 0.0)
1560        val e = (Real.e, 0.0)
1561        fun make (r,i) = (r,i) : t
1562        fun split z = z
1563        fun realPart (r,_) = r
1564        fun imagPart (_,i) = i
1565        fun abs2 (r,i) = Real.+(Real.*(r,r),Real.*(i,i)) (* FIXME!!! *)
1566        fun arg (r,i) = Real.atan2(i,r)
1567        fun modulus z = Real.sqrt(abs2 z)
1568        fun abs z = (modulus z, 0.0)
1569        fun signum (z as (r,i)) =
1570            let val m = modulus z
1571            in (Real./(r,m), Real./(i,m))
1572            end
1573        fun ~ (r1,i1) = (Real.~ r1, Real.~ i1)
1574        fun (r1,i1) + (r2,i2) = (Real.+(r1,r2), Real.+(i1,i2))
1575        fun (r1,i1) - (r2,i2) = (Real.-(r1,r2), Real.-(i1,i1))
1576        fun (r1,i1) * (r2,i2) = (Real.-(Real.*(r1,r2),Real.*(i1,i2)),
1577                                 Real.+(Real.*(r1,i2),Real.*(r2,i1)))
1578        fun (r1,i1) / (r2,i2) =
1579            let val modulus = abs2(r2,i2)
1580                val (nr,ni) = (r1,i1) * (r2,i2)
1581            in
1582                (Real./(nr,modulus), Real./(ni,modulus))
1583            end
1584        fun *+((r1,i1),(r2,i2),(r0,i0)) =
1585            (Real.*+(Real.~ i1, i2, Real.*+(r1,r2,r0)),
1586             Real.*+(r2, i2, Real.*+(r1,i2,i0)))
1587        fun *-((r1,i1),(r2,i2),(r0,i0)) =
1588            (Real.*+(Real.~ i1, i2, Real.*-(r1,r2,r0)),
1589             Real.*+(r2, i2, Real.*-(r1,i2,i0)))
1590        infix **
1591        fun i ** n =
1592            let fun loop 0 = one
1593                  | loop 1 = i
1594                  | loop n =
1595                let val x = loop (Int.div(n, 2))
1596                    val m = Int.mod(n, 2)
1597                in
1598                    if m = 0 then
1599                        x * x
1600                    else
1601                        x * x * i
1602                end
1603            in if Int.<(n, 0)
1604                   then raise Domain
1605               else loop n
1606            end
1607        fun recip (r1, i1) = 
1608            let val modulus = abs2(r1, i1)
1609            in (Real./(r1, modulus), Real./(Real.~ i1, modulus))
1610            end
1611        fun ==(z, w) = Real.==(realPart z, realPart w) andalso Real.==(imagPart z, imagPart w)
1612        fun !=(z, w) = Real.!=(realPart z, realPart w) andalso Real.!=(imagPart z, imagPart w)
1613        fun fromInt i = (Real.fromInt i, 0.0)
1614        fun toString (r,i) =
1615            String.concat ["(",Real.toString r,",",Real.toString i,")"]
1616        fun exp (x, y) =
1617            let val expx = Real.exp x
1618            in (Real.*(x, (Real.cos y)), Real.*(x, (Real.sin y)))
1619            end
1620    local
1621        val half = Real.recip (Real.fromInt 2)
1622    in
1623        fun sqrt (z as (x,y)) =
1624            if Real.==(x, 0.0) andalso Real.==(y, 0.0) then
1625                zero
1626            else
1627                let val m = Real.+(modulus z, Real.abs x)
1628                    val u' = Real.sqrt (Real.*(m, half))
1629                    val v' = Real./(Real.abs y , Real.+(u',u'))
1630                    val (u,v) = if Real.<(x, 0.0) then (v',u') else (u',v')
1631                in (u, if Real.<(y, 0.0) then Real.~ v else v)
1632                end
1633    end
1634        fun ln z = (Real.ln (modulus z), arg z)
1635        fun pow (z, n) =
1636            let val l = ln z
1637            in exp (l * n)
1638            end
1639        fun sin (x, y) = (Real.*(Real.sin x, Real.cosh y),
1640                          Real.*(Real.cos x, Real.sinh y))
1641        fun cos (x, y) = (Real.*(Real.cos x, Real.cosh y),
1642                          Real.~ (Real.*(Real.sin x, Real.sinh y)))
1643        fun tan (x, y) =
1644            let val (sx, cx) = (Real.sin x, Real.cos x)
1645                val (shy, chy) = (Real.sinh y, Real.cosh y)
1646                val a = (Real.*(sx, chy), Real.*(cx, shy))
1647                val b = (Real.*(cx, chy), Real.*(Real.~ sx, shy))
1648            in a / b
1649            end
1650        fun sinh (x, y) = (Real.*(Real.cos y, Real.sinh x),
1651                           Real.*(Real.sin y, Real.cosh x))
1652        fun cosh (x, y) = (Real.*(Real.cos y, Real.cosh x),
1653                           Real.*(Real.sin y, Real.sinh x))
1654        fun tanh (x, y) =
1655            let val (sy, cy) = (Real.sin y, Real.cos y)
1656                val (shx, chx) = (Real.sinh x, Real.cosh x)
1657                val a = (Real.*(cy, shx), Real.*(sy, chx))
1658                val b = (Real.*(cy, chx), Real.*(sy, shx))
1659            in a / b
1660            end
1661        fun asin (z as (x,y)) =
1662            let val w = sqrt (one - z * z)
1663                val (x',y') = ln ((Real.~ y, x) + w)
1664            in (y', Real.~ x')
1665            end
1666        fun acos (z as (x,y)) = 
1667            let val (x', y') = sqrt (one + z * z)
1668                val (x'', y'') = ln (z + (Real.~ y', x'))
1669            in (y'', Real.~ x'')
1670            end
1671        fun atan (z as (x,y)) =
1672            let val w = sqrt (one + z*z)
1673                val (x',y') = ln ((Real.-(1.0, y), x) / w)
1674            in (y', Real.~ x')
1675            end
1676        fun atan2 (y, x) = atan(y / x)
1677        fun asinh x = ln (x + sqrt(one + x * x))
1678        fun acosh x = ln (x + (x + one) * sqrt((x - one)/(x + one)))
1679        fun atanh x = ln ((one + x) / sqrt(one - x * x))
1680        fun scan getc =
1681            let val scanner = Real.scan getc
1682            in fn stream => 
1683                  case scanner stream of
1684                      NONE => NONE
1685                    | SOME (a, rest) =>
1686                      case scanner rest of
1687                          NONE => NONE
1688                        | SOME (b, rest) => SOME (make(a,b), rest)
1689            end
1690end (* ComplexNumber *)
1691
1692(*
1693 Copyright (c) Juan Jose Garcia Ripoll.
1694 All rights reserved.
1695 Refer to the COPYRIGHT file for license conditions
1696*)
1697structure INumberArray =
1698    struct
1699        open Array
1700        type array = INumber.t array
1701        type vector = INumber.t vector
1702        type elem  = INumber.t
1703        structure Vector =
1704            struct
1705                open Vector
1706                type vector = INumber.t Vector.vector
1707                type elem = INumber.t
1708            end
1709        fun map f a = tabulate(length a, fn x => (f (sub(a,x))))
1710        fun mapi f a = tabulate(length a, fn x => (f (x,sub(a,x))))
1711        fun map2 f a b = tabulate(length a, fn x => (f(sub(a,x),sub(b,x))))
1712    end
1713structure RNumberArray =
1714    struct
1715        open Real64Array
1716        val sub = Unsafe.Real64Array.sub
1717        val update = Unsafe.Real64Array.update
1718        fun map f a = tabulate(length a, fn x => (f (sub(a,x))))
1719        fun mapi f a = tabulate(length a, fn x => (f (x,sub(a,x))))
1720        fun map2 f a b = tabulate(length a, fn x => (f(sub(a,x),sub(b,x))))
1721    end
1722(*--------------------- COMPLEX ARRAY -------------------------*)
1723structure BasicCNumberArray =
1724struct
1725        structure Complex : COMPLEX_NUMBER = CNumber
1726        structure Array : MONO_ARRAY = RNumberArray
1727        type elem = Complex.t
1728        type array = Array.array * Array.array
1729        val maxLen = Array.maxLen
1730        fun length (a,b) = Array.length a
1731        fun sub ((a,b),index) = Complex.make(Array.sub(a,index),Array.sub(b,index))
1732        fun update ((a,b),index,z) =
1733            let val (re,im) = Complex.split z in
1734                Array.update(a, index, re);
1735                Array.update(b, index, im)
1736            end
1737    local
1738        fun makeRange (a, start, NONE) = makeRange(a, start, SOME (length a - 1))
1739          | makeRange (a, start, SOME last) =
1740            let val len = length a
1741                val diff = last - start
1742            in
1743                if (start >= len) orelse (last >= len) then
1744                    raise Subscript
1745                else if diff < 0 then
1746                    (a, start, 0)
1747                else
1748                    (a, start, diff + 1)
1749            end
1750    in
1751        fun array (size,z:elem) =
1752            let val realsize = size * 2
1753                val r = Complex.realPart z
1754                val i = Complex.imagPart z in
1755                    (Array.array(size,r), Array.array(size,i))
1756            end
1757        fun zeroarray size =
1758            (Array.array(size,Complex.Real.zero),
1759             Array.array(size,Complex.Real.zero))
1760        fun tabulate (size,f) =
1761            let val a = array(size, Complex.zero)
1762                fun loop i =
1763                    case i = size of
1764                        true => a
1765                      | false => (update(a, i, f i); loop (i+1))
1766            in
1767                loop 0
1768            end
1769        fun fromList list =
1770            let val length = List.length list
1771                val a = zeroarray length
1772                fun loop (_, []) = a
1773                  | loop (i, z::rest) = (update(a, i, z);
1774                                         loop (i+1, rest))
1775            in
1776                loop(0,list)
1777            end
1778        fun extract range =
1779            let val (a, start, len) = makeRange range
1780                fun copy i = sub(a, i + start)
1781            in tabulate(len, copy)
1782            end
1783        fun concat array_list =
1784            let val total_length = foldl (op +) 0 (map length array_list)
1785                val a = array(total_length, Complex.zero)
1786                fun copy (_, []) = a
1787                  | copy (pos, v::rest) =
1788                    let fun loop i =
1789                        case i = 0 of
1790                            true => ()
1791                          | false => (update(a, i+pos, sub(v, i)); loop (i-1))
1792                    in (loop (length v - 1); copy(length v + pos, rest))
1793                    end
1794            in
1795                copy(0, array_list)
1796            end
1797        fun copy {src : array, si : int, len : int option, dst : array, di : int } =
1798            let val (a, ia, la) = makeRange (src, si, len)
1799                val (b, ib, lb) = makeRange (dst, di, len)
1800                fun copy i =
1801                    case i < 0 of
1802                        true => ()
1803                      | false => (update(b, i+ib, sub(a, i+ia)); copy (i-1))
1804            in copy (la - 1)
1805            end
1806        val copyVec = copy
1807        fun modifyi f range =
1808            let val (a, start, len) = makeRange range
1809                val last = start + len
1810                fun loop i =
1811                    case i >= last of
1812                        true => ()
1813                      | false => (update(a, i, f(i, sub(a,i))); loop (i+1))
1814            in loop start
1815            end
1816        fun modify f a =
1817            let val last = length a
1818                fun loop i =
1819                    case i >= last of
1820                        true => ()
1821                      | false => (update(a, i, f(sub(a,i))); loop (i+1))
1822            in loop 0
1823            end
1824        fun app f a =
1825            let val size = length a
1826                fun loop i =
1827                    case i = size of
1828                        true => ()
1829                      | false => (f(sub(a,i)); loop (i+1))
1830            in
1831                loop 0
1832            end
1833        fun appi f range =
1834            let val (a, start, len) = makeRange range
1835                val last = start + len
1836                fun loop i =
1837                    case i >= last of
1838                        true => ()
1839                      | false => (f(i, sub(a,i)); loop (i+1))
1840            in
1841                loop start
1842            end
1843        fun map f a =
1844            let val len = length a
1845                val c = zeroarray len
1846                fun loop ~1 = c
1847                  | loop i = (update(a, i, f(sub(a,i))); loop (i-1))
1848            in loop (len-1)
1849            end
1850        fun map2 f a b =
1851            let val len = length a
1852                val c = zeroarray len
1853                fun loop ~1 = c
1854                  | loop i = (update(c, i, f(sub(a,i),sub(b,i)));
1855                              loop (i-1))
1856            in loop (len-1)
1857            end
1858        fun mapi f range =
1859            let val (a, start, len) = makeRange range
1860                fun rule i = f (i+start, sub(a, i+start))
1861            in tabulate(len, rule)
1862            end
1863        fun foldli f init range =
1864            let val (a, start, len) = makeRange range
1865                val last = start + len - 1
1866                fun loop (i, accum) =
1867                    case i > last of
1868                        true => accum
1869                      | false => loop (i+1, f(i, sub(a,i), accum))
1870            in loop (start, init)
1871            end
1872        fun foldri f init range =
1873            let val (a, start, len) = makeRange range
1874                val last = start + len - 1
1875                fun loop (i, accum) =
1876                    case i < start of
1877                        true => accum
1878                      | false => loop (i-1, f(i, sub(a,i), accum))
1879            in loop (last, init)
1880            end
1881        fun foldl f init a = foldli (fn (_, a, x) => f(a,x)) init (a,0,NONE)
1882        fun foldr f init a = foldri (fn (_, x, a) => f(x,a)) init (a,0,NONE)
1883    end
1884end (* BasicCNumberArray *)
1885structure CNumberArray =
1886    struct
1887        structure Vector =
1888            struct
1889                open BasicCNumberArray
1890                type vector = array
1891            end : MONO_VECTOR
1892        type vector = Vector.vector
1893        open BasicCNumberArray
1894    end (* CNumberArray *)
1895structure ITensor =
1896    struct
1897        structure Number = INumber
1898        structure Array = INumberArray
1899(*
1900 Copyright (c) Juan Jose Garcia Ripoll.
1901 All rights reserved.
1902 Refer to the COPYRIGHT file for license conditions
1903*)
1904structure MonoTensor  =
1905    struct
1906(* PARAMETERS
1907        structure Array = Array
1908*)
1909        structure Index  = Index
1910        type elem = Array.elem
1911        type index = Index.t
1912        type tensor = {shape : index, indexer : Index.indexer, data : Array.array}
1913        type t = tensor
1914        exception Shape
1915        exception Match
1916        exception Index
1917    local
1918    (*----- LOCALS -----*)
1919        fun make' (shape, data) =
1920            {shape = shape, indexer = Index.indexer shape, data = data}
1921        fun toInt {shape, indexer, data} index = indexer index
1922        fun splitList (l as (a::rest), place) =
1923            let fun loop (left,here,right) 0 =  (List.rev left,here,right)
1924                  | loop (_,_,[]) place = raise Index
1925                  | loop (left,here,a::right) place = 
1926                loop (here::left,a,right) (place-1)
1927            in
1928                if place <= 0 then
1929                    loop ([],a,rest) (List.length rest - place)
1930                else
1931                    loop ([],a,rest) (place - 1)
1932            end
1933    in
1934    (*----- STRUCTURAL OPERATIONS & QUERIES ------*)
1935      fun cat (dim, x: tensor, y: tensor) =
1936        (let val xshape = (#shape x)
1937             val yshape = (#shape y)
1938             val xdata  = (#data x)
1939             val ydata  = (#data y)
1940        in
1941           if  not (length xshape  = length yshape) then
1942           raise Shape
1943           else
1944                let 
1945                   val (_,newshape)   = ListPair.foldl
1946                                      (fn (x,y,(i,ax)) => if (dim = i) then (i+1,(x+y) :: ax) 
1947                                                                       else if not (x=y) then raise Shape else (i+1,x :: ax))
1948                                       (0,[]) (xshape, yshape)
1949                   val newlength  = Index.length newshape
1950                   val newdata    = Array.array(newlength,Array.sub(xdata,0))
1951                in
1952                    Array.copy {src=xdata,dst=newdata,di=0};
1953                    Array.copy {src=ydata,dst=newdata,di=(Index.length xshape)};
1954                    {shape = newshape,
1955                     indexer = Index.indexer newshape,
1956                     data = newdata}
1957                end
1958        end)
1959
1960        fun new (shape, init) =
1961            if not (Index.validShape shape) then
1962                raise Shape
1963            else
1964                let val length = Index.length shape in
1965                    {shape = shape,
1966                     indexer = Index.indexer shape,
1967                     data = Array.array(length,init)}
1968                end
1969        fun toArray {shape, indexer, data} = data
1970        fun length {shape, indexer, data} =  Array.length data
1971        fun shape {shape, indexer, data} = shape
1972        fun rank t = List.length (shape t)
1973        fun reshape new_shape tensor =
1974            if Index.validShape new_shape then
1975                case (Index.length new_shape) = length tensor of
1976                    true => make'(new_shape, toArray tensor)
1977                  | false => raise Match
1978            else
1979                raise Shape
1980        fun fromArray (s, a) =
1981            case Index.validShape s andalso 
1982                 ((Index.length s) = (Array.length a)) of
1983                 true => make'(s, a)
1984               | false => raise Shape
1985        fun fromList (s, a) = fromArray (s, Array.fromList a)
1986        fun tabulate (shape,f) =
1987            if Index.validShape shape then
1988                let val last = Index.last shape
1989                    val length = Index.length shape
1990                    val c = Array.array(length, f last)
1991                    fun dotable (c, indices, i) =
1992                        (Array.update(c, i, f indices);
1993                         if i <= 1
1994                         then c
1995                         else dotable(c, Index.prev' shape indices, i-1))
1996                in make'(shape,dotable(c, Index.prev' shape last, length-2))
1997                end
1998            else
1999                raise Shape
2000        (*----- ELEMENTWISE OPERATIONS -----*)
2001        fun sub (t, index) = Array.sub(#data t, toInt t index)
2002        fun update (t, index, value) =
2003            Array.update(toArray t, toInt t index, value)
2004        fun map f {shape, indexer, data} =
2005            {shape = shape, indexer = indexer, data = Array.map f data}
2006        fun map2 f t1 t2=
2007            let val {shape=shape1, indexer=indexer1, data=data1} = t1
2008                val {shape=shape2, indexer=indexer2, data=data2} = t2
2009            in
2010                if Index.eq(shape1,shape2) then
2011                    {shape = shape1,
2012                     indexer = indexer1,
2013                     data = Array.map2 f data1 data2}
2014                else
2015                    raise Match
2016        end
2017        fun appi f tensor = Array.appi f (toArray tensor)
2018        fun app f tensor = Array.app f (toArray tensor)
2019        fun all f tensor =
2020            let val a = toArray tensor
2021            in Loop.all(0, length tensor - 1, fn i =>
2022                        f (Array.sub(a, i)))
2023            end
2024        fun any f tensor =
2025            let val a = toArray tensor
2026            in Loop.any(0, length tensor - 1, fn i =>
2027                        f (Array.sub(a, i)))
2028            end
2029        fun foldl f init tensor = Array.foldl f init (toArray tensor)
2030        fun foldln f init {shape, indexer, data=a} index =
2031            let val (head,lk,tail) = splitList(shape, index)
2032                val li = Index.length head
2033                val lj = Index.length tail
2034                val c = Array.array(li * lj,init)
2035                fun loopi (0, _,  _)  = ()
2036                  | loopi (i, ia, ic) =
2037                    (Array.update(c, ic, f(Array.sub(c,ic), Array.sub(a,ia)));
2038                     loopi (i-1, ia+1, ic+1))
2039                fun loopk (0, ia, _)  = ia
2040                  | loopk (k, ia, ic) = (loopi (li, ia, ic);
2041                                         loopk (k-1, ia+li, ic))
2042                fun loopj (0, _,  _)  = ()
2043                  | loopj (j, ia, ic) = loopj (j-1, loopk(lk,ia,ic), ic+li)
2044            in
2045                loopj (lj, 0, 0);
2046                make'(head @ tail, c)
2047            end
2048        (* --- POLYMORPHIC ELEMENTWISE OPERATIONS --- *)
2049        fun array_map' f a =
2050            let fun apply index = f(Array.sub(a,index)) in
2051                Tensor.Array.tabulate(Array.length a, apply)
2052            end
2053        fun map' f t = Tensor.fromArray(shape t, array_map' f (toArray t))
2054        fun map2' f t1 t2 =
2055            let val d1 = toArray t1
2056                val d2 = toArray t2
2057                fun apply i = f (Array.sub(d1,i), Array.sub(d2,i))
2058                val len = Array.length d1
2059            in
2060                if Index.eq(shape t1, shape t2) then
2061                    Tensor.fromArray(shape t1, Tensor.Array.tabulate(len,apply))
2062                else
2063                    raise Match
2064            end
2065        fun foldl' f init {shape, indexer, data=a} index =
2066            let val (head,lk,tail) = splitList(shape, index)
2067                val li = Index.length head
2068                val lj = Index.length tail
2069                val c = Tensor.Array.array(li * lj,init)
2070                fun loopi (0, _,  _)  = ()
2071                  | loopi (i, ia, ic) =
2072                    (Tensor.Array.update(c,ic,f(Tensor.Array.sub(c,ic),Array.sub(a,ia)));
2073                     loopi (i-1, ia+1, ic+1))
2074                fun loopk (0, ia, _)  = ia
2075                  | loopk (k, ia, ic) = (loopi (li, ia, ic);
2076                                         loopk (k-1, ia+li, ic))
2077                fun loopj (0, _,  _)  = ()
2078                  | loopj (j, ia, ic) = loopj (j-1, loopk(lk,ia,ic), ic+li)
2079            in
2080                loopj (lj, 0, 0);
2081                make'(head @ tail, c)
2082            end
2083    end
2084    end (* MonoTensor *)
2085        open MonoTensor
2086    local
2087        (*
2088         LEFT INDEX CONTRACTION:
2089         a = a(i1,i2,...,in)
2090         b = b(j1,j2,...,jn)
2091         c = c(i2,...,in,j2,...,jn)
2092         = sum(a(k,i2,...,jn)*b(k,j2,...jn)) forall k
2093         MEANINGFUL VARIABLES:
2094         lk = i1 = j1
2095         li = i2*...*in
2096         lj = j2*...*jn
2097         *)
2098        fun do_fold_first a b c lk lj li =
2099            let fun loopk (0, _,  _,  accum) = accum
2100                  | loopk (k, ia, ib, accum) =
2101                    let val delta = Number.*(Array.sub(a,ia),Array.sub(b,ib))
2102                    in loopk (k-1, ia+1, ib+1, Number.+(delta,accum))
2103                    end
2104                fun loopj (0, ib, ic) = c
2105                  | loopj (j, ib, ic) =
2106                    let fun loopi (0, ia, ic) = ic
2107                          | loopi (i, ia, ic) =
2108                        (Array.update(c, ic, loopk(lk, ia, ib, Number.zero));
2109                         loopi(i-1, ia+lk, ic+1))
2110                    in
2111                        loopj(j-1, ib+lk, loopi(li, 0, ic))
2112                    end
2113            in loopj(lj, 0, 0)
2114            end
2115    in
2116        fun +* ta tb =
2117            let val (rank_a,lk::rest_a,a) = (rank ta, shape ta, toArray ta)
2118                val (rank_b,lk2::rest_b,b) = (rank tb, shape tb, toArray tb)
2119            in if not(lk = lk2)
2120               then raise Match
2121               else let val li = Index.length rest_a
2122                        val lj = Index.length rest_b
2123                        val c = Array.array(li*lj,Number.zero)
2124                    in fromArray(rest_a @ rest_b,
2125                                 do_fold_first a b c lk li lj)
2126                    end
2127            end
2128    end
2129    local
2130        (*
2131         LAST INDEX CONTRACTION:
2132         a = a(i1,i2,...,in)
2133         b = b(j1,j2,...,jn)
2134         c = c(i2,...,in,j2,...,jn)
2135         = sum(mult(a(i1,i2,...,k),b(j1,j2,...,k))) forall k
2136         MEANINGFUL VARIABLES:
2137         lk = in = jn
2138         li = i1*...*i(n-1)
2139         lj = j1*...*j(n-1)
2140         *)
2141        fun do_fold_last a b c lk lj li =
2142            let fun loopi (0, ia, ic, fac) = ()
2143                  | loopi (i, ia, ic, fac) =
2144                    let val old = Array.sub(c,ic)
2145                        val inc = Number.*(Array.sub(a,ia),fac)
2146                    in
2147                        Array.update(c,ic,Number.+(old,inc));
2148                        loopi(i-1, ia+1, ic+1, fac)
2149                    end
2150                fun loopj (j, ib, ic) =
2151                    let fun loopk (0, ia, ib) = ()
2152                          | loopk (k, ia, ib) =
2153                            (loopi(li, ia, ic, Array.sub(b,ib));
2154                             loopk(k-1, ia+li, ib+lj))
2155                    in case j of
2156                           0 => c
2157                         | _ => (loopk(lk, 0, ib);
2158                                 loopj(j-1, ib+1, ic+li))
2159                    end (* loopj *)
2160            in
2161                loopj(lj, 0, 0)
2162            end
2163    in
2164        fun *+ ta tb  =
2165            let val (rank_a,shape_a,a) = (rank ta, shape ta, toArray ta)
2166                val (rank_b,shape_b,b) = (rank tb, shape tb, toArray tb)
2167                val (lk::rest_a) = List.rev shape_a
2168                val (lk2::rest_b) = List.rev shape_b
2169            in if not(lk = lk2)
2170               then raise Match
2171               else let val li = Index.length rest_a
2172                        val lj = Index.length rest_b
2173                        val c = Array.array(li*lj,Number.zero)
2174                    in fromArray(List.rev rest_a @ List.rev rest_b,
2175                                 do_fold_last a b c lk li lj)
2176                    end
2177            end
2178    end
2179        (* ALGEBRAIC OPERATIONS *)
2180        infix **
2181        infix ==
2182        infix !=
2183        fun a + b = map2 Number.+ a b
2184        fun a - b = map2 Number.- a b
2185        fun a * b = map2 Number.* a b
2186        fun a ** i = map (fn x => (Number.**(x,i))) a
2187        fun ~ a = map Number.~ a
2188        fun abs a = map Number.abs a
2189        fun signum a = map Number.signum a
2190        fun a == b = map2' Number.== a b
2191        fun a != b = map2' Number.!= a b
2192        fun toString a = raise Domain
2193        fun fromInt a = new([1], Number.fromInt a)
2194        (* TENSOR SPECIFIC OPERATIONS *)
2195        fun *> n = map (fn x => Number.*(n,x))
2196        fun normInf a =
2197            let fun accum (y,x) = Number.max(x,Number.abs y)
2198            in  foldl accum Number.zero a
2199            end
2200    end (* NumberTensor *)
2201structure RTensor =
2202    struct
2203        structure Number = RNumber
2204        structure Array = RNumberArray
2205(*
2206 Copyright (c) Juan Jose Garcia Ripoll.
2207 All rights reserved.
2208 Refer to the COPYRIGHT file for license conditions
2209*)
2210structure MonoTensor  =
2211    struct
2212(* PARAMETERS
2213        structure Array = Array
2214*)
2215        structure Index  = Index
2216        type elem = Array.elem
2217        type index = Index.t
2218        type tensor = {shape : index, indexer : Index.indexer, data : Array.array}
2219        type t = tensor
2220        exception Shape
2221        exception Match
2222        exception Index
2223    local
2224    (*----- LOCALS -----*)
2225        fun make' (shape, data) =
2226            {shape = shape, indexer = Index.indexer shape, data = data}
2227        fun toInt {shape, indexer, data} index = indexer index
2228        fun splitList (l as (a::rest), place) =
2229            let fun loop (left,here,right) 0 =  (List.rev left,here,right)
2230                  | loop (_,_,[]) place = raise Index
2231                  | loop (left,here,a::right) place = 
2232                loop (here::left,a,right) (place-1)
2233            in
2234                if place <= 0 then
2235                    loop ([],a,rest) (List.length rest - place)
2236                else
2237                    loop ([],a,rest) (place - 1)
2238            end
2239    in
2240    (*----- STRUCTURAL OPERATIONS & QUERIES ------*)
2241      fun cat (dim, x: tensor, y: tensor) =
2242        (let val xshape = (#shape x)
2243             val yshape = (#shape y)
2244             val xdata  = (#data x)
2245             val ydata  = (#data y)
2246        in
2247           if  not (length xshape  = length yshape) then
2248           raise Shape
2249           else
2250                let 
2251                   val (_,newshape)   = ListPair.foldl
2252                                      (fn (x,y,(i,ax)) => if (dim = i) then (i+1,(x+y) :: ax) 
2253                                                                       else if not (x=y) then raise Shape else (i+1,x :: ax))
2254                                       (0,[]) (xshape, yshape)
2255                   val newlength  = Index.length newshape
2256                   val newdata    = Array.array(newlength,Array.sub(xdata,0))
2257                in
2258                    Array.copy {src=xdata,dst=newdata,di=0};
2259                    Array.copy {src=ydata,dst=newdata,di=(Index.length xshape)};
2260                    {shape = newshape,
2261                     indexer = Index.indexer newshape,
2262                     data = newdata}
2263                end
2264        end)
2265
2266        fun new (shape, init) =
2267            if not (Index.validShape shape) then
2268                raise Shape
2269            else
2270                let val length = Index.length shape in
2271                    {shape = shape,
2272                     indexer = Index.indexer shape,
2273                     data = Array.array(length,init)}
2274                end
2275        fun toArray {shape, indexer, data} = data
2276        fun length {shape, indexer, data} =  Array.length data
2277        fun shape {shape, indexer, data} = shape
2278        fun rank t = List.length (shape t)
2279        fun reshape new_shape tensor =
2280            if Index.validShape new_shape then
2281                case (Index.length new_shape) = length tensor of
2282                    true => make'(new_shape, toArray tensor)
2283                  | false => raise Match
2284            else
2285                raise Shape
2286        fun fromArray (s, a) =
2287            case Index.validShape s andalso 
2288                 ((Index.length s) = (Array.length a)) of
2289                 true => make'(s, a)
2290               | false => raise Shape
2291        fun fromList (s, a) = fromArray (s, Array.fromList a)
2292        fun tabulate (shape,f) =
2293            if Index.validShape shape then
2294                let val last = Index.last shape
2295                    val length = Index.length shape
2296                    val c = Array.array(length, f last)
2297                    fun dotable (c, indices, i) =
2298                        (Array.update(c, i, f indices);
2299                         if i <= 1
2300                         then c
2301                         else dotable(c, Index.prev' shape indices, i-1))
2302                in make'(shape,dotable(c, Index.prev' shape last, length-2))
2303                end
2304            else
2305                raise Shape
2306        (*----- ELEMENTWISE OPERATIONS -----*)
2307        fun sub (t, index) = Array.sub(#data t, toInt t index)
2308        fun update (t, index, value) =
2309            Array.update(toArray t, toInt t index, value)
2310        fun map f {shape, indexer, data} =
2311            {shape = shape, indexer = indexer, data = Array.map f data}
2312        fun map2 f t1 t2=
2313            let val {shape=shape1, indexer=indexer1, data=data1} = t1
2314                val {shape=shape2, indexer=indexer2, data=data2} = t2
2315            in
2316                if Index.eq(shape1,shape2) then
2317                    {shape = shape1,
2318                     indexer = indexer1,
2319                     data = Array.map2 f data1 data2}
2320                else
2321                    raise Match
2322        end
2323        fun appi f tensor = Array.appi f (toArray tensor)
2324        fun app f tensor = Array.app f (toArray tensor)
2325        fun all f tensor =
2326            let val a = toArray tensor
2327            in Loop.all(0, length tensor - 1, fn i =>
2328                        f (Array.sub(a, i)))
2329            end
2330        fun any f tensor =
2331            let val a = toArray tensor
2332            in Loop.any(0, length tensor - 1, fn i =>
2333                        f (Array.sub(a, i)))
2334            end
2335        fun foldl f init tensor = Array.foldl f init (toArray tensor)
2336        fun foldln f init {shape, indexer, data=a} index =
2337            let val (head,lk,tail) = splitList(shape, index)
2338                val li = Index.length head
2339                val lj = Index.length tail
2340                val c = Array.array(li * lj,init)
2341                fun loopi (0, _,  _)  = ()
2342                  | loopi (i, ia, ic) =
2343                    (Array.update(c, ic, f(Array.sub(c,ic), Array.sub(a,ia)));
2344                     loopi (i-1, ia+1, ic+1))
2345                fun loopk (0, ia, _)  = ia
2346                  | loopk (k, ia, ic) = (loopi (li, ia, ic);
2347                                         loopk (k-1, ia+li, ic))
2348                fun loopj (0, _,  _)  = ()
2349                  | loopj (j, ia, ic) = loopj (j-1, loopk(lk,ia,ic), ic+li)
2350            in
2351                loopj (lj, 0, 0);
2352                make'(head @ tail, c)
2353            end
2354        (* --- POLYMORPHIC ELEMENTWISE OPERATIONS --- *)
2355        fun array_map' f a =
2356            let fun apply index = f(Array.sub(a,index)) in
2357                Tensor.Array.tabulate(Array.length a, apply)
2358            end
2359        fun map' f t = Tensor.fromArray(shape t, array_map' f (toArray t))
2360        fun map2' f t1 t2 =
2361            let val d1 = toArray t1
2362                val d2 = toArray t2
2363                fun apply i = f (Array.sub(d1,i), Array.sub(d2,i))
2364                val len = Array.length d1
2365            in
2366                if Index.eq(shape t1, shape t2) then
2367                    Tensor.fromArray(shape t1, Tensor.Array.tabulate(len,apply))
2368                else
2369                    raise Match
2370            end
2371        fun foldl' f init {shape, indexer, data=a} index =
2372            let val (head,lk,tail) = splitList(shape, index)
2373                val li = Index.length head
2374                val lj = Index.length tail
2375                val c = Tensor.Array.array(li * lj,init)
2376                fun loopi (0, _,  _)  = ()
2377                  | loopi (i, ia, ic) =
2378                    (Tensor.Array.update(c,ic,f(Tensor.Array.sub(c,ic),Array.sub(a,ia)));
2379                     loopi (i-1, ia+1, ic+1))
2380                fun loopk (0, ia, _)  = ia
2381                  | loopk (k, ia, ic) = (loopi (li, ia, ic);
2382                                         loopk (k-1, ia+li, ic))
2383                fun loopj (0, _,  _)  = ()
2384                  | loopj (j, ia, ic) = loopj (j-1, loopk(lk,ia,ic), ic+li)
2385            in
2386                loopj (lj, 0, 0);
2387                make'(head @ tail, c)
2388            end
2389    end
2390    end (* MonoTensor *)
2391        open MonoTensor
2392    local
2393        (*
2394         LEFT INDEX CONTRACTION:
2395         a = a(i1,i2,...,in)
2396         b = b(j1,j2,...,jn)
2397         c = c(i2,...,in,j2,...,jn)
2398         = sum(a(k,i2,...,jn)*b(k,j2,...jn)) forall k
2399         MEANINGFUL VARIABLES:
2400         lk = i1 = j1
2401         li = i2*...*in
2402         lj = j2*...*jn
2403         *)
2404        fun do_fold_first a b c lk lj li =
2405            let fun loopk (0, _,  _,  accum) = accum
2406                  | loopk (k, ia, ib, accum) =
2407                    let val delta = Number.*(Array.sub(a,ia),Array.sub(b,ib))
2408                    in loopk (k-1, ia+1, ib+1, Number.+(delta,accum))
2409                    end
2410                fun loopj (0, ib, ic) = c
2411                  | loopj (j, ib, ic) =
2412                    let fun loopi (0, ia, ic) = ic
2413                          | loopi (i, ia, ic) =
2414                        (Array.update(c, ic, loopk(lk, ia, ib, Number.zero));
2415                         loopi(i-1, ia+lk, ic+1))
2416                    in
2417                        loopj(j-1, ib+lk, loopi(li, 0, ic))
2418                    end
2419            in loopj(lj, 0, 0)
2420            end
2421    in
2422        fun +* ta tb =
2423            let val (rank_a,lk::rest_a,a) = (rank ta, shape ta, toArray ta)
2424                val (rank_b,lk2::rest_b,b) = (rank tb, shape tb, toArray tb)
2425            in if not(lk = lk2)
2426               then raise Match
2427               else let val li = Index.length rest_a
2428                        val lj = Index.length rest_b
2429                        val c = Array.array(li*lj,Number.zero)
2430                    in fromArray(rest_a @ rest_b,
2431                                 do_fold_first a b c lk li lj)
2432                    end
2433            end
2434    end
2435    local
2436        (*
2437         LAST INDEX CONTRACTION:
2438         a = a(i1,i2,...,in)
2439         b = b(j1,j2,...,jn)
2440         c = c(i2,...,in,j2,...,jn)
2441         = sum(mult(a(i1,i2,...,k),b(j1,j2,...,k))) forall k
2442         MEANINGFUL VARIABLES:
2443         lk = in = jn
2444         li = i1*...*i(n-1)
2445         lj = j1*...*j(n-1)
2446         *)
2447        fun do_fold_last a b c lk lj li =
2448            let fun loopi (0, ia, ic, fac) = ()
2449                  | loopi (i, ia, ic, fac) =
2450                    let val old = Array.sub(c,ic)
2451                        val inc = Number.*(Array.sub(a,ia),fac)
2452                    in
2453                        Array.update(c,ic,Number.+(old,inc));
2454                        loopi(i-1, ia+1, ic+1, fac)
2455                    end
2456                fun loopj (j, ib, ic) =
2457                    let fun loopk (0, ia, ib) = ()
2458                          | loopk (k, ia, ib) =
2459                            (loopi(li, ia, ic, Array.sub(b,ib));
2460                             loopk(k-1, ia+li, ib+lj))
2461                    in case j of
2462                           0 => c
2463                         | _ => (loopk(lk, 0, ib);
2464                                 loopj(j-1, ib+1, ic+li))
2465                    end (* loopj *)
2466            in
2467                loopj(lj, 0, 0)
2468            end
2469    in
2470        fun *+ ta tb  =
2471            let val (rank_a,shape_a,a) = (rank ta, shape ta, toArray ta)
2472                val (rank_b,shape_b,b) = (rank tb, shape tb, toArray tb)
2473                val (lk::rest_a) = List.rev shape_a
2474                val (lk2::rest_b) = List.rev shape_b
2475            in if not(lk = lk2)
2476               then raise Match
2477               else let val li = Index.length rest_a
2478                        val lj = Index.length rest_b
2479                        val c = Array.array(li*lj,Number.zero)
2480                    in fromArray(List.rev rest_a @ List.rev rest_b,
2481                                 do_fold_last a b c lk li lj)
2482                    end
2483            end
2484    end
2485        (* ALGEBRAIC OPERATIONS *)
2486        infix **
2487        infix ==
2488        infix !=
2489        fun a + b = map2 Number.+ a b
2490        fun a - b = map2 Number.- a b
2491        fun a * b = map2 Number.* a b
2492        fun a ** i = map (fn x => (Number.**(x,i))) a
2493        fun ~ a = map Number.~ a
2494        fun abs a = map Number.abs a
2495        fun signum a = map Number.signum a
2496        fun a == b = map2' Number.== a b
2497        fun a != b = map2' Number.!= a b
2498        fun toString a = raise Domain
2499        fun fromInt a = new([1], Number.fromInt a)
2500        (* TENSOR SPECIFIC OPERATIONS *)
2501        fun *> n = map (fn x => Number.*(n,x))
2502        fun a / b = map2 Number./ a b
2503        fun recip a = map Number.recip a
2504        fun ln a = map Number.ln a
2505        fun pow (a, b) = map (fn x => (Number.pow(x,b))) a
2506        fun exp a = map Number.exp a
2507        fun sqrt a = map Number.sqrt a
2508        fun cos a = map Number.cos a
2509        fun sin a = map Number.sin a
2510        fun tan a = map Number.tan a
2511        fun sinh a = map Number.sinh a
2512        fun cosh a = map Number.cosh a
2513        fun tanh a = map Number.tanh a
2514        fun asin a = map Number.asin a
2515        fun acos a = map Number.acos a
2516        fun atan a = map Number.atan a
2517        fun asinh a = map Number.asinh a
2518        fun acosh a = map Number.acosh a
2519        fun atanh a = map Number.atanh a
2520        fun atan2 (a,b) = map2 Number.atan2 a b
2521        fun normInf a =
2522            let fun accum (y,x) = Number.max(x,Number.abs y)
2523            in  foldl accum Number.zero a
2524            end
2525        fun norm1 a =
2526            let fun accum (y,x) = Number.+(x,Number.abs y)
2527            in  foldl accum Number.zero a
2528            end
2529        fun norm2 a =
2530            let fun accum (y,x) = Number.+(x, Number.*(y,y))
2531            in Number.sqrt(foldl accum Number.zero a)
2532            end
2533    end (* RTensor *)
2534structure CTensor =
2535struct
2536    structure Number = CNumber
2537    structure Array = CNumberArray
2538(*
2539 Copyright (c) Juan Jose Garcia Ripoll.
2540 All rights reserved.
2541 Refer to the COPYRIGHT file for license conditions
2542*)
2543structure MonoTensor  =
2544    struct
2545(* PARAMETERS
2546        structure Array = Array
2547*)
2548        structure Index  = Index
2549        type elem = Array.elem
2550        type index = Index.t
2551        type tensor = {shape : index, indexer : Index.indexer, data : Array.array}
2552        type t = tensor
2553        exception Shape
2554        exception Match
2555        exception Index
2556    local
2557    (*----- LOCALS -----*)
2558        fun make' (shape, data) =
2559            {shape = shape, indexer = Index.indexer shape, data = data}
2560        fun toInt {shape, indexer, data} index = indexer index
2561        fun splitList (l as (a::rest), place) =
2562            let fun loop (left,here,right) 0 =  (List.rev left,here,right)
2563                  | loop (_,_,[]) place = raise Index
2564                  | loop (left,here,a::right) place = 
2565                loop (here::left,a,right) (place-1)
2566            in
2567                if place <= 0 then
2568                    loop ([],a,rest) (List.length rest - place)
2569                else
2570                    loop ([],a,rest) (place - 1)
2571            end
2572    in
2573    (*----- STRUCTURAL OPERATIONS & QUERIES ------*)
2574        fun new (shape, init) =
2575            if not (Index.validShape shape) then
2576                raise Shape
2577            else
2578                let val length = Index.length shape in
2579                    {shape = shape,
2580                     indexer = Index.indexer shape,
2581                     data = Array.array(length,init)}
2582                end
2583        fun toArray {shape, indexer, data} = data
2584        fun length {shape, indexer, data} =  Array.length data
2585        fun shape {shape, indexer, data} = shape
2586        fun rank t = List.length (shape t)
2587        fun reshape new_shape tensor =
2588            if Index.validShape new_shape then
2589                case (Index.length new_shape) = length tensor of
2590                    true => make'(new_shape, toArray tensor)
2591                  | false => raise Match
2592            else
2593                raise Shape
2594        fun fromArray (s, a) =
2595            case Index.validShape s andalso 
2596                 ((Index.length s) = (Array.length a)) of
2597                 true => make'(s, a)
2598               | false => raise Shape
2599        fun fromList (s, a) = fromArray (s, Array.fromList a)
2600        fun tabulate (shape,f) =
2601            if Index.validShape shape then
2602                let val last = Index.last shape
2603                    val length = Index.length shape
2604                    val c = Array.array(length, f last)
2605                    fun dotable (c, indices, i) =
2606                        (Array.update(c, i, f indices);
2607                         if i <= 1
2608                         then c
2609                         else dotable(c, Index.prev' shape indices, i-1))
2610                in make'(shape,dotable(c, Index.prev' shape last, length-2))
2611                end
2612            else
2613                raise Shape
2614        (*----- ELEMENTWISE OPERATIONS -----*)
2615        fun sub (t, index) = Array.sub(#data t, toInt t index)
2616        fun update (t, index, value) =
2617            Array.update(toArray t, toInt t index, value)
2618        fun map f {shape, indexer, data} =
2619            {shape = shape, indexer = indexer, data = Array.map f data}
2620        fun map2 f t1 t2=
2621            let val {shape=shape1, indexer=indexer1, data=data1} = t1
2622                val {shape=shape2, indexer=indexer2, data=data2} = t2
2623            in
2624                if Index.eq(shape1,shape2) then
2625                    {shape = shape1,
2626                     indexer = indexer1,
2627                     data = Array.map2 f data1 data2}
2628                else
2629                    raise Match
2630        end
2631        fun appi f tensor = Array.appi f (toArray tensor, 0, NONE)
2632        fun app f tensor = Array.app f (toArray tensor)
2633        fun all f tensor =
2634            let val a = toArray tensor
2635            in Loop.all(0, length tensor - 1, fn i =>
2636                        f (Array.sub(a, i)))
2637            end
2638        fun any f tensor =
2639            let val a = toArray tensor
2640            in Loop.any(0, length tensor - 1, fn i =>
2641                        f (Array.sub(a, i)))
2642            end
2643        fun foldl f init tensor = Array.foldl f init (toArray tensor)
2644        fun foldln f init {shape, indexer, data=a} index =
2645            let val (head,lk,tail) = splitList(shape, index)
2646                val li = Index.length head
2647                val lj = Index.length tail
2648                val c = Array.array(li * lj,init)
2649                fun loopi (0, _,  _)  = ()
2650                  | loopi (i, ia, ic) =
2651                    (Array.update(c, ic, f(Array.sub(c,ic), Array.sub(a,ia)));
2652                     loopi (i-1, ia+1, ic+1))
2653                fun loopk (0, ia, _)  = ia
2654                  | loopk (k, ia, ic) = (loopi (li, ia, ic);
2655                                         loopk (k-1, ia+li, ic))
2656                fun loopj (0, _,  _)  = ()
2657                  | loopj (j, ia, ic) = loopj (j-1, loopk(lk,ia,ic), ic+li)
2658            in
2659                loopj (lj, 0, 0);
2660                make'(head @ tail, c)
2661            end
2662        (* --- POLYMORPHIC ELEMENTWISE OPERATIONS --- *)
2663        fun array_map' f a =
2664            let fun apply index = f(Array.sub(a,index)) in
2665                Tensor.Array.tabulate(Array.length a, apply)
2666            end
2667        fun map' f t = Tensor.fromArray(shape t, array_map' f (toArray t))
2668        fun map2' f t1 t2 =
2669            let val d1 = toArray t1
2670                val d2 = toArray t2
2671                fun apply i = f (Array.sub(d1,i), Array.sub(d2,i))
2672                val len = Array.length d1
2673            in
2674                if Index.eq(shape t1, shape t2) then
2675                    Tensor.fromArray(shape t1, Tensor.Array.tabulate(len,apply))
2676                else
2677                    raise Match
2678            end
2679        fun foldl' f init {shape, indexer, data=a} index =
2680            let val (head,lk,tail) = splitList(shape, index)
2681                val li = Index.length head
2682                val lj = Index.length tail
2683                val c = Tensor.Array.array(li * lj,init)
2684                fun loopi (0, _,  _)  = ()
2685                  | loopi (i, ia, ic) =
2686                    (Tensor.Array.update(c,ic,f(Tensor.Array.sub(c,ic),Array.sub(a,ia)));
2687                     loopi (i-1, ia+1, ic+1))
2688                fun loopk (0, ia, _)  = ia
2689                  | loopk (k, ia, ic) = (loopi (li, ia, ic);
2690                                         loopk (k-1, ia+li, ic))
2691                fun loopj (0, _,  _)  = ()
2692                  | loopj (j, ia, ic) = loopj (j-1, loopk(lk,ia,ic), ic+li)
2693            in
2694                loopj (lj, 0, 0);
2695                make'(head @ tail, c)
2696            end
2697    end
2698    end (* MonoTensor *)
2699    open MonoTensor
2700    local
2701        (*
2702         LEFT INDEX CONTRACTION:
2703         a = a(i1,i2,...,in)
2704         b = b(j1,j2,...,jn)
2705         c = c(i2,...,in,j2,...,jn)
2706         = sum(a(k,i2,...,jn)*b(k,j2,...jn)) forall k
2707         MEANINGFUL VARIABLES:
2708         lk = i1 = j1
2709         li = i2*...*in
2710         lj = j2*...*jn
2711         *)
2712        fun do_fold_first a b c lk lj li =
2713            let fun loopk (0, _, _, r, i) = Number.make(r,i)
2714                  | loopk (k, ia, ib, r, i) =
2715                    let val (ar, ai) = Array.sub(a,ia)
2716                        val (br, bi) = Array.sub(b,ib)
2717                        val dr = ar * br - ai * bi
2718                        val di = ar * bi + ai * br
2719                    in loopk (k-1, ia+1, ib+1, r+dr, i+di)
2720                    end
2721                fun loopj (0, ib, ic) = c
2722                  | loopj (j, ib, ic) =
2723                    let fun loopi (0, ia, ic) = ic
2724                          | loopi (i, ia, ic) =
2725                            (Array.update(c, ic, loopk(lk, ia, ib, RNumber.zero, RNumber.zero));
2726                             loopi(i-1, ia+lk, ic+1))
2727                    in loopj(j-1, ib+lk, loopi(li, 0, ic))
2728                    end
2729            in loopj(lj, 0, 0)
2730            end
2731    in
2732        fun +* ta tb =
2733            let val (rank_a,lk::rest_a,a) = (rank ta, shape ta, toArray ta)
2734                val (rank_b,lk2::rest_b,b) = (rank tb, shape tb, toArray tb)
2735            in if not(lk = lk2)
2736               then raise Match
2737               else let val li = Index.length rest_a
2738                        val lj = Index.length rest_b
2739                        val c = Array.array(li*lj,Number.zero)
2740                    in fromArray(rest_a @ rest_b, do_fold_first a b c lk li lj)
2741                    end
2742            end
2743    end
2744    local
2745        (*
2746         LAST INDEX CONTRACTION:
2747         a = a(i1,i2,...,in)
2748         b = b(j1,j2,...,jn)
2749         c = c(i2,...,in,j2,...,jn)
2750         = sum(mult(a(i1,i2,...,k),b(j1,j2,...,k))) forall k
2751         MEANINGFUL VARIABLES:
2752         lk = in = jn
2753         li = i1*...*i(n-1)
2754         lj = j1*...*j(n-1)
2755         *)
2756        fun do_fold_last a b c lk lj li =
2757            let fun loopi(0, _, _, _, _) = ()
2758                  | loopi(i, ia, ic, br, bi) =
2759                    let val (cr,ci) = Array.sub(c,ic)
2760                        val (ar,ai) = Array.sub(a,ia)
2761                        val dr = (ar * br - ai * bi)
2762                        val di = (ar * bi + ai * br)
2763                    in
2764                        Array.update(c,ic,Number.make(cr+dr,ci+di));
2765                        loopi(i-1, ia+1, ic+1, br, bi)
2766                    end
2767                fun loopj(j, ib, ic) =
2768                    let fun loopk(0, _, _) = ()
2769                          | loopk(k, ia, ib) =
2770                            let val (br, bi) = Array.sub(b,ib)
2771                            in
2772                                loopi(li, ia, ic, br, bi);
2773                                loopk(k-1, ia+li, ib+lj)
2774                            end
2775                in case j of
2776                    0 => c
2777                  | _ => (loopk(lk, 0, ib);
2778                          loopj(j-1, ib+1, ic+li))
2779                end (* loopj *)
2780            in
2781                loopj(lj, 0, 0)
2782            end
2783    in
2784        fun *+ ta tb  =
2785            let val (rank_a,shape_a,a) = (rank ta, shape ta, toArray ta)
2786                val (rank_b,shape_b,b) = (rank tb, shape tb, toArray tb)
2787                val (lk::rest_a) = List.rev shape_a
2788                val (lk2::rest_b) = List.rev shape_b
2789            in
2790                if not(lk = lk2) then
2791                    raise Match
2792                else
2793                    let val li = Index.length rest_a
2794                        val lj = Index.length rest_b
2795                        val c = Array.array(li*lj,Number.zero)
2796                    in
2797                        fromArray(List.rev rest_a @ List.rev rest_b,
2798                                  do_fold_last a b c lk li lj)
2799                    end
2800            end
2801    end
2802    (* ALGEBRAIC OPERATIONS *)
2803    infix **
2804    infix ==
2805    infix !=
2806    fun a + b = map2 Number.+ a b
2807    fun a - b = map2 Number.- a b
2808    fun a * b = map2 Number.* a b
2809    fun a ** i = map (fn x => (Number.**(x,i))) a
2810    fun ~ a = map Number.~ a
2811    fun abs a = map Number.abs a
2812    fun signum a = map Number.signum a
2813    fun a == b = map2' Number.== a b
2814    fun a != b = map2' Number.!= a b
2815    fun toString a = raise Domain
2816    fun fromInt a = new([1], Number.fromInt a)
2817    (* TENSOR SPECIFIC OPERATIONS *)
2818    fun *> n = map (fn x => Number.*(n,x))
2819    fun a / b = map2 Number./ a b
2820    fun recip a = map Number.recip a
2821    fun ln a = map Number.ln a
2822    fun pow (a, b) = map (fn x => (Number.pow(x,b))) a
2823    fun exp a = map Number.exp a
2824    fun sqrt a = map Number.sqrt a
2825    fun cos a = map Number.cos a
2826    fun sin a = map Number.sin a
2827    fun tan a = map Number.tan a
2828    fun sinh a = map Number.sinh a
2829    fun cosh a = map Number.cosh a
2830    fun tanh a = map Number.tanh a
2831    fun asin a = map Number.asin a
2832    fun acos a = map Number.acos a
2833    fun atan a = map Number.atan a
2834    fun asinh a = map Number.asinh a
2835    fun acosh a = map Number.acosh a
2836    fun atanh a = map Number.atanh a
2837    fun atan2 (a,b) = map2 Number.atan2 a b
2838    fun normInf a =
2839        let fun accum (y,x) = RNumber.max(x, Number.realPart(Number.abs y))
2840        in  foldl accum RNumber.zero a
2841        end
2842    fun norm1 a =
2843        let fun accum (y,x) = RNumber.+(x, Number.realPart(Number.abs y))
2844        in  foldl accum RNumber.zero a
2845        end
2846    fun norm2 a =
2847        let fun accum (y,x) = RNumber.+(x, Number.abs2 y)
2848        in RNumber.sqrt(foldl accum RNumber.zero a)
2849        end
2850end (* CTensor *)
2851
2852
2853
2854structure TensorFile =
2855struct
2856
2857type file = TextIO.instream
2858
2859exception Data
2860
2861fun assert NONE = raise Data
2862  | assert (SOME a) = a
2863
2864(* ------------------ INPUT --------------------- *)
2865
2866fun intRead file = assert(TextIO.scanStream INumber.scan file)
2867fun realRead file = assert(TextIO.scanStream RNumber.scan file)
2868fun complexRead file = assert(TextIO.scanStream CNumber.scan file)
2869
2870fun listRead eltScan file =
2871    let val length = intRead file
2872        fun eltRead file = assert(TextIO.scanStream eltScan file)
2873        fun loop (0,accum) = accum
2874          | loop (i,accum) = loop(i-1, eltRead file :: accum)
2875    in
2876        if length < 0
2877        then raise Data
2878        else List.rev(loop(length,[]))
2879    end
2880
2881fun intListRead file = listRead INumber.scan file
2882fun realListRead file = listRead RNumber.scan file
2883fun complexListRead file = listRead CNumber.scan file
2884
2885fun intTensorRead file =
2886    let val shape = intListRead file
2887        val length = Index.length shape
2888        val first = intRead file
2889        val a = ITensor.Array.array(length, first)
2890        fun loop 0 = ITensor.fromArray(shape, a)
2891          | loop j = (ITensor.Array.update(a, length-j, intRead file);
2892                      loop (j-1))
2893    in loop (length - 1)
2894    end
2895
2896fun realTensorRead file =
2897    let val shape = intListRead file
2898        val length = Index.length shape
2899        val first = realRead file
2900        val a = RTensor.Array.array(length, first)
2901        fun loop 0 = RTensor.fromArray(shape, a)
2902          | loop j = (RTensor.Array.update(a, length-j, realRead file);
2903                      loop (j-1))
2904    in loop (length - 1)
2905    end
2906
2907fun complexTensorRead file =
2908    let val shape = intListRead file
2909        val length = Index.length shape
2910        val first = complexRead file
2911        val a = CTensor.Array.array(length, first)
2912        fun loop j = if j = length
2913                     then CTensor.fromArray(shape, a)
2914                     else (CTensor.Array.update(a, j, complexRead file);
2915                           loop (j+1))
2916    in loop 1
2917    end
2918
2919(* ------------------ OUTPUT -------------------- *)
2920fun linedOutput(file, x) = (TextIO.output(file, x); TextIO.output(file, "\n"))
2921
2922fun intWrite file x = linedOutput(file, INumber.toString x)
2923fun realWrite file x = linedOutput(file, RNumber.toString x)
2924fun complexWrite file x =
2925    let val (r,i) = CNumber.split x
2926    in linedOutput(file, concat [RNumber.toString r, " ", RNumber.toString i])
2927    end
2928
2929fun listWrite converter file x =
2930    (intWrite file (length x);
2931     List.app (fn x => (linedOutput(file, converter x))) x)
2932
2933fun intListWrite file x = listWrite INumber.toString file x
2934fun realListWrite file x = listWrite RNumber.toString file x
2935fun complexListWrite file x = listWrite CNumber.toString file x
2936
2937fun intTensorWrite file x = (intListWrite file (ITensor.shape x); ITensor.app (fn x => (intWrite file x)) x)
2938fun realTensorWrite file x = (intListWrite file (RTensor.shape x); RTensor.app (fn x => (realWrite file x)) x)
2939fun complexTensorWrite file x = (intListWrite file (CTensor.shape x); CTensor.app (fn x => (complexWrite file x)) x)
2940end
2941
2942
2943structure RandomTensor =
2944struct
2945
2946fun realRandomTensor (xseed,yseed) shape =
2947    let 
2948        val length = Index.length shape
2949        val seed   = Random.rand (xseed,yseed)
2950        val a      = RTensor.Array.array(length, Random.randReal seed)
2951        fun loop 0 = RTensor.fromArray(shape, a)
2952          | loop j = (RTensor.Array.update(a, length-j, Random.randReal seed);
2953                      loop (j-1))
2954    in loop (length - 1)
2955    end
2956
2957fun intRandomTensor (xseed,yseed) shape =
2958    let 
2959        val length = Index.length shape
2960        val seed   = Random.rand (xseed,yseed)
2961        val a      = ITensor.Array.array(length, Random.randInt seed)
2962        fun loop 0 = ITensor.fromArray(shape, a)
2963          | loop j = (ITensor.Array.update(a, length-j, Random.randInt seed);
2964                      loop (j-1))
2965    in loop (length - 1)
2966    end
2967
2968
2969end
2970
2971
2972val Ne = 8
2973val Ni = 2
2974
2975
2976val S  = RTensor.cat (1, 
2977                      (RTensor.*> 0.5 (RandomTensor.realRandomTensor (13,17) [(Ne+Ni),Ne]) ),
2978                      (RTensor.~ (RandomTensor.realRandomTensor (19,23) [(Ne+Ni),Ni])))
2979
2980val _ = TensorFile.realTensorWrite (TextIO.stdOut) S
2981
2982
Note: See TracBrowser for help on using the repository browser.