终于把机器学习中的混淆矩阵搞懂了!
混淆矩阵是用于评估分类模型性能的表格。它通过将实际(真实)标签与预测标签进行比较,提供分类问题的预测结果摘要。
True Positive (TP, 真阳性):实际为正类,预测也为正类的数量。
True Negative (TN, 真阴性):实际为负类,预测也为负类的数量。
False Positive (FP, 假阳性):实际为负类,预测却为正类的数量,通常称为"Type I 错误"或"误报"。
False Negative (FN, 假阴性):实际为正类,预测却为负类的数量,通常称为"Type II 错误"或"漏报"。
为什么要使用混淆矩阵?
混淆矩阵是评估分类模型性能的基本工具。
错误分析
它有助于识别模型所犯的错误类型,无论模型更容易出现假阳性还是假阴性,这在应用范围内(例如在医学诊断中)可能至关重要。
模型改进
通过分析混淆矩阵,你可以专注于改进模型的特定方面,例如减少误报或提高召回率。 类别不平衡处理
在类别不平衡的情况下,一个类别出现的频率高于另一个类别,单凭准确率可能会产生误导。 混淆矩阵可让你更好地了解模型在每个类别中的表现。 性能指标计算
分类中的评估指标
1.准确率
准确率是分类任务中最简单的评估指标之一,用来衡量模型预测正确的比例。
准确率的局限性
例如,在 95% 的样本属于同一类的数据集中,预测所有实例为多数类的模型的准确率为 95%,但在识别少数类时则无效。
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, accuracy_score
# Example true labels (ytest) and predicted labels (ypred)
ytest = [0, 1, 0, 1, 0, 1, 0, 0, 1, 1]
ypred = [0, 1, 0, 0, 0, 1, 0, 1, 1, 1]
# Calculate confusion matrix
cm = confusion_matrix(ytest, ypred)
# Create a heatmap
plt.figure(figsize=(6, 4))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False,
xticklabels=['1', '0'],
yticklabels=['1', '0'])
# Add labels and title
plt.xlabel('Predicted Classes')
plt.ylabel('Actual Classes')
plt.title('Confusion Matrix')
# Calculate and display accuracy
accuracy = accuracy_score(ytest, ypred)
plt.text(2.3, 1.5, f'Accuracy: {accuracy:.2f}', fontsize=14, color='black', weight='bold')
plt.show()
2.精度
import seaborn as sns
from sklearn.metrics import confusion_matrix, precision_score
# Example true labels (ytest) and predicted labels (ypred)
ytest = ['spam', 'spam', 'ham', 'spam', 'ham', 'spam', 'spam', 'ham', 'spam', 'spam', 'ham', 'spam', 'ham', 'ham', 'ham']
ypred = ['spam', 'spam', 'spam', 'spam', 'ham', 'spam', 'spam', 'ham', 'spam', 'spam', 'ham', 'ham', 'ham', 'ham', 'ham']
# Calculate the confusion matrix
cm = confusion_matrix(ytest, ypred, labels=['spam', 'ham'])
print("Confusion Matrix:\n", cm)
# Calculate precision
precision = precision_score(ytest, ypred, pos_label='spam')
print("Precision:", precision)
# Create a heatmap for the confusion matrix
plt.figure(figsize=(8, 6))
ax = sns.heatmap(cm, annot=True, fmt='d', cmap='viridis', cbar=False,
xticklabels=['Predicted Spam', 'Predicted Ham'],
yticklabels=['Actual Spam', 'Actual Ham'])
# Set labels and title
plt.xlabel('Predicted Classes')
plt.ylabel('Actual Classes')
plt.title(f'Confusion Matrix\nPrecision: {precision:.2f}')
# Show the plot
plt.show()
3.召回率
import seaborn as sns
from sklearn.metrics import confusion_matrix, recall_score
# Example true labels (ytest) and predicted labels (ypred)
ytest = ['positive', 'positive', 'negative', 'positive', 'negative']
ypred = ['positive', 'negative', 'negative', 'positive', 'positive']
# Calculate the confusion matrix
cm = confusion_matrix(ytest, ypred, labels=['positive', 'negative'])
# Calculate recall
recall = recall_score(ytest, ypred, pos_label='positive')
# Create a heatmap for the confusion matrix
plt.figure(figsize=(6, 4))
ax = sns.heatmap(cm, annot=True, fmt='d', cmap='viridis', cbar=False,
xticklabels=['Predicted Positive', 'Predicted Negative'],
yticklabels=['Actual Positive', 'Actual Negative'])
# Set labels and title
plt.xlabel('Predicted Classes')
plt.ylabel('Actual Classes')
plt.title(f'Confusion Matrix\nRecall: {recall:.2f}')
# Show the plot
plt.show()
4.F1-score
—
「进群方式:加我微信,备注 “python”」
往期回顾
Fashion-MNIST 服装图片分类-Pytorch实现