【TensorFlow 2.0教程】对影视评论进行文本分类

  • A+
所属分类:TensorFlow 2.0

本文将对电影评论文本进行分类,分为正面影评和负面影评,这是一个在机器学习问题中非常重要且常见的二分类问题。

本文演示使用TensorFlow Hub和Keras进行转移学习的基本应用。

我们将使用IMDB数据集,其中包含了来自互联网电影数据库的50,000篇电影评论的文本,它们被分成25000个训练评论文本和25000个测试评论文本。训练集和测试集中的评论类型是比较平衡的,这意味着它们包含相同数量的正面和负面评论。

我们同样使用keras高级API,用于在TensorFlow中构建和训练模型。TensorFlow Hub是一个用于转移学习的库和平台,我们将使用其中已经训练好的文本嵌入模型。

首先,导入本文将用到的python库:

下载IMDB数据集

TensorFlow datasets库提供了IMDB数据集,下面的代码使用datasets库下载该数据集:

输出如下:

展开

探索数据

首先,让我们花点时间来看看数据集的数据格式。每个样本都包含一段电影评论文本,以及相应的标签。电影评论文本没有经过任何预处理,标签为0或1的整数值,其中0表示负面评论,1表示正面评论。

让我们打印头10个样本看看:

展开

对面的头10个样本的标签:

构建模型

神经网络模型一般是由多个层叠加起来构成的,我们一般需要考虑如下三个主要因素来构建模型:

  • 如何表示文本?
  • 模型中应该使用多少层?
  • 每层应该有多少个隐藏单元(即神经元,也称为节点)?

在本例中,输入数据由一段文本构成的句子组成,要预测的标签是0或1。

表示文本的一种方法是将句子转换成嵌入向量。我们可以使用一个预先训练好的文本嵌入模型作为第一层。使用已训练好的文本嵌入模型有以下三个优点:

  • 我们不需要考虑文本预处理
  • 我们可以从转移学习中受益
  • 嵌入的大小是固定的,所以处理起来更简单。

在本例中,我们将使用一个来自TensorFlow Hub的预先训练好的文本嵌入模型,名为 google/tf2-preview/gnews-swivel-20dim/1

TensorFlow Hub中还有另外三个其他的训练好的模型,也可以用于本例的测试:

 

接下来,我们首先创建一个Keras层,我们使用TensorFlow Hub模型来进行文本嵌入,并对几个输入样本进行测试。注意,无论输入文本的长度如何,文本嵌入的输出形状都是固定的,大小为  。

现在,我们可以使用上面创建的层来构建完整的神经网络模型了:

可以看到,我们把几个层依次叠加起来,构成了我们的分类器模型:

  1. 第一层是一个TensorFlow Hub层。我们使用了一个预先训练好的被保存起来的模型,通过它将一个句子映射成一个文本嵌入向量。这个训练好的模型将句子分割成标记(token),然后嵌入每个标记,然后组合嵌入形成嵌入向量。结果的维度是:  。
  2. 第一层输出的固定长度的向量紧接着通过一个有16个隐藏单元的全连接(  )层。
  3. 最后一层是一个单节点输出层,使用  激活函数,这个值是一个介于0和1之间的浮点数,表示一个概率或置信级别。

 

接下来我们需要编译模型。

选择损失函数和优化器

模型需要指定一个损失函数和一个用于训练的优化器。由于这是一个二元分类问题,并且模型输出一个概率,所以我们将使用  损失函数。

这不是损失函数的唯一选择,例如,你可以选择  (均方误差)。但是,一般来说,  更适合处理概率,它测量概率分布之间的“差距”,在我们的例子中,它表示测量的真实分布和预测之间的“差距”。

当我们研究回归问题(例如,预测房价)时,我们可以使用均方误差损失函数。

现在,为我们的模型配置优化器和损失函数,同时指定检测参数:

训练模型

我们对模型进行20次迭代训练,也就是使用训练数据中的所有样本进行20次迭代。在训练过程中,对来自验证集的10000个样本进行验证,从而监测模型的损失和准确性:

展开

评估模型

使用  函数来验证模型对测试集的预测情况。该函数将返回两个值:损失值和准确率。

可以看到,使用这种相当简单的方法可以达到约87%的准确率。如果采用更高级的方法,模型的准确率应该可以接近95%。

延伸阅读

要了解处理文本输入的更一般的方法,以及在训练时对准确率和损失值的更详细分析过程,请看这里。本文完整代码请参考这里

weinxin
关注微信
如有疑问,欢迎扫一扫左侧二维码添加微信好友进行咨询,我会第一时间回复您!
yglong

发表评论

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen: