【TensorFlow 2.0教程】对结构化数据分类

  • A+
所属分类:TensorFlow 2.0

本教程将介绍如何对结构化数据进行分类,例如CSV中的表格数据。我们将使用Keras定义模型,并使用   作为桥梁,将CSV中的列映射到用于训练模型的特征。本教程将包含如下几方面的完整的代码演示:

  • 使用Pandas加载CSV数据
  • 构建一个输入管道(pipeline),使用  API对数据进行批处理和洗牌。
  • 使用  API将CSV中的列映射到用来训练模型的特征。
  • 使用Keras构建、训练和评估模型。

数据集介绍

我们将使用克利夫兰心脏病临床基金会提供的一个较小的数据集。CSV文件中有几百行,每一行描述一个病人,每一列为一个特征。我们将使用这些信息来预测患者是否患有心脏病,这是一个二元分类问题。

下面是对该数据集的描述。注意,有些列是数值型的,有些列是分类列。

ColumnDescriptionFeature TypeData Type
AgeAge in yearsNumericalinteger
Sex(1 = male; 0 = female)Categoricalinteger
CPChest pain type (0, 1, 2, 3, 4)Categoricalinteger
TrestbpdResting blood pressure (in mm Hg on admission to the hospital)Numericalinteger
CholSerum cholestoral in mg/dlNumericalinteger
FBS(fasting blood sugar > 120 mg/dl) (1 = true; 0 = false)Categoricalinteger
RestECGResting electrocardiographic results (0, 1, 2)Categoricalinteger
ThalachMaximum heart rate achievedNumericalinteger
ExangExercise induced angina (1 = yes; 0 = no)Categoricalinteger
OldpeakST depression induced by exercise relative to restNumericalinteger
SlopeThe slope of the peak exercise ST segmentNumericalfloat
CANumber of major vessels (0-3) colored by flourosopyNumericalinteger
Thal3 = normal; 6 = fixed defect; 7 = reversable defectCategoricalstring
TargetDiagnosis of heart disease (1 = true; 0 = false)Classificationinteger

导入TensorFlow和其他库

使用Pandas导入数据集

Pandas是一个Python库,有许多用于加载和处理结构化数据的实用工具。我们将使用Pandas从一个URL下载数据集,并将其加载到dataframe中。

 agesexcptrestbpscholfbsrestecgthalachexangoldpeakslopecathaltarget
063111452331215002.330fixed0
167141602860210811.523normal1
267141202290212912.622reversible0
337131302500018703.530normal0
441021302040217201.410normal0

将数据集拆分为训练、验证和测试集

我们下载的数据集在一个单一的CSV文件中,我们将把它们拆分为训练集,验证集和测试集。

使用tf.data创建输入管道

接下来,我们将用  包装dataframe。这将使我们能够使用TensorFlow的   作为桥梁,将Pandas的dataframe中的列映射到用于训练模型的特征。如果我们处理的是一个非常大的CSV文件(大到不能通过内存存储),我们将使用  直接从磁盘读取数据。本教程不讨论这一点。

将数据集拆分成小批次(5)来演示我们创建的  的数据集:

进一步了解输入管道

现在我们已经创建了输入管道,让我们调用它来查看它返回的数据的格式。我们使用了一个小的批次大小(5)来保持输出的可读性。

我们可以看到  数据集返回一个字典,key为列名(来自dataframe),值映射到dataframe中的所有行的列值。

演示几种类型的特征列(feature column)

TensorFlow提供了许多类型的特征列。在本节中,我们将创建几种类型的特征列,并演示它们如何从dataframe的列转换为TensorFlow的特征列。

我们将使用第一批次的训练数据进行演示:

下面的工具函数用于创建特征列,并转换一个批次的数据:

数值列(Numeric columns)

特征列的输出会作为模型的输入(使用上面定义的工具函数,我们将清楚地看到来自dataframe的每一列是如何转换成特征列的)。数值列(   )是最简单的特征列类型,它用于表示真实的数值特征。当使用这种特征列时,您的模型将原封不动地从dataframe接收列的值。

在本教程使用的心脏病数据集中,dataframe中的大多数列都是数值列。

桶列(Bucketized columns)

通常,您不希望将数值直接输入模型,而是根据数值范围将其值划分为不同的类别。考虑代表一个人年龄的数值数据,我们可以把年龄分成几个阶段,每个阶段称为一个桶,形成所谓的“桶列”(  )。

注意上面输出的是one-hot数组,数组中的每一行表示数据中某一行数据代表的某个人的年龄属于哪个年龄范围。

分类列(Categorical columns)

在这个数据集中,“thal”列的值是一个字符串(例如:“fixed”、“normal”或“reversible”)。我们不能将字符串直接提供给模型。相反,我们必须首先将它们转换为数值。分类词汇表列(categorical vocabulary columns)提供了一种将字符串表示为一个one-hot向量的方法(就像您在上面看到的年龄桶一样)。可以使用  ,将词汇表作为列表传递给该函数。也可以使用  从文件中加载词汇表。

在更复杂的数据集中,许多列都可能是类似的分类列。在处理这种类型的数据时,使用   API是最有效的。

嵌入列(Embedding columns)

假设不是只有几个可能的字符串,而是每个类别有数千个(或更多)值。由于许多原因,随着类别数量的增加,使用one-hot编码训练神经网络将变得不再可行。我们可以使用嵌入列来克服这个限制。嵌入列(  )不是将数据表示为有很多维度的one-hot向量,而是将该数据表示为一个低维度、密集的向量,其中每个单元格可以包含任意数字,而不仅仅是0或1。嵌入的大小(在下面的例子中是8)是一个必须调优的参数。

关键点:当分类列有比较多的可能值时,使用嵌入列是最好的选择。

我们在这里简单演示一下如何这种方法:

散列特征列(Hased feature columns)

另一种表示具有大量值的分类列的方法是使用  ,它计算输入的哈希值,然后选择一个合适的桶大小(  )对字符串进行编码。在使用这种类型的特征列时,您不需要提供词汇表,同时您可以选择让散列桶的数量比实际类别的数量小得多,从而节省空间。

关键点:这种技术的一个重要缺点是可能会有冲突,我们可能会遇到不同的字符串被映射到同一个散列桶的冲突。但实际上,即使存在这种冲突,对于某些数据集,这种方法依然表现得很好。

交叉特征列(Crossed feature columns)

将多个特征组合成一个特征,通常称为特征交叉(feature crosses)。模型能够为每个组合后的特征学习到单独的权重。在这里,我们将使用  创建一个新的特征,它是年龄和thal的组合特征。

注意 不会构建所有可能组合的完整表(它可能非常大)。相反,它实际上使用了  ,因此您可以选择表的大小。

选择合适的特征列

我们已经了解了如何使用几种常见类型的特征列。现在我们将用它们来训练一个模型。本教程的目标是向您展示处理特征列所需的完整代码,以及其机制。我们将随意选择一些列来训练我们的模型。

关键点:如果您的目标是构建一个精确的模型,那么你需要尝试更大的数据集,并仔细考虑哪些特征是最有意义的,以及它们应该被如何表示。

创建一个特征层

现在我们已经定义了特征列,接下来,我们将使用  创建一个层,该层将被输入到我们的模型。

在前面,我们使用了一个小的批次大小(5)来演示特征列是如何工作的。这里,我们将创建一个新的输入管道,具有更大批次大小(32)。

创建、编译和训练模型

接下来,我们创建一个模型,并编译和训练它:

关键点:通常情况下,深度学习在更大更复杂的数据集中才会得到的最佳结果。在处理像本教程的小数据集时,我们建议使用决策树或随机森林作为基线。本教程的目标不是训练一个精确的模型,而是演示处理结构化数据的机制,从而在将来处理自己的数据集时,可以使用这里的代码作为一个起点。

下一步

了解结构化数据分类的最佳方法是亲自尝试。我们建议您寻找另一个数据集,并使用类似于上面的代码训练一个模型,然后对其进行分类。为了提高精确度,请仔细考虑模型中应该包含哪些特征,以及它们应该被如何表示。

本教程的完整代码请参考这里

weinxin
关注微信
如有疑问,欢迎扫一扫左侧二维码添加微信好友进行咨询,我会第一时间回复您!
yglong

发表评论

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen: