当前位置:  首页>> 技术小册>> Python机器学习基础教程(上)

1.7 第一个应用:鸢尾花分类

在机器学习领域,鸢尾花(Iris)分类问题是一个经典的入门级案例,它以其简单而富有教育意义的特点,成为了学习分类算法的首选。在本章中,我们将通过Python及其强大的机器学习库scikit-learn,来构建并训练一个模型,用于预测鸢尾花的种类。这不仅能帮助你理解机器学习的基本流程,还能掌握数据预处理、模型选择、训练及评估等关键步骤。

1.7.1 引言

鸢尾花数据集(Iris dataset)是统计学和机器学习中常用的数据集之一,由R.A. Fisher于1936年收集。该数据集包含了150个样本,分别属于三种不同的鸢尾花种类:Setosa、Versicolour和Virginica。每个样本有四个特征:萼片长度(sepal length)、萼片宽度(sepal width)、花瓣长度(petal length)和花瓣宽度(petal width),所有这些特征均为连续值。

1.7.2 数据准备

在开始建模之前,首先需要加载并探索数据。scikit-learn库中已经内置了鸢尾花数据集,我们可以直接调用它来获取数据。

  1. from sklearn.datasets import load_iris
  2. from sklearn.model_selection import train_test_split
  3. # 加载数据
  4. iris = load_iris()
  5. X = iris.data # 特征数据
  6. y = iris.target # 目标标签(0: Setosa, 1: Versicolour, 2: Virginica)
  7. # 划分训练集和测试集
  8. X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

1.7.3 数据探索

数据探索是理解数据特性、发现潜在问题的重要步骤。通过简单的统计分析和可视化,我们可以获得对数据集的初步认识。

  1. import pandas as pd
  2. import seaborn as sns
  3. import matplotlib.pyplot as plt
  4. # 将数据转换为DataFrame以便于分析
  5. df = pd.DataFrame(X, columns=iris.feature_names)
  6. df['species'] = iris.target_names[y]
  7. # 描述性统计
  8. print(df.describe())
  9. # 绘制特征分布图
  10. sns.pairplot(df, hue='species')
  11. plt.show()

通过上述代码,我们可以观察到不同种类鸢尾花在特征上的分布差异,这为后续选择模型提供了直观的依据。

1.7.4 选择模型

在鸢尾花分类问题中,由于数据集相对较小且特征维度不高,我们可以选择多种分类算法进行尝试。这里,我们以逻辑回归(Logistic Regression)、决策树(Decision Tree)和K近邻(K-Nearest Neighbors, KNN)为例,展示不同算法在相同数据集上的表现。

逻辑回归

  1. from sklearn.linear_model import LogisticRegression
  2. # 创建逻辑回归模型
  3. lr = LogisticRegression(max_iter=200)
  4. lr.fit(X_train, y_train)
  5. # 预测测试集
  6. y_pred_lr = lr.predict(X_test)

决策树

  1. from sklearn.tree import DecisionTreeClassifier
  2. # 创建决策树模型
  3. dt = DecisionTreeClassifier()
  4. dt.fit(X_train, y_train)
  5. # 预测测试集
  6. y_pred_dt = dt.predict(X_test)

K近邻

  1. from sklearn.neighbors import KNeighborsClassifier
  2. # 创建K近邻模型
  3. knn = KNeighborsClassifier(n_neighbors=3)
  4. knn.fit(X_train, y_train)
  5. # 预测测试集
  6. y_pred_knn = knn.predict(X_test)

1.7.5 模型评估

模型评估是判断模型好坏的关键步骤。常用的评估指标包括准确率(Accuracy)、精确率(Precision)、召回率(Recall)和F1分数(F1 Score)等。对于鸢尾花分类这种多分类问题,我们主要关注准确率。

  1. from sklearn.metrics import accuracy_score
  2. # 计算各模型准确率
  3. accuracy_lr = accuracy_score(y_test, y_pred_lr)
  4. accuracy_dt = accuracy_score(y_test, y_pred_dt)
  5. accuracy_knn = accuracy_score(y_test, y_pred_knn)
  6. print(f"逻辑回归准确率: {accuracy_lr:.2f}")
  7. print(f"决策树准确率: {accuracy_dt:.2f}")
  8. print(f"K近邻准确率: {accuracy_knn:.2f}")

1.7.6 结果分析与讨论

通过上述实验,我们可以发现不同模型在鸢尾花数据集上的表现存在差异。逻辑回归虽然简单,但在该数据集上表现不俗;决策树由于其强大的非线性拟合能力,往往能取得较高的准确率;而K近邻的准确率则受到K值选择的影响,需要通过交叉验证等方式来确定最优K值。

此外,我们还可以进一步探讨模型的过拟合与欠拟合问题,通过调整模型参数(如决策树的深度、逻辑回归的正则化强度等)或使用更复杂的模型(如随机森林、梯度提升树等)来改进模型性能。

1.7.7 结论

通过本章的学习,我们不仅掌握了使用Python和scikit-learn进行机器学习项目的基本流程,还通过鸢尾花分类这一具体案例,深入理解了数据预处理、模型选择、训练及评估等关键环节。更重要的是,我们学会了如何根据数据集的特点选择合适的模型,并通过实验来验证模型的有效性。这些知识和技能将为你后续深入学习机器学习打下坚实的基础。


该分类下的相关小册推荐: