By this ranking, fuzzy/smudgy images are outliers and cleaner lines are central. Or, 4's are central. Here's the code:
# working with the MNIST data set
#
# (c) Tariq Rashid, 2016
# license is GPLv2
import numpy as np
import matplotlib.pyplot as py
%matplotlib inline
# strangely, does not exist in numpy
def normalize(v):
max_v = -10000000
min_v = 10000000
for x in v:
if (x > max_v):
max_v = x
if (x < min_v):
min_v = x
scale = 1.0/(max_v - min_v)
offset = -min_v
for i in range(len(v)):
v[i] = (v[i] + offset) * scale
return v
# open the CSV file and read its contents into a list
data_file = open("mnist_dataset/mnist_train_100.csv", 'r')
data_list = data_file.readlines()
data_file.close()
rows = len(data_list)
image_mat = np.zeros((rows, 28 * 28))
for row in range(rows):
dig = data_list[row][0]
all_values = data_list[row].split(',')
image_vector = np.asfarray(all_values[1:])
image_mat[row] = (image_vector / 255.0 * 0.99) + 0.01
(u, s, v) = np.linalg.svd(image_mat)
row_features = normalize(u.dot(s))
# py.plot(np.sort(row_features))
keys = np.argsort(row_features)
grid=10
fig,axes = py.subplots(nrows=rows//grid, ncols=grid)
fig.set_figheight(15)
fig.set_figwidth(15)
py.subplots_adjust(top=1.1)
for row in range(rows):
ax = axes[row//grid][row%grid]
ax.set_title("{0:.2f}".format(row_features[keys[row]]), fontsize=10)
ax.set_xticks([])
ax.set_yticks([])
ax.imshow(image_mat[keys[row]].reshape(28,28), cmap='Greys', interpolation='None')
fig.savefig('foo.png', bbox_inches='tight')
No comments:
Post a Comment