mxnet开车教程series1-mnist上手

mxnet入门中文教程,让我们从mnist果蝇数据集开始开车 > 本文由中南大学较为牛逼的研究生金天同学原创,欢迎转载,但是请保留这段版权信息,如果你对文章有任何疑问,欢迎微信联系我:jintianiloveu。牛逼大神一一为你解答!

前言

最近搭建起了深度学习环境,mxnet被亚马逊钦定为官方的机器学习库,加上mxnet快速,代码清晰的特点,我赶紧乘上了mxnet快车,准备以mxnet为基础开始一些理论研究和产品实现。然而….mxnet搭建过程还是有点麻烦的,尤其是对于对编译过程不是非常熟悉的同学,这一点和caffe有点像,不过这不是问题,在本博客前面几篇文章对此有一个专门的教程,大家可以去看看,欢迎评论转载。这篇文章是mxnet开车教程的第一弹,让我们从果蝇数据集开始下手。

开车!

二话不多说,开始开车,作为一名深度学习老司机,我们应该要学会果蝇数据集的正确下载方式,我这里就不贴了,去Lecun的官网下载。下载之后解压,你将会看到四个文件:

t10k-images.idx3-ubyte
t10k-labels.idx1-ubyte
train-images.idx3-ubyte
train-labels.idx1-ubyte

这就是训练集和测试集的数据和标签,很多人一看不知道这是什么鬼,其实我也不知道这是什么鬼,反正是一种文件格式就对了。不多说了直接上代码,开车之前先导入包:

import struct
import numpy as np
import matplotlib.pyplot as plt
import mxnet as mx
import logging
logging.getLogger().setLevel(logging.DEBUG)

哪个包缺少安装哪个,玩mxnet你不要告诉我你还没有安装mxnet,快去我的另外一片博文看教程安装。

读取mnist数据集的正确姿势

接下来我有必要传授大家读取mnist数据集的正确方式了,网上流传的各种方法都是瞎扯淡,不懂得科学内涵(手动装逼)。正确的读取方式我谢了两个函数,一个读取label,一个读取image:

def read_mnist_label(file_name):
    bin_file = open(file_name, 'rb')
    magic, num = struct.unpack(">II", bin_file.read(8))
    label = np.fromstring(bin_file.read(), dtype=np.int8)
    return label

def read_mnist_image(file_name):
    bin_file = open(file_name, 'rb')
    magic, num, rows, cols = struct.unpack(">IIII", bin_file.read(16))
    image = np.fromstring(bin_file.read(), dtype=np.uint8).reshape(num, rows, cols)
    return image

将我们下载的文件传进去,就能得到label,images的输出,应该都是numpy.array的格式。

测试图片

我们写个显示图片的函数把:

def plot_image(image_array):
    plt.imshow(image_array, cmap='gray')
    plt.show()

输入图片矩阵,画出图片。

val_img = read_mnist_image('t10k-images.idx3-ubyte')
plot_image(val_img[0])

这就把测试集的第一张图片显示出来了。

搭建mxnet网络

这部分直接根据官网的来:

batch_size = 100
train_iter = mx.io.NDArrayIter(to4d(train_img), train_lbl, batch_size, shuffle=True)

val_iter = mx.io.NDArrayIter(to4d(val_img), val_lbl, batch_size)

data = mx.sym.Variable('data')
data = mx.sym.Flatten(data=data)

fc1 = mx.sym.FullyConnected(data=data, name='fc1', num_hidden=128)
act1 = mx.sym.Activation(data=fc1, name='relu1', act_type="relu")

fc2 = mx.sym.FullyConnected(data=act1, name='fc2', num_hidden=64)
act2 = mx.sym.Activation(data=fc2, name='relu2', act_type="relu")

fc3 = mx.sym.FullyConnected(data=act2, name='fc3', num_hidden=10)
mlp = mx.sym.SoftmaxOutput(data=fc3, name='softmax')

shape = {"data": (batch_size, 1, 28, 28)}
# mx.viz.plot_network(symbol=mlp, shape=shape)

model = mx.model.FeedForward(
    ctx=mx.gpu(0),
    symbol=mlp,
    num_epoch=4,
    learning_rate=0.1
)
model.fit(
    X=train_iter,
    eval_data=val_iter,
    batch_end_callback=mx.callback.Speedometer(batch_size, 200)
)

预测

最后最重要的部分来了,那就是预测:

predict_img = val_img[0].astype(np.float32).reshape((1, 1, 28, 28))/255.0
prob = model.predict(predict_img)[0]
print('prob:', prob)
print('Classified as {0} with probability {1}'.format(prob.argmax(), max(prob)))

输出结果如下:

Classified as 7 with probability 0.9959895014762878

说明我们的预测准确度还是非常高的啊!

Lewis Jin avatar
About Lewis Jin
Lewis Jin is a intelligent scientist, maybe he loves make funny AI program.
comments powered by Disqus