1 """ This file contains different utility functions that are not connected
2 in anyway to the networks presented in the tutorials, but rather help in
3 processing the outputs into a more understandable way.
4
5 For example ``tile_raster_images`` helps in generating a easy to grasp
6 image from a set of samples or weights.
7 """
8
9
10 import numpy
11
12
14 """ Scales all values in the ndarray ndar to be between 0 and 1 """
15 ndar = ndar.copy()
16 ndar -= ndar.min()
17 ndar *= 1.0 / (ndar.max() + eps)
18 return ndar
19
20
21 -def tile_raster_images(X, img_shape, tile_shape, tile_spacing=(0, 0),
22 scale_rows_to_unit_interval=True,
23 output_pixel_vals=True):
24 """
25 Transform an array with one flattened image per row, into an array in
26 which images are reshaped and layed out like tiles on a floor.
27
28 This function is useful for visualizing datasets whose rows are images,
29 and also columns of matrices for transforming those rows
30 (such as the first layer of a neural net).
31
32 :type X: a 2-D ndarray or a tuple of 4 channels, elements of which can
33 be 2-D ndarrays or None;
34 :param X: a 2-D array in which every row is a flattened image.
35
36 :type img_shape: tuple; (height, width)
37 :param img_shape: the original shape of each image
38
39 :type tile_shape: tuple; (rows, cols)
40 :param tile_shape: the number of images to tile (rows, cols)
41
42 :param output_pixel_vals: if output should be pixel values (i.e. int8
43 values) or floats
44
45 :param scale_rows_to_unit_interval: if the values need to be scaled before
46 being plotted to [0,1] or not
47
48
49 :returns: array suitable for viewing as an image.
50 (See:`PIL.Image.fromarray`.)
51 :rtype: a 2-d array with same dtype as X.
52
53 """
54
55 assert len(img_shape) == 2
56 assert len(tile_shape) == 2
57 assert len(tile_spacing) == 2
58
59
60
61
62
63
64
65
66
67 out_shape = [(ishp + tsp) * tshp - tsp for ishp, tshp, tsp
68 in zip(img_shape, tile_shape, tile_spacing)]
69
70 if isinstance(X, tuple):
71 assert len(X) == 4
72
73 if output_pixel_vals:
74 out_array = numpy.zeros((out_shape[0], out_shape[1], 4),
75 dtype='uint8')
76 else:
77 out_array = numpy.zeros((out_shape[0], out_shape[1], 4),
78 dtype=X.dtype)
79
80
81 if output_pixel_vals:
82 channel_defaults = [0, 0, 0, 255]
83 else:
84 channel_defaults = [0., 0., 0., 1.]
85
86 for i in xrange(4):
87 if X[i] is None:
88
89
90 dt = out_array.dtype
91 if output_pixel_vals:
92 dt = 'uint8'
93 out_array[:, :, i] = numpy.zeros(out_shape,
94 dtype=dt) + channel_defaults[i]
95 else:
96
97
98 out_array[:, :, i] = tile_raster_images(
99 X[i], img_shape, tile_shape, tile_spacing,
100 scale_rows_to_unit_interval, output_pixel_vals)
101 return out_array
102
103 else:
104
105 H, W = img_shape
106 Hs, Ws = tile_spacing
107
108
109 dt = X.dtype
110 if output_pixel_vals:
111 dt = 'uint8'
112 out_array = numpy.zeros(out_shape, dtype=dt)
113
114 for tile_row in xrange(tile_shape[0]):
115 for tile_col in xrange(tile_shape[1]):
116 if tile_row * tile_shape[1] + tile_col < X.shape[0]:
117 this_x = X[tile_row * tile_shape[1] + tile_col]
118 if scale_rows_to_unit_interval:
119
120
121
122 this_img = scale_to_unit_interval(
123 this_x.reshape(img_shape))
124 else:
125 this_img = this_x.reshape(img_shape)
126
127
128 c = 1
129 if output_pixel_vals:
130 c = 255
131 out_array[
132 tile_row * (H + Hs): tile_row * (H + Hs) + H,
133 tile_col * (W + Ws): tile_col * (W + Ws) + W
134 ] = this_img * c
135 return out_array
136
137
139
140 K, I, H, W = X.shape
141
142
143
144
145 return numpy.swapaxes( X, 0, 1 ).reshape((I, K, -1)).reshape(I*K, -1)
146