source: project/release/4/flsim/trunk/sml-lib/tensor/sparse.sml @ 29960

Last change on this file since 29960 was 29960, checked in by Ivan Raikov, 8 years ago

flsim: added tensor library to sml support files

File size: 39.1 KB
Line 
1
2(*
3 Copyright 2013 Ivan Raikov.
4 All rights reserved.
5
6Redistribution and use in source and binary forms, with or
7without modification, are permitted provided that the following
8conditions are met:
9
101. Redistributions of source code must retain the above copyright
11   notice, this list of conditions and the following disclaimer.
12
132. Redistributions in binary form must reproduce the above
14   copyright notice, this list of conditions and the following
15   disclaimer in the documentation and/or other materials provided
16   with the distribution.
17
18*)
19
20
21
22fun putStr out str = 
23    TextIO.output (out, str)
24
25fun putStrLn out str = 
26    (TextIO.output (out, str);
27     TextIO.output (out, "\n"))
28
29
30
31signature SPARSE_INDEX =
32    sig
33        type t
34        type array = IntArray.array
35        type nonzero = { indptr: array, indices: array }
36        type indexer = t -> int option
37        datatype storage = CSR | CSC
38
39        exception Index
40        exception Shape
41
42        val order : storage
43        val toInt : t -> nonzero -> t -> int option
44
45        val inBounds : t -> t -> bool
46
47        val app : t -> nonzero -> (t -> unit) -> unit
48
49    end
50
51
52(*
53 MONO_SPARSE            - Signature -
54
55 Monomorphic sparse matrices.
56
57 structure Number : NUMBER
58        Structure that describe type type of the elements of the
59        matrix and their operations (+,*,etc)
60
61 type elem = Number.t
62        The type of the elements of the matrix
63
64 structure Tensor : MONO_TENSOR
65        Tensors of 'elem' type.
66
67 fromTensor  [[number,number,...]*]
68 fromTensor' [[number,number,...]*]
69        Builds a sparse matrix up from a tensor.
70        fromTensor converts the tensor to sparse format;
71        fromTensor uses it as-is
72
73 fromTensorList  [[number,number,...]*]
74        Builds a sparse matrix up from a list of tensors.
75
76 sub (matrix,row,column)
77 update (matrix,row,column,value)
78        Retrieves or sets an element from/to a sparse matrix
79
80 map op matrix
81 mapi op matrix
82 app op matrix
83 appi op matrix
84        Maps or applies 'op' to the elements of a matrix. 'mapi'
85        and 'appi' have the pecularity that 'op' receives the
86        row and the column indices plust the value.
87
88 matrix + matrix
89 matrix - matrix
90 matrix * matrix
91 matrix / matrix
92 ~ matrix
93        Elementwise operations.
94*)
95
96signature MONO_SPARSE_MATRIX =
97    sig
98        structure Tensor : MONO_TENSOR
99        structure TensorSlice : MONO_TENSOR_SLICE
100        structure Number : NUMBER
101        structure Index : SPARSE_INDEX
102
103        type index = Index.t
104        type elem = Number.t
105        type matrix
106
107        exception Data and Shape
108
109        val fromTensor : index -> (Tensor.tensor * (index option)) -> matrix
110        val fromTensorList : index -> {tensor: Tensor.tensor, offset: index, sparse: bool} list -> matrix
111        val fromGenerator : index -> ((index -> elem) * index * (index option)) -> matrix
112        val fromGeneratorList : index -> ({f: (index -> elem), fshape: index, offset: index} list) -> matrix
113        val insert : matrix * matrix -> matrix
114
115        val shape : matrix -> index
116
117        val sub : matrix * index -> elem
118        val update : matrix * index * elem -> unit
119
120        val map : (elem -> elem) -> matrix -> matrix
121        val app : (elem -> unit) -> matrix -> unit
122
123        datatype slice = SLSPARSE of {offset: index, indices: Index.array, data: Tensor.tensor}
124                       | SLDENSE  of {offset: index, data: TensorSlice.slice}
125        val slice : (matrix * int * int) ->  slice list
126        val sliceAppi: ((int * elem) -> unit) -> slice list -> unit
127        val sliceFoldi: ((int * elem * 'a) -> 'a) -> 'a -> slice list -> 'a
128
129(*
130        val mapi : (index * elem -> elem) -> matrix -> matrix
131        val appi : (index * elem -> unit) -> matrix -> unit
132        val map2 : (elem * elem -> elem) -> matrix -> matrix -> matrix
133
134
135        val + : matrix * matrix -> matrix
136        val - : matrix * matrix -> matrix
137        val * : matrix * matrix -> matrix
138        val / : matrix * matrix -> matrix
139        val ~ : matrix -> matrix
140*)
141    end
142
143
144
145structure SparseIndex =
146    struct
147
148        type t = int list
149        type array = IntArray.array
150        type nonzero = { indptr: array, indices: array }
151        type indexer = t -> int option
152        datatype storage = CSR | CSC
153                           
154        exception Index
155        exception Shape
156
157        val order = CSC
158
159        fun validShape shape = List.all (fn x => x > 0) shape
160        fun validIndex index = List.all (fn x => x >= 0) index
161
162        val sub = IntArray.sub
163
164        fun findFromTo (i,v,s,e) =
165            let fun loop (j) = 
166                    if ((j >= s) andalso (j < e)) 
167                    then (if (sub (v,j) = i) then SOME j else loop (j+1))
168                    else NONE
169            in
170                loop s
171            end
172           
173        fun inBounds shape index =
174            ListPair.all (fn (x,y) => (x >= 0) andalso (x < y))
175            (index, shape)
176
177
178        fun toInt shape {indptr, indices} index  =
179            let
180                val nptr = IntArray.length indptr
181                val nind = IntArray.length indices
182            in
183                case order of 
184                    CSR => 
185                    (case (index, shape) of
186                         ([i,j],[s,rs]) => 
187                         let
188                             val s = sub (indptr, j)
189                             val e = if (j < (nptr-1)) then sub (indptr,j+1) else nind
190                         in
191                             findFromTo (i, indices, s, e)
192                         end
193                       | ([],[]) => SOME 0
194                       | (_,_)   => raise Index)
195                  | CSC => 
196                    (case (index, shape) of
197                         ([i,j],[s,rs]) => 
198                         if (i >= 0) andalso (i < s) 
199                         then
200                             (let
201                                  val s = sub (indptr,i)
202                                  val e = if (i < (nptr-1)) then sub (indptr,i+1) else nind
203                              in
204                                  findFromTo (j, indices, s, e)
205                              end)
206                         else raise Index
207                       | ([],[]) => SOME 0
208                       | (_,_)   => raise Index)
209                   
210            end
211
212        fun app shape {indptr, indices}   =
213            (case order of 
214                 CSR => 
215                 (let 
216                      val ni = IntArray.length indices
217                      val nj = IntArray.length indptr
218                      fun iterator j f =
219                          if (j < nj)
220                          then (let 
221                                    val jj  = sub (indptr, j)
222                                    val jj' = sub (indptr, if (j < (nj-1)) then j+1 else ni-1)
223                                in
224                                    (List.app f (List.tabulate (jj'-jj, fn (jk) => [sub(indices,jj+jk),j]));
225                                     iterator (j+1) f)
226                                end)
227                          else ()
228                  in
229                      iterator 0
230                  end)
231               | CSC =>
232                 (let 
233                      val nj = IntArray.length indices
234                      val ni = IntArray.length indptr
235                      fun iterator i f =
236                          if (i < ni)
237                          then (let 
238                                    val ii  = sub (indptr, i)
239                                    val ii' = sub (indptr, if (i < (ni-1)) then i+1 else nj-1)
240                                in
241                                    (List.app f (List.tabulate (ii'-ii, fn (ik) => [i,sub(indices,ii+ik)]));
242                                     iterator (i+1) f)
243                                end)
244                          else ()
245                  in
246                      iterator 0
247                  end)
248            )
249
250    end
251
252
253
254structure SparseMatrix : MONO_SPARSE_MATRIX =
255
256struct
257    structure Tensor : MONO_TENSOR = RTensor
258    structure TensorSlice : MONO_TENSOR_SLICE = RTensorSlice
259    structure Number = RTensor.Number
260    structure Index = SparseIndex
261
262    type index   = Index.t
263    type nonzero = Index.nonzero
264    type elem    = Number.t
265    datatype block = 
266             SPARSE of {offset: index, shape: index, nz: nonzero, data: elem array}
267           | DENSE of {offset: index, data: Tensor.tensor}
268
269    type matrix  = {shape: index, blocks: block list}
270
271    datatype slice = SLSPARSE of {offset: index, indices: Index.array, data: Tensor.tensor}
272                   | SLDENSE  of {offset: index, data: TensorSlice.slice}
273
274    exception Data
275    exception Shape
276    exception Index
277    exception Overlap
278
279    (* --- LOCALS --- *)
280
281    fun dimVals [m,n] = (m,n) | dimVals _ = raise Shape
282
283    fun array_map f a =
284        let fun apply index = f(Array.sub(a,index)) in
285            Array.tabulate(Array.length a, apply)
286        end
287
288    fun array_mapi f a =
289        let fun apply index = f(index,Array.sub(a,index)) in
290            Array.tabulate(Array.length a, apply)
291        end
292
293    fun findBlock (i,j,blocks) =
294        let
295            val block = List.find
296                            (fn (SPARSE {offset=offset, shape=shape, nz, data}) =>
297                                (let
298                                    val (u,v) = dimVals offset
299                                    val (t,s) = dimVals shape
300                                in
301                                    ((j>=v) andalso (j-v<s) andalso
302                                     (i>=u) andalso (i-u<t))
303                                end)
304                                | (DENSE {offset=offset, data}) =>
305                                (let
306                                    val (u,v) = dimVals offset
307                                    val (t,s) = dimVals (Tensor.shape data)
308                                in
309                                    ((j>=v) andalso (j-v<s) andalso
310                                     (i>=u) andalso (i-u<t))
311                                end))
312                            blocks
313        in
314            block
315        end
316
317    fun intArrayWrite file x = Array.app (fn (i) => putStr file ((Int.toString i) ^ " ")) x
318
319    (* --- CONSTRUCTORS --- *)
320
321    fun fromTensor shape (a: Tensor.tensor, offset) = 
322        (let 
323             val shape_a = Tensor.shape a
324             val (rows,cols) = dimVals shape_a
325        in
326            case Index.order of
327                Index.CSC =>
328                let 
329                    val v0: (int * elem) list = []
330                    val data: (((int * elem) DynArray.array) option) Array.array  = 
331                        Array.array(cols,NONE)
332                    val nzcount = ref 0
333                    val _ = Tensor.Index.app (List.rev shape_a)
334                                      (fn (i) => 
335                                          let 
336                                              val v = Tensor.sub (a, i)
337                                          in
338                                              if not (Number.== (v, Number.zero))
339                                              then
340                                                  let 
341                                                      val (irow,icol) = dimVals i
342                                                      val colv  = Array.sub (data, icol)
343                                                      (*val col' = (irow,v) :: col*)
344                                                  in
345                                                      (case colv of
346                                                          SOME col => 
347                                                          (DynArray.update(col,DynArray.length col,(irow,v)))
348                                                        | NONE => Array.update (data, icol, SOME (DynArray.fromList [(irow,v)]));
349                                                       nzcount := (!nzcount) + 1)
350                                                  end
351                                              else ()
352                                          end)
353                    val data'   = Array.array (!nzcount, Number.zero)
354                    val indices = IntArray.array (!nzcount, 0)
355                    val indptr  = IntArray.array (cols, 0)
356                    val update  = IntArray.update
357                    val fi      = Array.foldli
358                                      (fn (n,SOME cols,i) => 
359                                          let 
360                                              val i' = DynArray.foldr
361                                                           (fn ((rowind,v),i) => 
362                                                               (Array.update (data',i,v); 
363                                                                update (indices,i,rowind); 
364                                                                i+1))
365                                                           i cols
366                                          in
367                                              (update (indptr,n,i); i')
368                                          end
369                                      | (n,NONE,i) => (update (indptr,n,i); i))
370                                      0 data
371                in
372                    {shape=shape,
373                     blocks=[SPARSE {offset=case offset of NONE => [0, 0] | SOME i => i, 
374                                     shape=shape_a, nz={ indptr= indptr, indices=indices }, data=data'}]}
375                end
376              | Index.CSR => 
377                let 
378                    val v0: (int * elem) list = []
379                    val data: (((int * elem) DynArray.array) option) Array.array  = 
380                        Array.array(rows,NONE)
381                    val nzcount = ref 0
382                    val _ = Tensor.Index.app shape_a
383                                              (fn (i) => 
384                                                  let 
385                                                      val v = Tensor.sub (a, i)
386                                                  in
387                                                      if not (Number.== (v, Number.zero))
388                                                      then
389                                                          let 
390                                                              val (irow,icol) = dimVals i
391                                                              val rowv  = Array.sub (data, irow)
392                                                              (*val row' = (icol,v) :: row*)
393                                                          in
394                                                              (case rowv of
395                                                                   (*Array.update(data,irow,row');*)
396                                                                   SOME row => DynArray.update (row,DynArray.length row,(icol,v))
397                                                                 | NONE => Array.update (data, irow, SOME (DynArray.fromList [(icol,v)]));
398                                                               nzcount := (!nzcount) + 1)
399                                                          end
400                                                      else ()
401                                                  end)
402                    val data'   = Array.array (!nzcount, Number.zero)
403                    val indices = IntArray.array (!nzcount, 0)
404                    val indptr  = IntArray.array (rows, 0)
405                    val update  = IntArray.update
406                    val fi      = Array.foldli
407                                      (fn (n,SOME rows,i) => 
408                                          let 
409                                              val i' = DynArray.foldr
410                                                           (fn ((colind,v),i) => 
411                                                               (Array.update (data',i,v); 
412                                                                update (indices,i,colind); 
413                                                                i+1))
414                                                           i rows
415                                          in
416                                              (update (indptr,n,i); i')
417                                          end
418                                      | (n,NONE,i) => (update (indptr,n,i); i))
419                                      0 data
420                in
421                    {shape=shape, 
422                     blocks=[SPARSE {offset = case offset of NONE => [0,0] | SOME i => i, 
423                                     shape=shape_a, nz={ indptr= indptr, indices=indices }, data=data'}]}
424                end
425        end)
426
427
428    fun fromTensor' shape (a: Tensor.tensor, offset) = 
429        let 
430            val shape_a = Tensor.shape a
431            val (rows,cols) = dimVals shape_a
432        in
433            {shape=shape, blocks=[DENSE {offset=(case offset of NONE => [0, 0] | SOME i => i), 
434                                         data=a}]}
435        end
436
437
438    fun fromGenerator shape (f: index -> elem, fshape, offset) = 
439        (let 
440             val (rows,cols) = dimVals fshape
441        in
442            case Index.order of
443                Index.CSC =>
444                let 
445                    val v0: (int * elem) list = []
446                    val data: (((int * elem) DynArray.array) option) Array.array  = 
447                        Array.array(cols,NONE)
448                    val nzcount = ref 0
449                    val _ = Tensor.Index.app
450                                (List.rev fshape)
451                                (fn (i) => 
452                                    let 
453                                        val v = f (i)
454                                    in
455                                        if not (Number.== (v, Number.zero))
456                                        then
457                                            let 
458                                                val (irow,icol) = dimVals i
459                                                val colv  = Array.sub (data, icol)
460                                            (*val col' = (irow,v) :: col*)
461                                            in
462                                                (case colv of
463                                                     SOME col => 
464                                                     (DynArray.update(col,DynArray.length col,(irow,v)))
465                                                   | NONE => Array.update (data, icol, SOME (DynArray.fromList [(irow,v)]));
466                                                 nzcount := (!nzcount) + 1)
467                                            end
468                                        else ()
469                                    end)
470                    val data'   = Array.array (!nzcount, Number.zero)
471                    val indices = IntArray.array (!nzcount, 0)
472                    val indptr  = IntArray.array (cols, 0)
473                    val update  = IntArray.update
474                    val fi      = Array.foldli
475                                      (fn (n,SOME cols,i) => 
476                                          let 
477                                              val i' = DynArray.foldr
478                                                           (fn ((rowind,v),i) => 
479                                                               (Array.update (data',i,v); 
480                                                                update (indices,i,rowind); 
481                                                                i+1))
482                                                           i cols
483                                          in
484                                              (update (indptr,n,i); i')
485                                          end
486                                      | (n,NONE,i) => (update (indptr,n,i); i))
487                                      0 data
488                in
489                    {shape=shape,
490                     blocks=[SPARSE {offset=case offset of NONE => [0, 0] | SOME i => i, 
491                                     shape=fshape, nz={ indptr= indptr, indices=indices }, data=data'}]}
492                end
493              | Index.CSR => 
494                let 
495                    val v0: (int * elem) list = []
496                    val data: (((int * elem) DynArray.array) option) Array.array  = 
497                        Array.array(rows,NONE)
498                    val nzcount = ref 0
499                    val _ = Tensor.Index.app fshape
500                                              (fn (i) => 
501                                                  let 
502                                                      val v = f (i)
503                                                  in
504                                                      if not (Number.== (v, Number.zero))
505                                                      then
506                                                          let 
507                                                              val (irow,icol) = dimVals i
508                                                              val rowv  = Array.sub (data, irow)
509                                                              (*val row' = (icol,v) :: row*)
510                                                          in
511                                                              (case rowv of
512                                                                   (*Array.update(data,irow,row');*)
513                                                                   SOME row => DynArray.update (row,DynArray.length row,(icol,v))
514                                                                 | NONE => Array.update (data, irow, SOME (DynArray.fromList [(icol,v)]));
515                                                               nzcount := (!nzcount) + 1)
516                                                          end
517                                                      else ()
518                                                  end)
519                    val data'   = Array.array (!nzcount, Number.zero)
520                    val indices = IntArray.array (!nzcount, 0)
521                    val indptr  = IntArray.array (rows, 0)
522                    val update  = IntArray.update
523                    val fi      = Array.foldli
524                                      (fn (n,SOME rows,i) => 
525                                          let 
526                                              val i' = DynArray.foldr
527                                                           (fn ((colind,v),i) => 
528                                                               (Array.update (data',i,v); 
529                                                                update (indices,i,colind); 
530                                                                i+1))
531                                                           i rows
532                                          in
533                                              (update (indptr,n,i); i')
534                                          end
535                                      | (n,NONE,i) => (update (indptr,n,i); i))
536                                      0 data
537                in
538                    {shape=shape, 
539                     blocks=[SPARSE {offset = case offset of NONE => [0,0] | SOME i => i, 
540                                     shape=fshape, nz={ indptr= indptr, indices=indices }, data=data'}]}
541                end
542        end)
543
544
545    fun insertBlock ({shape, blocks},b',boffset) =
546        let
547            val (i,j) = dimVals boffset
548            val (m,n) = dimVals (case b' of
549                                     SPARSE {offset, shape, nz, data} => shape
550                                   | DENSE {offset, data} => Tensor.shape data)
551
552                             
553
554            val blocks' = 
555                let
556                    fun merge ([], []) = [b']
557                      | merge (b::rst, ax) =
558                        let 
559                            val (bm,bn) = dimVals (case b of 
560                                                       SPARSE {offset,shape=shape,nz,data} => shape
561                                                     | DENSE {offset,data} => Tensor.shape data)
562                            val (bi,bj) = dimVals (case b of
563                                                       SPARSE {offset=offset,shape,nz,data} => offset
564                                                     | DENSE {offset=offset,data} => offset)
565
566                        in
567                            if (j < bj)
568                            then List.rev (rst@(b::b'::ax))
569                            else (if (i >= bi) andalso (j >= bj) andalso 
570                                     (i+m <= bi+bm) andalso (j+n <= bj+bn)
571                                  then raise Overlap
572                                  else merge (rst, b::ax))
573                        end
574                      | merge ([], ax) = List.rev (b'::ax)
575                in
576                    merge (blocks, [])
577                end
578        in
579                {shape=shape, blocks=blocks'}
580        end
581
582           
583    fun insertTensor (S as {shape, blocks},t,offset) =
584        let
585            val (i,j) = dimVals offset
586            val {shape=_, blocks=bs} = fromTensor shape (t,SOME [i,j])
587            val b': block = case bs of
588                                [b] => b
589                              | _ => raise Match
590        in
591            insertBlock (S,b',offset)
592        end
593
594    fun insertTensor' (S as {shape, blocks},t,offset) =
595        let
596            val (i,j) = dimVals offset
597            val {shape=_, blocks=bs} = fromTensor' shape (t,SOME [i,j])
598            val b': block = case bs of
599                                [b] => b
600                              | _ => raise Match
601        in
602            insertBlock (S,b',offset)
603        end
604                                               
605                 
606    (* Builds a sparse matrix from a list of the form:
607 
608       {tensor,offset=[xoffset,yoffset],sparse) ...
609
610     where xoffset and yoffset are the positions where to insert the
611     given tensor. The tensors to be inserted must be non-overlapping.
612     sparse is a boolean flag that indicates whether the tensor should
613     be converted to sparse form.
614    *)
615
616    fun fromTensorList shape (al: ({tensor: Tensor.tensor, offset: index, sparse: bool}) list) = 
617        (case al of
618             {tensor,offset,sparse}::rst => 
619             (List.foldl (fn ({tensor,offset,sparse},ax) => 
620                             if sparse
621                             then insertTensor (ax,tensor,offset)
622                             else insertTensor' (ax,tensor,offset))
623                         (if sparse
624                          then fromTensor shape (tensor, SOME offset)
625                          else fromTensor' shape (tensor, SOME offset) )
626                         rst)
627           | _ => raise Match)
628
629
630    fun fromGeneratorList shape (gg: ({f: index -> elem, fshape: index, offset: index}) list) = 
631        case gg of 
632            ({f,fshape,offset}::rst) =>
633            (List.foldl
634                 (fn ({f,fshape,offset},S) => 
635                     let
636                         val {shape=_, blocks=bs} = fromGenerator shape (f,fshape,SOME offset)
637                         val b': block = case bs of
638                                             [b] => b
639                                           | _ => raise Match
640                     in
641                         insertBlock (S,b',offset)
642                 end)
643                 (fromGenerator shape (f,fshape,SOME offset)) rst)
644            | _ => raise Match
645
646
647    fun insert (x: matrix, y: matrix) =
648        let
649            val {shape=_, blocks=bs} = x
650        in
651            foldl (fn (b,S) => 
652                      let
653                          val offset = (case b of 
654                                            SPARSE {offset, shape, nz, data} =>  offset
655                                          | DENSE {offset, data} => offset)
656                      in
657                          insertBlock (S,b,offset)
658                      end) y bs
659        end
660
661
662
663    (* --- ACCESSORS --- *)
664
665    fun shape {shape, blocks} = shape
666
667    fun sub ({shape, blocks},index) =
668        let
669            val (i,j) = dimVals index
670            val block = findBlock (i,j,blocks)
671        in
672            case block of
673                SOME (b) => 
674                (case b of 
675                     SPARSE {offset, shape, nz, data} => 
676                     (let 
677                         val (m,n) = dimVals offset
678                         val p = Index.toInt shape nz [i-m,j-n]
679                       in
680                           case p of SOME p' => Array.sub (data, p')
681                                   | NONE => Number.zero
682                       end)
683                     | DENSE {offset, data} => 
684                     (let 
685                         val (m,n) = dimVals offset
686                       in
687                           Tensor.sub (data,[i+m,j+n])
688                       end)
689                )
690              | NONE => Number.zero
691        end
692
693    fun update ({shape,blocks},index,new) =
694        let
695            val (i,j) = dimVals index
696            val block = findBlock (i,j,blocks)
697        in
698            case block of
699                SOME (b) => 
700                (case b of 
701                     SPARSE {offset, shape, nz, data} => 
702                     (let
703                         val (m,n) = dimVals shape
704                         val p     = Index.toInt shape nz [i-m,j-n]
705                     in
706                         case p of SOME p' => Array.update (data, p', new) | NONE => ()
707                     end)
708                     | DENSE {offset, data} =>
709                       (let 
710                         val (m,n) = dimVals offset
711                       in
712                           Tensor.update (data,[i+m,j+n],new)
713                       end)
714                )
715              | NONE => ()
716        end
717
718    fun findBlocks (i,axis,blocks) =
719        let
720            val blocks' = List.mapPartial
721                            (fn (b ) =>
722                                let
723                                    val (u,v) = case b of 
724                                                    SPARSE {offset=offset, shape=shape, nz, data} => 
725                                                      dimVals offset
726                                                    | DENSE {offset, data} =>
727                                                      dimVals offset
728                                    val (t,s) = case b of 
729                                                    SPARSE {offset=offset, shape=shape, nz, data} => 
730                                                    dimVals shape
731                                                    | DENSE {offset, data} =>
732                                                      dimVals (Tensor.shape data)
733                                                           
734                                in
735                                    (case axis of
736                                         1 => if ((i>=v) andalso (i-v<s)) then SOME b else NONE
737                                       | 0 => if ((i>=u) andalso (i-u<t)) then SOME b else NONE
738                                       | _ => raise Match)
739                                end)
740                            blocks
741        in
742            blocks'
743        end
744
745    fun slice ({shape,blocks},axis,i) =
746        let
747            val (m,n) = dimVals shape
748
749            val _ = case axis of
750                        0 => if (i > m) then raise Index else ()
751                      | 1 => if (i > n) then raise Index else ()
752                      | _ => raise Data
753                                                             
754        in
755            List.mapPartial
756            (fn (SPARSE {offset=offset, shape=shape, nz={indptr, indices}, data})  =>
757                let 
758                    val (u,v) = dimVals offset
759                    val (m,n) = dimVals shape
760
761                    val i'  = case axis of 1 => i-v | 0 => i-u | _ => raise Match
762                in
763                (case (Index.order,axis) of
764                     (Index.CSC,1) => (let 
765                                           val s   = IntArray.sub (indptr, i')
766                                           val e   = (if i' < n-1
767                                                      then IntArray.sub (indptr, i'+1) else Array.length data)
768                                           val len = e-s
769                                           val res = RNumberArray.array (len, Number.zero)
770                                           val rsi = IntArray.array (len, 0)
771                                           fun loop (i,n) = if i < e
772                                                            then (RNumberArray.update (res,n,Array.sub (data,i));
773                                                                  IntArray.update (rsi,i-s,Index.sub (indices,i));
774                                                                  loop (i+1,n+1))
775                                                            else ()
776                                       in
777                                           loop (s,0);
778                                           if len > 0 
779                                           then SOME (SLSPARSE {data=Tensor.fromArray ([1,len],res),
780                                                                indices=rsi,
781                                                                offset=offset})
782                                           else NONE
783                                       end)
784                   | (Index.CSR,0) => (let val s   = IntArray.sub (indptr, i')
785                                           val e   = (if i'< (m-1) 
786                                                      then IntArray.sub (indptr, i'+1) else Array.length data)
787                                           val len = e-s
788                                           val res = RNumberArray.array (len, Number.zero)
789                                           val rsi = IntArray.array (len, 0)
790                                           fun loop (i,n) = if i < e
791                                                            then (RNumberArray.update (res,n,Array.sub (data,i));
792                                                                  IntArray.update (rsi,i-s,Index.sub (indices,i));
793                                                                  loop (i+1,n+1))
794                                                            else ()
795                                       in
796                                           loop (s,0);
797                                           if len > 0 
798                                           then SOME (SLSPARSE {data=Tensor.fromArray ([1,len],res),
799                                                                indices=rsi,
800                                                                offset=offset})
801                                           else NONE
802                                       end)
803                   | (Index.CSC,0) => (let val vs = IntArray.foldri
804                                                        (fn (n,ii,ax) =>  if ii=i then (Array.sub(data,n),ii)::ax else ax)
805                                                        [] indices
806                                           val len = List.length vs
807                                       in
808                                           if len > 0 
809                                           then SOME (SLSPARSE {data=Tensor.fromList ([1,len],map #1 vs),
810                                                                indices=IntArray.fromList (map #2 vs),
811                                                                offset=offset})
812                                           else NONE
813                                       end)
814                   | (Index.CSR,1) => (let val vs = IntArray.foldri
815                                                        (fn (n,ii,ax) =>  if ii=i then (Array.sub(data,n),ii)::ax else ax)
816                                                        [] indices
817                                           val len = List.length vs
818                                       in
819                                           if len > 0 
820                                           then SOME (SLSPARSE {data=Tensor.fromList ([1,len],map #1 vs),
821                                                                indices=IntArray.fromList (map #2 vs),
822                                                                offset=offset})
823                                           else NONE
824                                       end)
825                   | (_,_) => raise Index)
826                end
827            | (DENSE {offset=offset, data})  =>
828              let 
829                    val (u,v) = dimVals offset
830                    val (m,n) = dimVals (Tensor.shape data)
831
832                    val i'  = case axis of 1 => i-v | 0 => i-u | _ => raise Match
833                    val sl  = case axis of 
834                                  1 => TensorSlice.fromto ([0,i'],[m-1,i'],data)
835                                | 0 => TensorSlice.fromto ([i',0],[i',n-1],data)
836                                | _ => raise Match
837              in
838                  SOME (SLDENSE {data=sl, offset=offset})
839              end
840            )
841            (findBlocks (i,axis,blocks) )
842           
843                                   
844        end
845
846
847    (* --- MAPPING --- *)
848
849    fun map f {shape, blocks} =
850        {shape=shape,
851         blocks=(List.map
852                     (fn (SPARSE {offset, shape, nz, data}) =>
853                         (SPARSE {offset=offset, shape=shape, nz=nz, data=array_map f data})
854                     |  (DENSE {offset, data}) =>
855                         (DENSE {data=(Tensor.map f data), offset=offset}))
856                     blocks)}
857
858    fun app f {shape, blocks} = 
859        List.app (fn (SPARSE {offset, shape, nz, data}) => 
860                     Array.app f data
861                 | (DENSE {offset, data}) => 
862                   Tensor.app f data)
863                 blocks
864
865    fun sliceAppi f sl =
866        List.app
867            (fn (SLSPARSE {data=sl,indices=si,offset}) => 
868                let val (m,n) = dimVals offset
869                in
870                    (RTensor.foldl
871                         (fn (x,i) => 
872                             let
873                                 val i' = Index.sub (si,i)+m
874                             in
875                                 (f (i',x); i+1)
876                             end) 0 sl; ())
877                end
878            | (SLDENSE {data=sl,offset}) => 
879              let val (m,n) = dimVals offset
880              in
881                  (RTensorSlice.foldl
882                       (fn (x,i) => 
883                           let
884                               val i' = i+m
885                           in
886                               (f (i',x); i+1)
887                           end) 0 sl; ())
888              end)
889            sl 
890
891    fun sliceFoldi f init sl =
892        List.foldl
893            (fn (SLSPARSE {data=sl,indices=si,offset},ax) => 
894                let val (m,n) = dimVals offset
895                in
896                    #2 (RTensor.foldl
897                            (fn (x,(i,ax)) => 
898                                let
899                                    val i' = Index.sub (si,i)+m
900                                in
901                                    (i+1, f (i',x,ax))
902                                end) (0,ax) sl)
903                end
904            | (SLDENSE {data=sl,offset},ax) => 
905              let val (m,n) = dimVals offset
906              in
907                  #2 (RTensorSlice.foldl
908                          (fn (x,(i,ax)) => 
909                              let
910                                  val i' = i+m
911                              in
912                                  (i+1, f (i',x,ax))
913                              end) (0,ax) sl)
914              end)
915            init sl 
916
917
918    (* --- BINOPS --- *)
919(*
920    fun a + b = map2 Number.+ a b
921    fun a * b = map2 Number.* a b
922    fun a - b = map2 Number.- a b
923    fun a / b = map2 Number./ a b
924    fun ~ a = map Number.~ a
925*)
926end
927
928
Note: See TracBrowser for help on using the repository browser.