查看原文
其他

实战:使用机器学习算法进行航班票价预测!

The following article is from 机器学习社区 Author 机器猫

大家好,实战是学习机器学习的最好方法。今天在本文中,我们将使用机器学习方法来对航班票价进行预测。为方便大家实操,文末提供完整版代码和数据。

关于数据集

数据来kaggle比赛数据,我首先对数据集的字段进行说明,方便后续分析和理解

  • Airline:所有类型的航空公司,例如 Indigo、Jet Airways、Air India
  • Date_of_Journey:乘客旅程的开始日期
  • Source:乘客旅程开始的地点名称
  • Destination:乘客想要前往的地点的名称
  • Route:乘客选择从他/她的来源到目的地的路线是什么
  • Arrival_Time:乘客到达目的地的时间
  • Duration:持续时间是航班完成从源头到目的地的旅程的整个时间
  • Total_Stops:整个旅程中将在多少地方停留
  • Additional_Info:获得有关食物、食物种类和其他便利设施的信息
  • Price:完整旅程的航班价格,包括登机前的所有费用

引入库

为方便后续航空价格做预测,我们先引入对应库

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error as mse
from sklearn.metrics import r2_score
from math import sqrt
from sklearn.linear_model import Ridge
from sklearn.linear_model import Lasso
from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import RandomForestRegressor
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import KFold
from sklearn.model_selection import train_test_split
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import RandomizedSearchCV
from prettytable import PrettyTable

读取数据

train_df = pd.read_excel("Data_Train.xlsx")
train_df.head(10)

探索性数据分析(EDA)

我们查看数据集所具有的列类型。

train_df.columns
#Output
Index(['Airline''Date_of_Journey''Source''Destination''Route',
       'Dep_Time''Arrival_Time''Duration''Total_Stops',
       'Additional_Info''Price'],
      dtype='object')

在这里,我们可以获得数据集的更多信息

train_df.info()

了解有关数据集的更多信息

train_df.describe()

现在在使用 IsNull 函数时,我们将看到数据集中空值的数量

train_df.isnull().head()

现在在使用 IsNull 函数和 sum 函数时,我们将看到数据集中空值的数量

train_df.isnull().sum()
#output
Airline            0
Date_of_Journey    0
Source             0
Destination        0
Route              1
Dep_Time           0
Arrival_Time       0
Duration           0
Total_Stops        1
Additional_Info    0
Price              0
dtype: int64

删除 NAN 值

train_df.dropna(inplace = True)

重复值

train_df[train_df.duplicated()].head()

在这里,我们将从数据集中删除那些重复的值,并保持原地属性为真,这样就不会发生任何变化。

train_df.drop_duplicates(keep='first',inplace=True)
train_df.head()

当前数据集数量

train_df.shape
#output
(1046211)

检查 Additional_info 列并计算唯一类型的值。

train_df["Additional_Info"].value_counts()
#output
No info                         8182
In-flight meal not included     1926
No check-in baggage included     318
1 Long layover                    19
Change airports                    7
Business class                     4
No Info                            3
1 Short layover                    1
2 Long layover                     1
Red-eye flight                     1
Name:
 Additional_Info, dtype: int64

检查不同的航空公司

train_df["Airline"].unique()
#output
array(['IndiGo''Air India''Jet Airways''SpiceJet',
       'Multiple carriers''GoAir''Vistara''Air Asia',
       'Vistara Premium economy''Jet Airways Business',
       'Multiple carriers Premium economy''Trujet'], dtype=object)

检查不同的航线

train_df["Route"].unique()
# output

现在让我们看看我们的测试数据集

test_df = pd.read_excel("Test_set.xlsx")
test_df.head(10)

测试数据所具有的列类型

test_df.columns
#output
Index(['Airline''Date_of_Journey''Source''Destination''Route',
       'Dep_Time''Arrival_Time''Duration''Total_Stops',
       'Additional_Info'],
      dtype='object')

有关数据集的信息

test_df.info()

了解有关测试数据集的更多信息

test_df.describe()

现在在使用 IsNull 函数和 sum 函数时,我们将看到测试数据中空值的数量

test_df.isnull().sum()
# output
Airline            0
Date_of_Journey    0
Source             0
Destination        0
Route              0
Dep_Time           0
Arrival_Time       0
Duration           0
Total_Stops        0
Additional_Info    0
dtype: int64

绘制价格(Price)与航空公司(Airline)图在猫图的帮助下,我们试图绘制航班价格和航空公司价格之间的箱线图,我们可以得出结论,Jet Airways 在价格方面的异常值最多。

绘制价格与来源的小提琴图

sns.catplot(y = "Price", x = "Source", data = train_df.sort_values("Price", ascending = False), kind="violin", height = 4, aspect = 3)
plt.show()

现在仅借助猫图,我们在航班价格和源地之间绘制箱线图,即乘客将从哪里前往目的地,我们可以看到作为源地的班格罗尔(Banglore)拥有最多异常值,而金奈(Chennai)最少。

绘制价格与目的地的箱线图

sns.catplot(y = "Price", x = "Destination", data = train_df.sort_values("Price", ascending = False), kind="box", height = 4, aspect = 3)
plt.show()

我们在航班价格和乘客旅行目的地之间的猫图的帮助下绘制箱线图,并发现新德里(New Delhi)的异常值最多,加尔各答(Kolkata)的异常值最少。

特征工程

先看看数据

train_df.head()

在这里,我们首先划分特征和标签,然后将小时转换为分钟。

train_df['Duration'] = train_df['Duration'].str.replace("h"'*60').str.replace(' ','+').str.replace('m','*1').apply(eval)
test_df['Duration'] = test_df['Duration'].str.replace("h"'*60').str.replace(' ','+').str.replace('m','*1').apply(eval)

Date_of_Journey:标准化旅程日期的格式,以便在模型阶段进行更好的预处理。

train_df["Journey_day"] = train_df['Date_of_Journey'].str.split('/').str[0].astype(int)
train_df["Journey_month"] = train_df['Date_of_Journey'].str.split('/').str[1].astype(int)
train_df.drop(["Date_of_Journey"], axis = 1, inplace = True)

Dep_Time:将出发时间转换为小时和分钟

train_df["Dep_hour"] = pd.to_datetime(train_df["Dep_Time"]).dt.hour
train_df["Dep_min"] = pd.to_datetime(train_df["Dep_Time"]).dt.minute
train_df.drop(["Dep_Time"], axis = 1, inplace = True)

Arrival_Time:将到达时间转换为小时和分钟。

train_df["Arrival_hour"] = pd.to_datetime(train_df.Arrival_Time).dt.hour
train_df["Arrival_min"] = pd.to_datetime(train_df.Arrival_Time).dt.minute
train_df.drop(["Arrival_Time"], axis = 1, inplace = True)

在最后的预处理之后,让我们看看数据集绘制月份(持续时间)与航班数量的条形图

plt.figure(figsize = (105))
plt.title('Count of flights month wise')
ax=sns.countplot(x = 'Journey_month', data = train_df)
plt.xlabel('Month')
plt.ylabel('Count of flights')
for p in ax.patches:
    ax.annotate(int(p.get_height()), (p.get_x()+0.25, p.get_height()+1), va='bottom', color= 'black')

在上图中,我们绘制了一个月旅程与几个航班的计数图,并看到五月的航班数量最多。

绘制航空公司类型与航班数量的条形图

plt.figure(figsize = (20,5))
plt.title('Count of flights with different Airlines')
ax=sns.countplot(x = 'Airline', data =train_df)
plt.xlabel('Airline')
plt.ylabel('Count of flights')
plt.xticks(rotation = 45)
for p in ax.patches:
    ax.annotate(int(p.get_height()), (p.get_x()+0.25, p.get_height()+1), va='bottom', color= 'black')

从上图中我们可以看到,在航空公司类型和航班数量之间,我们可以看到 Jet Airways 登机的航班最多。

绘制机票价格 VS 航空公司

plt.figure(figsize = (15,4))
plt.title('Price VS Airlines')
plt.scatter(train_df['Airline'], train_df['Price'])
plt.xticks
plt.xlabel('Airline')
plt.ylabel('Price of ticket')
plt.xticks(rotation = 90)

所有特征之间的相关性

绘制相关性

plt.figure(figsize = (15,15))
sns.heatmap(train_df.corr(), annot = True, cmap = "RdYlGn")
plt.show()

删除价格列,因为它没有用

data = train_df.drop(["Price"], axis=1)

处理分类数据和数值数据

train_categorical_data = data.select_dtypes(exclude=['int64''float','int32'])
train_numerical_data = data.select_dtypes(include=['int64''float','int32'])

test_categorical_data = test_df.select_dtypes(exclude=['int64''float','int32','int32'])
test_numerical_data  = test_df.select_dtypes(include=['int64''float','int32'])
train_categorical_data.head()

分类列的标签编码和热编码

le = LabelEncoder()
train_categorical_data = train_categorical_data.apply(LabelEncoder().fit_transform)
test_categorical_data = test_categorical_data.apply(LabelEncoder().fit_transform)
train_categorical_data.head()

连接分类数据和数值数据

X = pd.concat([train_categorical_data, train_numerical_data], axis=1)
y = train_df['Price']
test_set = pd.concat([test_categorical_data, test_numerical_data], axis=1)
X.head()

构建模型

1、构建评估模型的平均绝对百分比误差
# Calculating Mean Absolute Percentage Error
def mean_absolute_percentage_error(y_true, y_pred): 
    y_true, y_pred = np.array(y_true), np.array(y_pred)
    return np.mean(np.abs((y_true - y_pred) / y_true)) * 100
2、切分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.3, random_state = 42)

数据统计如下

3、模型训练整个流程

Ridge RegressionLasso Regression 算法整个流程决策树算法整体流程

比较所有模型
# Comparing all the models
models = pd.DataFrame({
    'Model': [ 'Ridge Regression''Lasso Regression','Decision Tree Regressor'],
    'Score': [ ridge_score, lasso_score, decision_score],
    'Test Score': [ ridge_score_test, lasso_score_test, decision_score_test]})
models.sort_values(by='Test Score', ascending=False)

通过比较所有模型,我们可以得出结论,决策树回归和随机森林回归表现最好。

结论

正如我们所看到的,我们已经完成了一个整个模型开发流程,包括数据洞察、特征工程和数据可视化、用机器学习模型制作步骤进行预测等。当然你也可以用更复杂的模型来做,文章涉及的数据,我也会提供给大家。

代码、数据获取
  • 1. 关注下方公众号,点击右上角;

  • 2. 在下方后台回复关键词「航班」快速下载:

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存