查看原文
其他

第3.2节 逻辑回归(多分类任务)

空字符 月来客栈 2024-01-19

各位朋友大家好,欢迎来到月来客栈,我是掌柜空字符。

本期推送内容目录如下,如果你觉得本期内容对你所有帮助欢迎点个赞、关个注、下回更新不迷路。

  • 3.2 多分类任务
    • 3.2.1多分类逻辑回归
    • 3.2.2 多分类示例代码
    • 3.2.3 小结

3.2 多分类任务

3.2.1多分类逻辑回归

在3.1节中对于逻辑回归的介绍都仅仅局限在二分类任务中,但是在实际任务里,更多则是多分类的任务场景,也就是说最终的分类结果中类别数会大于2。对于这样的问题该如何解决呢?

通常情况下在用逻辑回归处理多分类任务时,都会采样一种称为One-vs-all(也叫作 One-vs-rest)的方法,两者的缩写分别为ova与ovr。这种策略的核心思想就是每次将其中一个类和剩余的其他类看作一个二分类任务进行训练,最后在预测过程中选择输出概率值最大那个类作为该样本点所属的类别。

如图3-5所示,此图为一个可视化的数据集,它一共包含3个类别。

图 3-5 多分类问题

当利用One-vs-all的分类思想来解决图3-5中的多分类问题时,可以可视化成如图3-6所示的情况。

图 3-6 Onevsall思想

在图3-6中,以从左往右的划分方式划分数据集,然后分别训练3个二分类的逻辑回归模型,分别表示样本属于第1、第2和第3共3个类别的概率,最后在预测的时候只要选择概率最大时分类模型所对应的类别即可。

3.2.2 多分类示例代码

在sklearn中,可以借助LogisticRegression类中的multi_class='ovr'参数来完成整个多分类的建模任务,完整代码见Book/Chapter03/02_one_vs_all_train.py文件。

1. 载入数据集

在这里,笔者同样使用了sklearn中内置的一个分类数据集iris进行示例。首先需要载入这个数据集,代码如下:

from sklearn.datasets import load_iris
def load_data():
    data = load_iris()
    x, y = data.data, data.target
    return x, y

iris数据集一共包含3个类别,每个类别中有50个样本,并且每个样本有4个特征维度。同时,sklearn中也内置了很多丰富的其他数据集来方便初学者使用,具体信息可以参见官网 [1]。

2. 训练模型

在数据集载入完成后,便可以通过sklearn中的LogisticRegression完成整个建模求解过程,代码如下:

def train(x, y): 
    model = LogisticRegression(multi_class='ovr')
    model.fit(x,y)
    print("得分: ", model.score(x, y))
# 得分:0.95

到此,对于多变量逻辑回归的分类方法与建模过程就介绍完了。不过细心的读者可能会发现,上面代码中的最后一行输出了一个0.95的得分,它表示什么含义呢?这里的0.95其实指的模型分类的准确率,意思是有95%的样本被模型正确分类了,具体计算原理可见3.3节内容。

3.2.3 小结

在本节内容中,笔者首先以图示的方式介绍了如何用Onevsall的思想来用逻辑回归模型解决多分类的任务场景,然后介绍了如何借助开源库sklearn来完成整个多分类任务的建模过程。接下来,我们将开始学习分类模型中的常见评估指标。

引用

[1] https://scikitlearn.org/stable

继续滑动看下一个

第3.2节 逻辑回归(多分类任务)

空字符 月来客栈
向上滑动看下一个

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存