Sunday, April 15, 2018

Fun with SVD: MNIST Digits

Singular Value Decomposition is a very powerful linear algebra operation. It factorizes a matrix into three matrices which have some interesting properties. Among other things, SVD gives you a ranking of how 'typical' or 'normal' a row or column is in comparison to the other columns. I created a 100x784 matrix of gray scale values, where each row is a sample image and each column is a position in the 28x28 gray-scale raster. The following chart gives the 100 digits ranked from "outlier" to "central", with the normalized rank above each digit.


By this ranking, fuzzy/smudgy images are outliers and cleaner lines are central. Or, 4's are central. Here's the code:

# variation of code from Tariq's book)# python notebook for Make Your Own Neural Network
# working with the MNIST data set
#
# (c) Tariq Rashid, 2016
# license is GPLv2

import numpy as np
import matplotlib.pyplot as py
%matplotlib inline

# strangely, does not exist in numpy
def normalize(v):
    max_v = -10000000
    min_v = 10000000
    for x in v:
        if (x > max_v):
            max_v = x
        if (x < min_v):
            min_v = x
    scale = 1.0/(max_v - min_v)
    offset = -min_v
    for i in range(len(v)):
        v[i] = (v[i] + offset) * scale 
    return v

# open the CSV file and read its contents into a list
data_file = open("mnist_dataset/mnist_train_100.csv", 'r')
data_list = data_file.readlines()
data_file.close()
rows = len(data_list)
image_mat = np.zeros((rows, 28 * 28))
for row in range(rows):
    dig = data_list[row][0]
    all_values = data_list[row].split(',')
    image_vector = np.asfarray(all_values[1:])
    image_mat[row] = (image_vector / 255.0 * 0.99) + 0.01
(u, s, v) = np.linalg.svd(image_mat)
row_features = normalize(u.dot(s))
# py.plot(np.sort(row_features))
keys = np.argsort(row_features)

grid=10
fig,axes = py.subplots(nrows=rows//grid, ncols=grid)
fig.set_figheight(15)
fig.set_figwidth(15)
py.subplots_adjust(top=1.1)
for row in range(rows):
    ax = axes[row//grid][row%grid]
    ax.set_title("{0:.2f}".format(row_features[keys[row]]), fontsize=10)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.imshow(image_mat[keys[row]].reshape(28,28), cmap='Greys', interpolation='None')
fig.savefig('foo.png', bbox_inches='tight')

No comments:

Post a Comment