深度学习100问-10:如何部署一个轻量级的深度学习项目?
深度学习100问
Author:louwill
Machine Learning Lab
无论是写论文还是打比赛,我们的深度学习模型始终局限于一种实验的状态。作为一名算法工程师,除了搭建网络和训练模型之外,将深度学习模型应用落地到生产环境也是我们应该要做的。虽然深度学习算法工程师不是后端开发工程师,但懂一些后端开发技术是尤为必要的。
本文以ResNet50预训练模型为例,旨在展示一个轻量级的深度学习模型部署,写一个较为简单的图像分类的REST API。主要技术框架为Keras+Flask+Redis。其中Keras作为模型框架、Flask作为后端Web框架、Redis则是方便以键值形式存储图像的数据库。各主要package版本:
tensorflow 1.14
keras 2.2.4
flask 1.1.1
redis 3.3.8
先简单说一下Web服务,一个Web应用的本质无非就是客户端发送一个HTTP请求,然后服务器收到请求后生成一个HTML文档作为响应返回给客户端的过程。在部署深度学习模型时,大多时候我们不需要搞一个前端页面出来,一般是以REST API的形式提供给开发调用。那么什么是API呢?很简单,如果一个URL返回的不是HTML,而是机器能直接解析的数据,这样的一个URL就可以看作是一个API。
先开启Redis服务:
redis-server
定义一些配置参数:
IMAGE_WIDTH = 224
IMAGE_HEIGHT = 224
IMAGE_CHANS = 3
IMAGE_DTYPE = "float32"
IMAGE_QUEUE = "image_queue"
BATCH_SIZE = 32
SERVER_SLEEP = 0.25
CLIENT_SLEEP = 0.25
指定输入图像大小、类型、batch_size大小以及Redis图像队列名称。
然后创建Flask对象实例,建立Redis数据库连接:
app = flask.Flask(__name__)
db = redis.StrictRedis(host="localhost", port=6379, db=0)
model = None
因为图像数据作为numpy数组不能直接存储到Redis中,所以图像存入到数据库之前需要将其序列化编码,从数据库取出时再将其反序列化解码即可。分别定义编码和解码函数:
def base64_encode_image(img):
return base64.b64encode(img).decode("utf-8")
def base64_decode_image(img, dtype, shape):
if sys.version_info.major == 3:
img = bytes(img, encoding="utf-8")
img = np.frombuffer(base64.decodebytes(img), dtype=dtype)
img = img.reshape(shape)
return img
另外待预测图像还需要进行简单的预处理,定义预处理函数如下:
def prepare_image(image, target):
# if the image mode is not RGB, convert it
if image.mode != "RGB":
image = image.convert("RGB")
# resize the input image and preprocess it
image = image.resize(target)
image = img_to_array(image)
# expand image as one batch like shape (1, c, w, h)
image = np.expand_dims(image, axis=0)
image = imagenet_utils.preprocess_input(image)
# return the processed image
return image
准备工作完毕之后,接下来就是主要的两大部分:模型预测部分和app后端相应部分。先定义模型预测函数如下:
def classify_process():
# 导入模型
print("* Loading model...")
model = ResNet50(weights="imagenet")
print("* Model loaded")
while True:
# 从数据库中创建预测图像队列
queue = db.lrange(IMAGE_QUEUE, 0, BATCH_SIZE - 1)
imageIDs = []
batch = None
# 遍历队列
for q in queue:
# 获取队列中的图像并反序列化解码
q = json.loads(q.decode("utf-8"))
image = base64_decode_image(q["image"], IMAGE_DTYPE,
(1, IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANS))
# 检查batch列表是否为空
if batch is None:
batch = image
# 合并batch
else:
batch = np.vstack([batch, image])
# 更新图像ID
imageIDs.append(q["id"])
if len(imageIDs) > 0:
print("* Batch size: {}".format(batch.shape))
preds = model.predict(batch)
results = imagenet_utils.decode_predictions(preds)
# 遍历图像ID和预测结果并打印
for (imageID, resultSet) in zip(imageIDs, results):
# initialize the list of output predictions
output = []
# loop over the results and add them to the list of
# output predictions
for (imagenetID, label, prob) in resultSet:
r = {"label": label, "probability": float(prob)}
output.append(r)
# 保存结果到数据库
db.set(imageID, json.dumps(output))
# 从队列中删除已预测过的图像
db.ltrim(IMAGE_QUEUE, len(imageIDs), -1)
time.sleep(SERVER_SLEEP)
然后定义app服务:
@app.route("/predict", methods=["POST"])
def predict():
# 初始化数据字典
data = {"success": False}
# 确保图像上传方式正确
if flask.request.method == "POST":
if flask.request.files.get("image"):
# 读取图像数据
image = flask.request.files["image"].read()
image = Image.open(io.BytesIO(image))
image = prepare_image(image, (IMAGE_WIDTH, IMAGE_HEIGHT))
# 将数组以C语言存储顺序存储
image = image.copy(order="C")
# 生成图像ID
k = str(uuid.uuid4())
d = {"id": k, "image": base64_encode_image(image)}
db.rpush(IMAGE_QUEUE, json.dumps(d))
# 运行服务
while True:
# 获取输出结果
output = db.get(k)
if output is not None:
output = output.decode("utf-8")
data["predictions"] = json.loads(output)
db.delete(k)
break
time.sleep(CLIENT_SLEEP)
data["success"] = True
return flask.jsonify(data)
Flask使用Python装饰器在内部自动将请求的URL和目标函数关联了起来,这样方便我们快速搭建一个Web服务。
服务搭建好了之后我们可以用一张图片来测试一下效果:
curl -X POST -F image=@test.jpg 'http://127.0.0.1:5000/predict'
模型端的返回:
预测结果返回:
最后我们可以给搭建好的服务进行一个压力测试,看看服务的并发等性能如何,定义一个压测文件stress_test.py 如下:
from threading import Thread
import requests
import time
# 请求的URL
KERAS_REST_API_URL = "http://127.0.0.1:5000/predict"
# 测试图片
IMAGE_PATH = "test.jpg"
# 并发数
NUM_REQUESTS = 500
# 请求间隔
SLEEP_COUNT = 0.05
def call_predict_endpoint(n):
# 上传图像
image = open(IMAGE_PATH, "rb").read()
payload = {"image": image}
# 提交请求
r = requests.post(KERAS_REST_API_URL, files=payload).json()
# 确认请求是否成功
if r["success"]:
print("[INFO] thread {} OK".format(n))
else:
print("[INFO] thread {} FAILED".format(n))
# 多线程进行
for i in range(0, NUM_REQUESTS):
# 创建线程来调用api
t = Thread(target=call_predict_endpoint, args=(i,))
t.daemon = True
t.start()
time.sleep(SLEEP_COUNT)
time.sleep(300)
测试效果如下:
有学术和技术问题的同学可以加我微信进入机器学习实验室读者交流群。加微信后说明来意,最好做个简单的自我介绍,让我有个印象。
参考资料:
往期精彩:
深度学习100问-9:为什么EfficientNet号称是最好的分类网络?
深度学习100问-8:什么是Batch Normalization?
深度学习100问-3:深度学习应掌握哪些Linux开发技术?
一个算法工程师的成长之路
长按二维码.关注机器学习实验室