Caffe中层参数及数据的可视化

本文将讲解如何可视化caffe网络中的层的参数及数据,即只要输入的规格为(n, height, width)或(n, height, width, 3)都可以通过如下函数可视化。

def vis_square(data):
“””Take an array of shape (n, height, width) or (n, height, width, 3)
and visualize each (height, width) thing in a grid of size approx. sqrt(n) by sqrt(n)”””

    # normalize data for display
    data = (data - data.min()) / (data.max() - data.min())

    # force the number of filters to be square
    n = int(np.ceil(np.sqrt(data.shape[0])))
    padding = (((0, n ** 2 - data.shape[0]),
               (0, 1), (0, 1))                 # add some space between filters
               + ((0, 0),) * (data.ndim - 3))  # don't pad the last dimension (if there is one)
      data = np.pad(data, padding, mode='constant', constant_values=1)  # pad with ones (white)

      # tile the filters into an image
       data = data.reshape((n, n) + data.shape[1:]).transpose((0, 2, 1, 3) + tuple(range(4, data.ndim + 1)))
       data = data.reshape((n * data.shape[1], n * data.shape[3]) + data.shape[4:])

      plt.imshow(data); plt.axis('off')

filters = net.params['conv1'][0].data
vis_square(filters.transpose(0, 2, 3, 1))`

这是官方给的例子,我们假设net.params[‘conv1’][0].data.shape为(96,3,11,11)。
filters.transpose(0, 2, 3, 1) #此时shape为(96,11,11,3)

data = (data - data.min()) / (data.max() - data.min()) #预处理,减去平均值,并除以值的变化范围

n = int(np.ceil(np.sqrt(data.shape[0]))) #n为10

padding = (((0, n * 2 - data.shape[0]),(0, 1), (0, 1))+ ((0, 0),) (data.ndim - 3)) #pad为填充函数,第一维扩展为100,第二第三维分别加1,这是为了留出图之间的空隙,第四维不变

data = np.pad(data, padding, mode=’constant’, constant_values=1) # 开始进行扩展,扩展的值为1,即全为白色,此时shape为(100,12,12,3)

data = data.reshape((n, n) + data.shape[1:]).transpose((0, 2, 1, 3) + tuple(range(4, data.ndim + 1))) #shape由(100,12,12,3)变为(10,12,10,12,3)

data = data.reshape((n data.shape[1], n data.shape[3]) + data.shape[4:]) #shape由(10,12,10,12,3)变为(120,120,3)

plt.imshow(data); plt.axis(‘off’) #显示图像

接下来是全连接层的可视化。
feat = net.blobs[‘fc6’].data[0] #获取全连接层处理前的数据

plt.subplot(2, 1, 1) #指定将要绘制的图像的位置,在这里是两行一列中的第一行第一列的位置。

plt.plot(feat.flat) #绘制图像,feat.flat是numpy中的迭代器,如下图1

plt.subplot(2, 1, 2)

plt.hist(feat.flat[feat.flat > 0], bins=100) #绘制统计直方图,如下图2,统计每个区间内数据的个数