其他
讲透一个强大的算法模型,LSTM!!
LSTM 可以捕捉序列中的长期依赖关系,因此特别适合处理时间序列、自然语言处理等领域中的序列数据。
LSTM通过引入称为“门控机制”的结构来控制信息的流动,确保网络能够选择性地记住和遗忘信息。
LSTM的基本结构
1. 遗忘门
遗忘门决定前一时间步的记忆哪些需要被遗忘,哪些需要保留。
是遗忘门的输出。 是权重矩阵, 是偏置项。 是前一个时间步的隐藏状态, 是当前时间步的输入。 是 Sigmoid 函数,它将输出压缩到 (0,1) 区间。
2. 输入门(Input Gate)
输入门包括两个部分:
输入门 决定哪些信息会更新到记忆单元。 新的候选信息 被添加到记忆单元。
3. 输出门(Output Gate)
案例分享
数据准备
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import yfinance as yf
from sklearn.preprocessing import MinMaxScaler
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, LSTM
# 下载股票数据 (以苹果公司AAPL为例)
df = yf.download('AAPL', start='2010-01-01', end='2021-01-01')
data = df.filter(['Close'])
dataset = data.values
training_data_len = int(np.ceil( len(dataset) * .8 ))
scaler = MinMaxScaler(feature_range=(0,1))
scaled_data = scaler.fit_transform(dataset)
train_data = scaled_data[0:int(training_data_len), :]
x_train = []
y_train = []
for i in range(60, len(train_data)):
x_train.append(train_data[i-60:i, 0])
y_train.append(train_data[i, 0])
x_train, y_train = np.array(x_train), np.array(y_train)
x_train = np.reshape(x_train, (x_train.shape[0], x_train.shape[1], 1))
构建LSTM模型
# 创建LSTM模型
model = Sequential()
model.add(LSTM(units=50, return_sequences=True, input_shape=(x_train.shape[1], 1)))
model.add(LSTM(units=50, return_sequences=False))
model.add(Dense(units=25))
model.add(Dense(units=1))
# 编译模型
model.compile(optimizer='adam', loss='mean_squared_error')
# 训练模型
model.fit(x_train, y_train, batch_size=64, epochs=10)
创建测试数据集并进行预测
# 创建测试数据集
test_data = scaled_data[training_data_len - 60:, :]
x_test = []
y_test = dataset[training_data_len:, :]
for i in range(60, len(test_data)):
x_test.append(test_data[i-60:i, 0])
x_test = np.array(x_test)
x_test = np.reshape(x_test, (x_test.shape[0], x_test.shape[1], 1))
# 使用模型进行预测
predictions = model.predict(x_test)
# 将数据反归一化
predictions = scaler.inverse_transform(predictions)
可视化结果
最后,我们绘制实际股价和预测股价的图表进行对比。
# 绘制数据
train = data[:training_data_len]
valid = data[training_data_len:]
valid['Predictions'] = predictions
# 可视化数据
plt.figure(figsize=(16,8))
plt.title('Model')
plt.xlabel('Date')
plt.ylabel('Close Price USD ($)')
plt.plot(train['Close'])
plt.plot(valid[['Close', 'Predictions']])
plt.legend(['Train', 'Val', 'Predictions'], loc='lower right')
plt.show()
最后
—
「进群方式:加我微信,备注 “python”」
往期回顾
Fashion-MNIST 服装图片分类-Pytorch实现