Changeset 28666 in project


Ignore:
Timestamp:
04/12/13 06:49:11 (7 years ago)
Author:
Ivan Raikov
Message:

9ML-toolkit: bug fixes to indexing in tensor SML library

File:
1 edited

Legend:

Unmodified
Added
Removed
  • release/4/9ML-toolkit/trunk/examples/tensor.sml

    r28665 r28666  
    210210        val app : t -> (t -> unit) -> unit
    211211    end
     212
    212213structure Index : INDEX =
    213214    struct
    214         type t = int list
    215         type indexer = t -> int
    216         datatype storage = RowMajor | ColumnMajor
    217 
    218         exception Index
    219         exception Shape
    220 
    221         val order = ColumnMajor
    222 
    223         fun validShape shape = List.all (fn x => x > 0) shape
    224 
    225         fun validIndex index = List.all (fn x => x >= 0) index
    226 
    227         fun toInt shape index =
    228             let fun loop ([], [], accum, _) = accum
    229                   | loop ([], _, _, _) = raise Index
    230                   | loop (_, [], _, _) = raise Index
    231                   | loop (i::ri, l::rl, accum, fac) =
    232                 if (i >= 0) andalso (i < l) then
    233                     loop (ri, rl, i*fac + accum, fac*l)
    234                 else
    235                     raise Index
    236             in loop (index, shape, 0, 1)
    237             end
    238 
    239         (* ----- CACHED LINEAR INDEXER -----
    240 
    241            An indexer is a function that takes a list of
    242            indices, validates it and produces a nonnegative
    243            integer number. In short, the indexer is the
    244            mapper from indices to element positions in
    245            arrays.
    246 
    247            'indexer' builds such a mapper by optimizing
    248            the most common cases, which are 1d and 2d
    249            tensors.
    250         *)
     215        type t = int list
     216        type indexer = t -> int
     217        datatype storage = RowMajor | ColumnMajor
     218
     219        exception Index
     220        exception Shape
     221
     222        val order = ColumnMajor
     223
     224        fun validShape shape = List.all (fn x => x > 0) shape
     225
     226        fun validIndex index = List.all (fn x => x >= 0) index
     227
     228        fun toInt shape index =
     229            let fun loop ([], [], accum, p) = accum
     230                  | loop ([], _, _, _) = raise Index
     231                  | loop (_, [], _, _) = raise Index
     232                  | loop (i::ri, l::rl, accum, p) =
     233                if (i >= 0) andalso (i < l) then
     234                    loop (ri, rl, i * p + accum, p * l)
     235                else
     236                    raise Index
     237            in loop (index, shape, 0, 1)
     238            end
     239
     240        (* ----- CACHED LINEAR INDEXER -----
     241
     242           An indexer is a function that takes a list of
     243           indices, validates it and produces a nonnegative
     244           integer number. In short, the indexer is the
     245           mapper from indices to element positions in
     246           arrays.
     247
     248           'indexer' builds such a mapper by optimizing
     249           the most common cases, which are 1d and 2d
     250           tensors.
     251        *)
    251252    local
    252         fun doindexer [] _ = raise Shape
    253           | doindexer [a] [dx] =
    254             let fun f [x] = if (x > 0) andalso (x < a)
    255                             then x
    256                             else raise Index
    257                   | f _ = raise Index
    258             in f end
    259           | doindexer [a,b] [dx, dy] =
    260             let fun f [x,y] = if ((x > 0) andalso (x < a) andalso
    261                                   (y > 0) andalso (y < b))
    262                               then x + dy * y
    263                               else raise Index
    264                   | f _ = raise Index
    265             in f end
    266           | doindexer [a,b,c] [dx,dy,dz] =
    267             let fun f [x,y,z] = if ((x > 0) andalso (x < a) andalso
    268                                     (y > 0) andalso (y < b) andalso
    269                                     (z > 0) andalso (z < c))
    270                                 then x + dy * y + dz * z
    271                                 else raise Index
    272                   | f _ = raise Index
    273             in f end
    274           | doindexer shape memo =
    275             let fun f [] [] accum [] = accum
    276                   | f _  _  _ [] = raise Index
    277                   | f (fac::rf) (ndx::ri) accum (dim::rd) =
    278                     if (ndx >= 0) andalso (ndx < dim) then
    279                         f rf ri (accum + ndx * fac) rd
    280                     else
    281                         raise Index
    282             in f shape memo 0
    283             end
     253        fun doindexer [] _ = raise Shape
     254          | doindexer [a] [dx] =
     255            let fun f [x] = if (x > 0) andalso (x < a)
     256                            then x
     257                            else raise Index
     258                  | f _ = raise Index
     259            in f end
     260          | doindexer [a,b] [dx, dy] =
     261                let fun f [x,y] = if ((x >= 0) andalso (x < a) andalso
     262                                      (y >= 0) andalso (y < b))
     263                                  then x + dy * y
     264                                  else raise Index
     265                      | f _ = raise Index
     266                in f end
     267          | doindexer [a,b,c] [dx,dy,dz] =
     268                let fun f [x,y,z] = if ((x >= 0) andalso (x < a) andalso
     269                                        (y >= 0) andalso (y < b) andalso
     270                                        (z >= 0) andalso (z < c))
     271                                    then x + dy * y + dz * z
     272                                    else raise Index
     273                      | f _ = raise Index
     274                in f end
     275          | doindexer shape memo =
     276                let fun f [] [] accum [] = accum
     277                      | f _  _  _ [] = raise Index
     278                      | f (fact::rf) (ndx::ri) accum (dim::rd) =
     279                        if (ndx >= 0) andalso (ndx < dim) then
     280                            f rf ri (accum + ndx * fact) rd
     281                        else
     282                            raise Index
     283                in f shape memo 0
     284                end
    284285    in
    285         fun indexer shape =
    286             let fun memoize accum [] = []
    287                   | memoize accum (dim::rd) =
    288                     accum :: (memoize (dim * accum) rd)
    289             in
    290                 if validShape shape
    291                 then doindexer shape (memoize 1 shape)
    292                 else raise Shape
    293             end
    294     end
    295 
    296         fun length shape =
    297             let fun prod (a,b) =
    298                 if b < 0 then raise Shape else a * b
    299             in foldl prod 1 shape
    300             end
    301 
    302         fun first shape = map (fn x => 0) shape
    303 
    304         fun last [] = []
    305           | last (size :: rest) =
    306             if size < 1
    307             then raise Shape
    308             else size - 1 :: last rest
    309 
    310         fun next' [] [] = raise Subscript
    311           | next' _ [] = raise Index
    312           | next' [] _ = raise Index
    313           | next' (dimension::restd) (index::resti) =
    314             if (index + 1) < dimension
    315             then (index + 1) :: resti
    316             else 0 :: (next' restd resti)
    317 
    318         fun prev' [] [] = raise Subscript
    319           | prev' _ [] = raise Index
    320           | prev' [] _ = raise Index
    321           | prev' (dimension::restd) (index::resti) =
    322             if (index > 0)
    323             then index - 1 :: resti
    324             else dimension - 1 :: prev' restd resti
    325 
    326         fun next shape index = (SOME (next' shape index)) handle
    327             Subscript => NONE
    328 
    329         fun prev shape index = (SOME (prev' shape index)) handle
    330             Subscript => NONE
    331 
    332         fun inBounds shape index =
    333             ListPair.all (fn (x,y) => (x >= 0) andalso (x < y))
    334             (index, shape)
    335 
    336         fun compare ([],[]) = EQUAL
    337           | compare (_, []) = raise Index
    338           | compare ([],_) = raise Index
    339           | compare (a::ra, b::rb) =
    340             case Int.compare (a,b) of
    341                 EQUAL => compare (ra,rb)
    342               | LESS => LESS
    343               | GREATER => GREATER
     286        fun indexer shape =
     287            let fun memoize accum [] = []
     288                  | memoize accum (dim::rd) =
     289                accum :: (memoize (dim * accum) rd)
     290            in if validShape shape then
     291                   doindexer shape (memoize 1 shape)
     292               else
     293                   raise Shape
     294            end
     295    end
     296
     297        fun length shape =
     298            let fun prod (a,b) =
     299                if b < 0 then raise Shape else a * b
     300            in foldl prod 1 shape
     301            end
     302
     303        fun first shape = map (fn x => 0) shape
     304
     305        fun last [] = []
     306          | last (size :: rest) = size - 1 :: last rest
     307
     308        fun next' [] [] = raise Subscript
     309          | next' _ [] = raise Index
     310          | next' [] _ = raise Index
     311          | next' (dimension::restd) (index::resti) =
     312            if (index + 1) < dimension then
     313                (index + 1) :: resti
     314            else
     315                0 :: (next' restd resti)
     316
     317        fun prev' [] [] = raise Subscript
     318          | prev' _ [] = raise Index
     319          | prev' [] _ = raise Index
     320          | prev' (dimension::restd) (index::resti) =
     321            if (index > 0) then
     322                (index - 1) :: resti
     323            else
     324                (dimension - 1) :: (prev' restd resti)
     325
     326        fun next shape index = (SOME (next' shape index)) handle
     327            Subscript => NONE
     328
     329        fun prev shape index = (SOME (prev' shape index)) handle
     330            Subscript => NONE
     331
     332        fun inBounds shape index =
     333            ListPair.all (fn (x,y) => (x >= 0) andalso (x < y))
     334            (index, shape)
     335
     336        fun compare ([],[]) = EQUAL
     337          | compare (_, []) = raise Index
     338          | compare ([],_) = raise Index
     339          | compare (a::ra, b::rb) =
     340            case Int.compare (a,b) of
     341                EQUAL => compare (ra,rb)
     342              | LESS => LESS
     343              | GREATER => GREATER
    344344
    345345    local
    346         fun iterator a inner =
    347             let fun loop accum f =
    348                 let fun innerloop i =
    349                     if i < a
    350                     then if inner (i::accum) f
    351                          then innerloop (i+1)
    352                          else false
    353                     else true
    354                 in innerloop 0
    355                 end
    356             in loop
    357             end
    358         fun build_iterator [a] =
    359             let fun loop accum f =
    360                 let fun innerloop i =
    361                     if i < a
    362                     then if f (i::accum)
    363                          then innerloop (i+1)
    364                          else false
    365                     else true
    366                 in innerloop 0
    367                 end
    368             in loop
    369             end
    370           | build_iterator (a::rest) = iterator a (build_iterator rest)
     346        fun iterator a inner =
     347            let fun loop accum f =
     348                let fun innerloop i =
     349                    if i < a then
     350                        if inner (i::accum) f then
     351                            innerloop (i+1)
     352                        else
     353                            false
     354                    else
     355                        true
     356                in innerloop 0
     357                end
     358            in loop
     359            end
     360        fun build_iterator [a] =
     361            let fun loop accum f =
     362                let fun innerloop i =
     363                    if i < a then
     364                        if f (i::accum) then
     365                            innerloop (i+1)
     366                        else
     367                            false
     368                    else
     369                        true
     370                in innerloop 0
     371                end
     372            in loop
     373            end
     374          | build_iterator (a::rest) = iterator a (build_iterator rest)
    371375    in
    372         fun all shape = build_iterator shape []
     376        fun all shape = build_iterator shape []
    373377    end
    374378
    375379    local
    376         fun iterator a inner =
    377             let fun loop accum f =
    378                 let fun innerloop i =
    379                     if i < a
    380                     then if inner (i::accum) f
    381                          then true
    382                          else innerloop (i+1)
    383                     else false
    384                 in innerloop 0
    385                 end
    386             in loop
    387             end
    388         fun build_iterator [a] =
    389             let fun loop accum f =
    390                 let fun innerloop i =
    391                     if i < a
    392                     then if f (i::accum)
    393                          then true
    394                          else innerloop (i+1)
    395                     else false
    396                 in innerloop 0
    397                 end
    398             in loop
    399             end
    400           | build_iterator (a::rest) = iterator a (build_iterator rest)
     380        fun iterator a inner =
     381            let fun loop accum f =
     382                let fun innerloop i =
     383                    if i < a then
     384                        if inner (i::accum) f then
     385                            true
     386                        else
     387                            innerloop (i+1)
     388                    else
     389                        false
     390                in innerloop 0
     391                end
     392            in loop
     393            end
     394        fun build_iterator [a] =
     395            let fun loop accum f =
     396                let fun innerloop i =
     397                    if i < a then
     398                        if f (i::accum) then
     399                            true
     400                        else
     401                            innerloop (i+1)
     402                    else
     403                        false
     404                in innerloop 0
     405                end
     406            in loop
     407            end
     408          | build_iterator (a::rest) = iterator a (build_iterator rest)
    401409    in
    402         fun any shape = build_iterator shape []
     410        fun any shape = build_iterator shape []
    403411    end
    404412
    405413    local
    406         fun iterator a inner =
    407             let fun loop accum f =
    408                 let fun innerloop i =
    409                     if i < a
    410                     then (inner (i::accum) f;
    411                           innerloop (i+1))
    412                     else ()
    413                 in innerloop 0
    414                 end
    415             in loop
    416             end
    417         fun build_iterator [a] =
    418             let fun loop accum f =
    419                 let fun innerloop i =
    420                     if i < a
    421                     then (f (i::accum); innerloop (i+1))
    422                     else ()
    423                 in innerloop 0
    424                 end
    425             in loop
    426             end
    427           | build_iterator (a::rest) = iterator a (build_iterator rest)
     414        fun iterator a inner =
     415            let fun loop accum f =
     416                let fun innerloop i =
     417                    case i < a of
     418                        true => (inner (i::accum) f; innerloop (i+1))
     419                      | false => ()
     420                in innerloop 0
     421                end
     422            in loop
     423            end
     424        fun build_iterator [a] =
     425            let fun loop accum f =
     426                let fun innerloop i =
     427                    case i < a of
     428                        true => (f (i::accum); innerloop (i+1))
     429                      | false => ()
     430                in innerloop 0
     431                end
     432            in loop
     433            end
     434          | build_iterator (a::rest) = iterator a (build_iterator rest)
    428435    in
    429         fun app shape = build_iterator shape []
    430     end
    431 
    432         fun a < b = compare(a,b) = LESS
    433         fun a > b = compare(a,b) = GREATER
    434         fun eq (a, b) = compare(a,b) = EQUAL
    435         fun a <> b = not (a = b)
    436         fun a <= b = not (a > b)
    437         fun a >= b = not (a < b)
    438         fun a - b = ListPair.map Int.- (a,b)
     436        fun app shape = build_iterator shape []
     437    end
     438
     439        fun a < b = compare(a,b) = LESS
     440        fun a > b = compare(a,b) = GREATER
     441        fun eq (a, b) = compare(a,b) = EQUAL
     442        fun a <> b = not (a = b)
     443        fun a <= b = not (a > b)
     444        fun a >= b = not (a < b)
     445        fun a - b = ListPair.map Int.- (a,b)
    439446
    440447    end
     
    463470        val prev : t -> index -> index option
    464471        val ranges : index -> ((index * index) list) -> t
     472
    465473        val iteri : (index -> bool) -> t -> bool
     474        val iteri2 : (index * index -> bool) -> (t * t) -> bool
     475
     476       
    466477    end
    467478
     
    525536            let fun loop (ndx: index) (g: index -> bool) =
    526537                let fun innerloop i =
    527                     (if i > last then
     538                     (if i > last then
    528539                        true
    529540                    else if (f (i::ndx) g) then
     
    541552            nested_loop (build_iterator ra rb) a b
    542553
     554        fun simple_loop2 (first : int) (last : int) (first' : int) (last' : int) =
     555            let fun loop (ndx : index) (ndx' : index) (g: index * index -> bool) =
     556                let fun innerloop (i,j) =
     557                    if i > last andalso j > last' then
     558                        true
     559                    else if g (i::ndx,j::ndx') then
     560                        innerloop (i+1,j+1)
     561                    else
     562                        false
     563                in innerloop (first,first') end
     564            in loop end
     565
     566        fun nested_loop2 f (first : int) (last : int) (first' : int) (last' : int) =
     567            let fun loop (ndx: index) (ndx': index) (g: index * index -> bool) =
     568                let fun innerloop (i,j) =
     569                    (if i > last andalso j > last' then
     570                        true
     571                    else if (f (i::ndx) (j::ndx') g) then
     572                        innerloop (i+1,j+1)
     573                    else
     574                        false)
     575                in
     576                     innerloop (first,first')
     577                end
     578            in loop end
     579
     580        fun build_iterator2 ([a] : index) ([b] : index) ([a'] : index) ([b'] : index) =
     581            simple_loop2 a b a' b'
     582          | build_iterator2 (a::ra) (b::rb) (a'::ra') (b'::rb') =
     583            nested_loop2 (build_iterator2 ra rb ra' rb') a b a' b'
     584
    543585    in
    544586
     
    565607          | range_append ((lo,up),[]) = [(lo,up)]
    566608
     609
     610        fun listWrite converter file x =
     611        (List.app (fn x => (TextIO.output(file, "," ^ (converter x)))) x)
     612
     613        fun intListWrite file x = listWrite Int.toString file x
     614
    567615        fun ranges' shape ((lo, up) :: rest) =
    568             if (Index.validShape shape) andalso (Index.validIndex lo) andalso (Index.validIndex up) andalso
     616            (if (Index.validShape shape) andalso (Index.validIndex lo) andalso (Index.validIndex up) andalso
    569617               (Index.inBounds shape lo) andalso (Index.inBounds shape up) andalso Index.< (lo,up)
    570                then range_append ((lo,up), (ranges' shape rest)) else (ranges' shape rest)
     618               then range_append ((lo,up), (ranges' shape rest)) else (ranges' shape rest))
    571619            | ranges' shape ([]) = []
    572620
     
    660708        fun iteri f RangeEmpty = f []
    661709          | iteri (f: index -> bool) (RangeIn(shape,lo: index,up: index)) =
    662            (build_iterator lo up) lo f
     710           (build_iterator lo up) [] f
    663711          | iteri (f: index -> bool) (RangeSet(shape,set)) =
    664            List.all (fn (lo,up) => (build_iterator lo up) lo f) set
     712           List.all (fn (lo,up) => (build_iterator lo up) [] f) set
     713
     714        (* Builds an interator that applies 'f' sequentially to
     715           all the indices of the two ranges, *)
     716        fun iteri2 f (RangeEmpty,RangeEmpty) = f ([],[])
     717          | iteri2 (f: index * index -> bool) (RangeIn(shape,lo: index,up: index),RangeIn(shape',lo': index,up': index)) =
     718            if shape=shape' then (build_iterator2 lo up lo' up') [] [] f else raise Range
     719          | iteri2 (f: index * index -> bool) (RangeSet(shape,set),RangeSet(shape',set')) =
     720           if shape=shape' then ListPair.all (fn ((lo,up),(lo',up')) => (build_iterator2 lo up lo' up') [] [] f) (set,set') else raise Range
    665721
    666722    end
     
    775831        val shape  : 'a slice -> (index)
    776832
     833        val app : ('a -> unit) -> 'a slice -> unit
    777834        val map : ('a -> 'b) -> 'a slice -> 'b Vector.vector
    778835        val map2 : ('a * 'b -> 'c) -> 'a slice -> 'b slice -> 'c Vector.vector
     
    9781035        type 'a slice = {range : range, shape: index, tensor : 'a tensor}
    9791036
    980 (*
    981         val slice : ((index * index) list) * 'a tensor -> 'a slice
    982         val length : 'a slice -> int
    983         val base : 'a slice -> range * 'a tensor
    984         val shape : 'a slice -> (index)
    985 *)
    9861037        fun slice (rs,tensor) =
    9871038            let val r = (Range.ranges (Tensor.shape tensor) rs)
     
    9911042                 tensor=tensor}
    9921043            end
     1044
    9931045        fun length ({range, shape, tensor}) = Range.length range
    9941046        fun base ({range, shape, tensor}) = tensor
    9951047        fun shape ({range, shape, tensor}) = Range.shape range
    9961048
     1049        fun map f slice =
     1050        let
     1051           val te  = #tensor slice
     1052           val ra  = #range slice
     1053           val fndx  = Range.first ra
     1054           val arr = Array.array(length(slice),f (Tensor.sub(te,fndx)))
     1055           val i   = ref 0
     1056        in
     1057           Range.iteri (fn (ndx) => let val v = f (Tensor.sub (te,ndx)) in (Array.update (arr, !i, v); i := (!i + 1); true) end) ra;
     1058           Array.vector arr
     1059        end
     1060
     1061        fun app f (slice: 'a slice) =
     1062        let
     1063           val te  = #tensor slice
     1064           val ra  = #range slice
     1065           val fndx  = Range.first ra
     1066        in
     1067           Range.iteri (fn (ndx) => (f (Tensor.sub (te,ndx)); true)) ra; ()
     1068        end
     1069
     1070        fun map2 f (sl1: 'a slice) (sl2: 'b slice) =
     1071        let
     1072           val _    = if not ((#shape sl1) = (#shape sl2)) then raise Index.Shape else ()
     1073           val te1  = #tensor sl1
     1074           val te2  = #tensor sl2
     1075           val ra1  = #range sl1
     1076           val ra2  = #range sl2
     1077           val fndx1  = Range.first ra1
     1078           val fndx2  = Range.first ra2
     1079           val arr   = Array.array(length(sl1),f (Tensor.sub(te1,fndx1),Tensor.sub(te2,fndx2)))
     1080           val i     = ref 0
     1081        in
     1082           Range.iteri2 (fn (ndx,ndx') => let val v = f (Tensor.sub (te1,ndx),Tensor.sub (te2,ndx')) in (Array.update (arr, !i, v); i := (!i + 1); true) end) (ra1,ra2);
     1083           Array.vector arr
     1084        end
     1085
     1086        fun foldl f init (slice: 'a slice) =
     1087        let
     1088           val te     = #tensor slice
     1089           val ra     = #range slice
     1090           val ax     = ref init
     1091        in
     1092           Range.iteri (fn (ndx) => let val ax' = f (Tensor.sub (te,ndx),!ax) in ((ax := ax'); true) end) ra;
     1093           !ax
     1094        end
    9971095
    9981096    end                               
     
    21992297            end
    22002298    end (* NumberTensor *)
     2299
    22012300structure RTensor =
    22022301    struct
     
    28512950
    28522951
     2952structure RTensorSlice =
     2953    struct
     2954        structure Tensor = RTensor
     2955        structure Index  = Tensor.Index
     2956        structure Range  = Range
     2957           
     2958        type index = Tensor.Index.t
     2959        type range = Range.t
     2960        type tensor = RTensor.tensor
     2961
     2962        type slice = {range : range, shape: index, tensor : tensor}
     2963
     2964        fun slice (rs,tensor) =
     2965            let val r = (Range.ranges (Tensor.shape tensor) rs)
     2966            in
     2967                {range=r,
     2968                 shape=(Range.shape r),
     2969                 tensor=tensor}
     2970            end
     2971
     2972        fun length ({range, shape, tensor}) = Range.length range
     2973        fun base ({range, shape, tensor}) = tensor
     2974        fun shape ({range, shape, tensor}) = Range.shape range
     2975
     2976        fun map f slice =
     2977        let
     2978           val te  = #tensor slice
     2979           val ra  = #range slice
     2980           val fndx  = Range.first ra
     2981           val arr = Array.array(length(slice),f (Tensor.sub(te,fndx)))
     2982           val i   = ref 0
     2983        in
     2984           Range.iteri (fn (ndx) => let val v = f (Tensor.sub (te,ndx)) in (Array.update (arr, !i, v); i := (!i + 1); true) end) ra;
     2985           Array.vector arr
     2986        end
     2987
     2988        fun listWrite converter file x =
     2989        (List.app (fn x => (TextIO.output(file, "," ^ (converter x)))) x)
     2990
     2991        fun intListWrite file x = listWrite Int.toString file x
     2992
     2993        fun app f (slice: slice) =
     2994        let
     2995           val te  = #tensor slice
     2996           val ra  = #range slice
     2997           val fndx  = Range.first ra
     2998        in
     2999           Range.iteri (fn (ndx) => (f (Tensor.sub (te,ndx)); true)) ra; ()
     3000        end
     3001
     3002        fun map2 f (sl1: slice) (sl2: slice) =
     3003        let
     3004           val _    = if not ((#shape sl1) = (#shape sl2)) then raise Index.Shape else ()
     3005           val te1  = #tensor sl1
     3006           val te2  = #tensor sl2
     3007           val ra1  = #range sl1
     3008           val ra2  = #range sl2
     3009           val fndx1  = Range.first ra1
     3010           val fndx2  = Range.first ra2
     3011           val arr   = Array.array(length(sl1),f (Tensor.sub(te1,fndx1),Tensor.sub(te2,fndx2)))
     3012           val i     = ref 0
     3013        in
     3014           Range.iteri2 (fn (ndx,ndx') => let val v = f (Tensor.sub (te1,ndx),Tensor.sub (te2,ndx')) in (Array.update (arr, !i, v); i := (!i + 1); true) end) (ra1,ra2);
     3015           Array.vector arr
     3016        end
     3017
     3018        fun foldl f init (slice: slice) =
     3019        let
     3020           val te     = #tensor slice
     3021           val ra     = #range slice
     3022           val ax     = ref init
     3023        in
     3024           Range.iteri (fn (ndx) => let val ax' = f (Tensor.sub (te,ndx),!ax) in ((ax := ax'); true) end) ra;
     3025           !ax
     3026        end
     3027
     3028
     3029    end                               
     3030
    28533031
    28543032structure TensorFile =
     
    29383116fun realTensorWrite file x = (intListWrite file (RTensor.shape x); RTensor.app (fn x => (realWrite file x)) x)
    29393117fun complexTensorWrite file x = (intListWrite file (CTensor.shape x); CTensor.app (fn x => (complexWrite file x)) x)
     3118
     3119fun realTensorSliceWrite file x = (intListWrite file (RTensorSlice.shape x); RTensorSlice.app (fn x => (realWrite file x)) x)
     3120
    29403121end
    29413122
     
    29703151
    29713152
    2972 val Ne = 8
    2973 val Ni = 2
    2974 
    2975 
    2976 val S  = RTensor.cat (1,
    2977                       (RTensor.*> 0.5 (RandomTensor.realRandomTensor (13,17) [(Ne+Ni),Ne]) ),
    2978                       (RTensor.~ (RandomTensor.realRandomTensor (19,23) [(Ne+Ni),Ni])))
    2979 
     3153val Ne = 2
     3154val Ni = 1
     3155
     3156val S0  = RTensor.fromList ([2,2],[1.0,2.0,3.0,4.0])
     3157val _ = (print "S0 = "; TensorFile.realTensorWrite (TextIO.stdOut) S0)
     3158val v = RTensor.sub (S0,[0,0])
     3159val _ = (print "v = "; TensorFile.realWrite (TextIO.stdOut) v)
     3160val v = RTensor.sub (S0,[0,1])
     3161val _ = (print "v = "; TensorFile.realWrite (TextIO.stdOut) v)
     3162val v = RTensor.sub (S0,[1,0])
     3163val _ = (print "v = "; TensorFile.realWrite (TextIO.stdOut) v)
     3164val v = RTensor.sub (S0,[1,1])
     3165val _ = (print "v = "; TensorFile.realWrite (TextIO.stdOut) v)
     3166
     3167val SNe = (RTensor.*> 0.5 (RandomTensor.realRandomTensor (13,17) [(Ne+Ni),Ne]) )
     3168val SNi = (RTensor.~ (RandomTensor.realRandomTensor (19,23) [(Ne+Ni),Ni]))
     3169
     3170val _ = TensorFile.realTensorWrite (TextIO.stdOut) SNe
     3171
     3172val v = RTensor.sub (SNe,[0,0])
     3173val _ = (print "v = "; TensorFile.realWrite (TextIO.stdOut) v)
     3174val v = RTensor.sub (SNe,[0,1])
     3175val _ = (print "v = "; TensorFile.realWrite (TextIO.stdOut) v)
     3176val v = RTensor.sub (SNe,[1,0])
     3177val _ = (print "v = "; TensorFile.realWrite (TextIO.stdOut) v)
     3178val v = RTensor.sub (SNe,[1,1])
     3179val _ = (print "v = "; TensorFile.realWrite (TextIO.stdOut) v)
     3180val v = RTensor.sub (SNe,[2,0])
     3181val _ = (print "v = "; TensorFile.realWrite (TextIO.stdOut) v)
     3182val v = RTensor.sub (SNe,[2,1])
     3183val _ = (print "v = "; TensorFile.realWrite (TextIO.stdOut) v)
     3184
     3185val S  = RTensor.cat (1, SNe, SNi)
    29803186val _ = TensorFile.realTensorWrite (TextIO.stdOut) S
    29813187
    2982 
     3188val v = RTensor.sub (S,[0,0])
     3189val _ = (print "v = "; TensorFile.realWrite (TextIO.stdOut) v)
     3190val v = RTensor.sub (S,[0,1])
     3191val _ = (print "v = "; TensorFile.realWrite (TextIO.stdOut) v)
     3192val v = RTensor.sub (S,[0,2])
     3193val _ = (print "v = "; TensorFile.realWrite (TextIO.stdOut) v)
     3194val v = RTensor.sub (S,[1,0])
     3195val _ = (print "v = "; TensorFile.realWrite (TextIO.stdOut) v)
     3196val v = RTensor.sub (S,[1,1])
     3197val _ = (print "v = "; TensorFile.realWrite (TextIO.stdOut) v)
     3198val v = RTensor.sub (S,[1,2])
     3199val _ = (print "v = "; TensorFile.realWrite (TextIO.stdOut) v)
     3200val v = RTensor.sub (S,[2,0])
     3201val _ = (print "v = "; TensorFile.realWrite (TextIO.stdOut) v)
     3202val v = RTensor.sub (S,[2,1])
     3203val _ = (print "v = "; TensorFile.realWrite (TextIO.stdOut) v)
     3204val v = RTensor.sub (S,[2,2])
     3205val _ = (print "v = "; TensorFile.realWrite (TextIO.stdOut) v)
     3206
     3207val S1 = RTensorSlice.slice ([([0,0],[2,0])],S)
     3208val S2 = RTensorSlice.slice ([([0,1],[2,1])],S)
     3209val S3 = RTensorSlice.slice ([([0,2],[2,2])],S)
     3210
     3211val _ = TensorFile.realTensorSliceWrite (TextIO.stdOut) S1
     3212val _ = TensorFile.realTensorSliceWrite (TextIO.stdOut) S2
     3213val _ = TensorFile.realTensorSliceWrite (TextIO.stdOut) S3
     3214
     3215
Note: See TracChangeset for help on using the changeset viewer.