mxnet开车教程series1-mnist上手
mxnet入门中文教程,让我们从mnist果蝇数据集开始开车 > 本文由中南大学较为牛逼的研究生金天同学原创,欢迎转载,但是请保留这段版权信息,如果你对文章有任何疑问,欢迎微信联系我:jintianiloveu。牛逼大神一一为你解答!
前言
最近搭建起了深度学习环境,mxnet被亚马逊钦定为官方的机器学习库,加上mxnet快速,代码清晰的特点,我赶紧乘上了mxnet快车,准备以mxnet为基础开始一些理论研究和产品实现。然而….mxnet搭建过程还是有点麻烦的,尤其是对于对编译过程不是非常熟悉的同学,这一点和caffe有点像,不过这不是问题,在本博客前面几篇文章对此有一个专门的教程,大家可以去看看,欢迎评论转载。这篇文章是mxnet开车教程的第一弹,让我们从果蝇数据集开始下手。
开车!
二话不多说,开始开车,作为一名深度学习老司机,我们应该要学会果蝇数据集的正确下载方式,我这里就不贴了,去Lecun的官网下载。下载之后解压,你将会看到四个文件:
t10k-images.idx3-ubyte
t10k-labels.idx1-ubyte
train-images.idx3-ubyte
train-labels.idx1-ubyte
这就是训练集和测试集的数据和标签,很多人一看不知道这是什么鬼,其实我也不知道这是什么鬼,反正是一种文件格式就对了。不多说了直接上代码,开车之前先导入包:
import struct
import numpy as np
import matplotlib.pyplot as plt
import mxnet as mx
import logging
logging.getLogger().setLevel(logging.DEBUG)
哪个包缺少安装哪个,玩mxnet你不要告诉我你还没有安装mxnet,快去我的另外一片博文看教程安装。
读取mnist数据集的正确姿势
接下来我有必要传授大家读取mnist数据集的正确方式了,网上流传的各种方法都是瞎扯淡,不懂得科学内涵(手动装逼)。正确的读取方式我谢了两个函数,一个读取label,一个读取image:
def read_mnist_label(file_name):
bin_file = open(file_name, 'rb')
magic, num = struct.unpack(">II", bin_file.read(8))
label = np.fromstring(bin_file.read(), dtype=np.int8)
return label
def read_mnist_image(file_name):
bin_file = open(file_name, 'rb')
magic, num, rows, cols = struct.unpack(">IIII", bin_file.read(16))
image = np.fromstring(bin_file.read(), dtype=np.uint8).reshape(num, rows, cols)
return image
将我们下载的文件传进去,就能得到label,images的输出,应该都是numpy.array的格式。
测试图片
我们写个显示图片的函数把:
def plot_image(image_array):
plt.imshow(image_array, cmap='gray')
plt.show()
输入图片矩阵,画出图片。
val_img = read_mnist_image('t10k-images.idx3-ubyte')
plot_image(val_img[0])
这就把测试集的第一张图片显示出来了。
搭建mxnet网络
这部分直接根据官网的来:
batch_size = 100
train_iter = mx.io.NDArrayIter(to4d(train_img), train_lbl, batch_size, shuffle=True)
val_iter = mx.io.NDArrayIter(to4d(val_img), val_lbl, batch_size)
data = mx.sym.Variable('data')
data = mx.sym.Flatten(data=data)
fc1 = mx.sym.FullyConnected(data=data, name='fc1', num_hidden=128)
act1 = mx.sym.Activation(data=fc1, name='relu1', act_type="relu")
fc2 = mx.sym.FullyConnected(data=act1, name='fc2', num_hidden=64)
act2 = mx.sym.Activation(data=fc2, name='relu2', act_type="relu")
fc3 = mx.sym.FullyConnected(data=act2, name='fc3', num_hidden=10)
mlp = mx.sym.SoftmaxOutput(data=fc3, name='softmax')
shape = {"data": (batch_size, 1, 28, 28)}
# mx.viz.plot_network(symbol=mlp, shape=shape)
model = mx.model.FeedForward(
ctx=mx.gpu(0),
symbol=mlp,
num_epoch=4,
learning_rate=0.1
)
model.fit(
X=train_iter,
eval_data=val_iter,
batch_end_callback=mx.callback.Speedometer(batch_size, 200)
)
预测
最后最重要的部分来了,那就是预测:
predict_img = val_img[0].astype(np.float32).reshape((1, 1, 28, 28))/255.0
prob = model.predict(predict_img)[0]
print('prob:', prob)
print('Classified as {0} with probability {1}'.format(prob.argmax(), max(prob)))
输出结果如下:
Classified as 7 with probability 0.9959895014762878
说明我们的预测准确度还是非常高的啊!