【TensorFlow 2.0教程】初学者入门指南

  • A+
所属分类:TensorFlow 2.0

本文是TensorFlow 2.0入门示例,使用TensorFlow 2.0对MNIST手写数字进行识别,从而展示了基于TensorFlow 2.0进行开发的最简单的流程。

首先导入TensorFlow 2.0

加载并准备MNIST手写数字数据集,并将样本从整数转换为浮点数:

首次运行时会下载数据集,从输出的log中可以看到。当已经下载后,不会再重复下载。

该数据集中有60000个训练样本,10000个测试样本,每个样本都是28x28的二维数组,数组中每个数字代表图片的一个像素值。由于像素值最大为255,为了对数据归一化,我们对每个值除以255.

接下来,构建一个最基本的层叠模型(Sequential),并选择一个优化器和损失函数进行训练:

通过tf.keras.Sequential创建模型后,调用模型的compile函数编译模型,编译时指定优化器,损失函数,和监测的指标,这里只监测了准确率(Accuracy)。

模型创建并编译后,开始训练模型:

训练模型使用模型的fit函数,传入训练样本数据,并指定训练迭代次数,这里只迭代了5次,即对所有训练样本重复进行了5次训练。

训练完成后,我们可以使用测试数据对模型进行评估:

可以看到训练好的模型,对于测试数据达到了接近98%的准确率。

以上就是使用TensorFlow 2.0最简单的入门示例,也展示了使用TensorFlow 2.0进行开发的基本流程。TensorFlow 2.0使用了keras作为高阶API,相对于TensorFlow 1.x在编码以及开发效率上简化了很多。

本文完整源代码请参考这里

  • 微信
  • 如有疑问,请加个人微信联系
  • weinxin
  • 关注公众号:新码农客栈
  • 有趣的灵魂在等你
  • weinxin
yglong

发表评论

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen: