如何用 Python 实现堆和优先队列?
The following article is from 古明地觉的编程教室 Author 古明地觉
楔子
在工作中我们经常会遇到这样一个需求,就是获取可迭代对象中的前 K 个最大或最小的元素。我们之前介绍过排序,所以一个最简单的办法就是先排序,排完了再选择前 K 个元素即可。
data = [3, 1, 2, 5, 4]
# 选择前 3 个最大的元素
# 和前 3 个最小的元素
sorted_data = sorted(data)
print(f"TOP 3 MAX:{sorted_data[-3:]}")
print(f"TOP 3 MIN:{sorted_data[: 3]}")
"""
TOP 3 MAX:[3, 4, 5]
TOP 3 MIN:[1, 2, 3]
"""
# 或者这么做
sorted_data = sorted(data, reverse=True)
print(f"TOP 3 MAX:{sorted_data[: 3]}")
print(f"TOP 3 MIN:{sorted_data[-3:]}")
"""
TOP 3 MAX:[5, 4, 3]
TOP 3 MIN:[3, 2, 1]
"""
显然这是一种解决办法,但如果列表的长度非常大,排序就会带来不小的开销。而且有时我们只要前几个元素即可,比如长度为 10000 的列表,我们只想要前三个最大或最小的元素,那么此时对整个列表进行排序显然会存在性能上的浪费。
sorted 函数的时间复杂度是 O(NlogN)
所以接下来我们要介绍一个模块叫 heapq,通过该模块我们能快速地获取前 K 个元素。
import random
import heapq
data = [random.randint(10, 10000)
for _ in range(1000)]
# 获取前 3 个最大的元素
print(heapq.nlargest(3, data))
print(sorted(data, reverse=True)[: 3])
"""
[9972, 9972, 9966]
[9972, 9972, 9966]
"""
# 获取前 3 个最小的元素
print(heapq.nsmallest(3, data))
print(sorted(data)[: 3])
"""
[17, 21, 31]
[17, 21, 31]
"""
得到的结果是一样的,但是性能差异如何呢?我们来测一下:
可以看到性能差异还是蛮大的,并且列表长度越大,性能差距越明显。而根本原因就在于 sorted 会对列表进行全局排序,而 heapq 没有。
因此在获取前 K 个元素、并且 K 和列表长度差距比较大的时候,不妨使用 heapq 的 nsmallest 和 nlargest 函数,性能会有明显提升。
但如果 K 和列表长度相差不大,那么先 sorted 排序,再使用切片的方式会更好一些。
nsmallest 和 nlargest 这两个函数都接收 3 个参数,第一个参数表示要获取前多少个元素、第二个参数表示可迭代对象(一般是列表)、第三个参数是 key(和 sorted 函数里面的 key 含义相同)。
import random
import heapq
data = [{"number": random.randint(1, 10000)}
for _ in range(1000)]
# data 内部都是字典
# 获取前 3 个 number 字段的值最大的字典
print(heapq.nlargest(3, data, key=lambda x: x["number"]))
"""
[{'number': 9991}, {'number': 9984}, {'number': 9970}]
"""
print(sorted(data, key=lambda x: x["number"], reverse=True)[: 3])
"""
[{'number': 9991}, {'number': 9984}, {'number': 9970}]
"""
特别提示,如果 K 为 1,那么使用内置函数 min 和 max 是最佳选择。
data = [{"number": random.randint(1, 10000)}
for _ in range(1000)]
print(
heapq.nlargest(1, data, key=lambda x: x["number"])
) # [{'number': 9979}]
print(
sorted(data, key=lambda x: x["number"])[-1]
) # {'number': 9979}
print(
max(data, key=lambda x: x["number"])
) # {'number': 9979}
所以结论如下,当获取最大或最小元素的个数为 K,列表(可迭代对象)长度为 L 时:
K 等于 1,使用内置函数 max 或 min;
K 不等于 1、且远小于 L,使用 heapq 模块的 nlargest 或 nsmallest 函数;
K 和 L 差别不大,使用 sorted 先全局排序、然后再通过切片方式截取;
当然啦,以上都属于基础知识,比较简单。其实选择前 K 个元素就是我们常说的 TOP K 问题,如果只是单纯地想解决 TOP K 问题的话,上面已经给出了方案。这里我主要是想通过 TOP K 来引出一种数据结构,也就是堆。
堆是一种非常高效的数据结构,我们可以用它实现优先队列,堆实现的优先队列在元素入队、出队的时间复杂度上均为 O(logN)。
什么是堆?
首先堆本身是一棵树,如果这棵树是二叉树,那么实现的堆就被称为二叉堆。当然除了二叉堆,还有三叉堆等等,只不过二叉堆是一种最主流的堆的实现方式。因此,堆(二叉堆)就是一棵满足一些特殊性质的二叉树,那么问题来了,它都满足哪些性质呢?
堆是一棵完全二叉树;
堆里的每一个节点都大于等于(或小于等于)它的孩子节点;
如果每个节点都大于等于它的孩子节点,或者说每个节点都不大于它的父节点,那么这个堆就是大根堆;
如果每个节点都小于等于它的孩子节点,或者说每个节点都不小于它的父节点,那么这个堆就是小根堆;
注意:堆要求的是每个节点和其孩子节点之间要满足相应的大小关系,如果两个节点之间没有父子关系,那么它们谁大谁小无关紧要。比如图上的大根堆,第三层的最后一个节点是 13,可第四层的节点却都比它大,但它们之间没有父子关系,所以我们当前这个堆是成立的。
正是因为堆的这个性质,我们可以使用数组来表示堆,直接按照层序遍历的方式将每一层的元素放在数组中即可,比如:
[62, 41, 30, 28, 16, 22, 13, 19, 17, 15]
很明显,堆顶(数组索引为 0)的元素永远是值最大或最小的元素,如果构建的是大根堆,堆顶元素最大;构建的是小根堆,堆顶元素最小。
但是问题来了,如果我有一个节点,要如何找到它的父节点或者孩子节点呢?结论如下,假设当前节点所在的索引为 i:
父节点的索引:(i - 1) / 2;
左孩子节点的索引:2 * i + 1;
右孩子节点的索引:2 * i + 2;
我们以索引为 3 这个元素(值为 28)为例,它父节点的索引是 (3 - 1) / 2 = 1,也就是 41 这个元素;左孩子节点的索引是 2 * 3 + 1 = 7,也就是 19 这个元素;右孩子节点的索引显然是 8,也就是 17 这个元素。可以对照上图,检验一下是否有误,或者你也可以创建一个更大的堆,自己测试一下,但前提必须是完全二叉树才具备这个性质。
显然通过这种方式,我们就不需要两个指针来维持节点之间的父子关系了,使用数组索引即可,并且通过索引定位元素的速度也会更快。
向堆中添加元素(Sift Up)
我们来看看如何往堆中添加元素,首先堆是一个完全二叉树,往堆中添加一个元素,从树的层面来看,就是往最后一层的最右端添加一个元素,如果最后一层已经满了,那么就新加一层。如果从数组的层面来看,就相当于 append 一个元素。
假设我们添加一个 52,那么堆的示意图就会变成如下这样:
添加的过程非常简单,因为往堆里面添加一个节点,就是往数组里面 append 一个元素,但显然还没有结束。因为堆有两个性质,虽然我们添加元素之后仍然满足是一棵完全二叉树,但是不满足子节点都不大于它的父节点(这里我们构建的是大根堆),因为 52 明显大于它的父节点 16。
所以我们还要进行调整,将新添加的节点放到属于它的位置,具体过程也很简单:将该节点和它的父节点进行比较,如果比它的父节点大,那么就进行交换;交换之后再和它新的父节点进行比较,如果还大于新的父节点则继续交换,直到不大于为止。
所以从尾部添加的节点,一直向上浮动,直到找到属于它的位置,因此这个过程也被称为 Sift Up(上浮),具体示意图如下:
当交换之后,发现不大于它的父节点,那么该节点就可以停下来了。可能有人问,它父节点之上的节点该怎么办?比如爷爷节点。答案是不需要关心,因为大根堆的性质就是每个子节点不大于父节点。所以当新添加的节点不大于它的父节点时,也更不可能大于父节点之上的爷爷节点。
下面我们就编写代码实现一下:
class BinaryHeap:
"""
大根堆
"""
def __init__(self):
# 通过数组来模拟堆,为避免直接修改堆
# 这个数组不对外暴露,而是专门提供一个接口
self.__data = []
def show_heap(self):
return self.__data
@staticmethod
def get_parent(i: int):
# 根据节点的索引找到其父节点的索引
return (i - 1) // 2
def heappush(self, item: int):
# 往堆中添加一个节点,对于数组而言,直接 append 即可
self.__data.append(item)
# 但是还没有结束,添加完之后不满足堆的性质
# 我们还要对堆进行调整,由 sift_up 函数负责,它接收一个索引
# 表示对指定索引的节点进行上浮,显然这里是最后一个
self.sift_up(len(self.__data) - 1)
def sift_up(self, i: int):
# 对指定索引位置的节点进行上浮
while i > 0:
parent = self.get_parent(i)
# 当该元素不是根节点的时候,将其和父节点进行比较
# 如果大于父节点,两者进行交换
if self.__data[i] > self.__data[parent]:
self.__data[i], self.__data[parent] = self.__data[parent], self.__data[i]
# 交换之后该节点成为了父节点,然后将 parent 赋值为 i
# 因为它还要继续作为新的子节点和新的父节点比较
i = parent
else:
# 如果不大于父节点,说明该元素已经找到属于它的位置了
# 直接将循环结束掉即可
break
heap = BinaryHeap()
for item in [62, 41, 30, 28, 16, 22, 13, 19, 17, 15]:
heap.heappush(item)
print(heap.show_heap())
"""
[62, 41, 30, 28, 16, 22, 13, 19, 17, 15]
62
41 30
28 16 22 13
19 17 15
"""
# 这个时候再添加一个元素 52
heap.heappush(52)
print(heap.show_heap())
"""
[62, 52, 30, 28, 41, 22, 13, 19, 17, 15, 16]
62
52 30
28 41 22 13
19 17 15 16
"""
可以看到结果是没有问题的,以上我们添加元素就成功了,下面我们再来看看如何从堆中取出元素。
从堆中取出元素(Sift Down)
正如添加节点从堆底添加,取出节点只能从堆顶取出(也就是只能取根节点),不能取其它位置的节点。
但问题是,如果直接将堆顶的节点取走的话,就会形成两个独立的堆,两个堆的根节点分别是它的左右节点。于是我们还要手动将两个堆合并在一起,会比较麻烦,所以我们可以换个思路,将堆顶和堆底的元素进行交换。交换之后,弹出堆底的元素,这样就得到了最大值。
但该做法同时也破坏了堆的第二个性质,因为之前的堆底元素现在跑到了堆顶,肯定不满足父节点和子节点之间的大小关系,所以我们还要进行调整。
对于大根堆而言,将该节点和左右子节点中大的那一个进行比较,如果比子节点小,那么进行交换。交换之后再和它新的子节点进行比较,如果还小于新的子节点则继续交换,直到不小于为止。
所以从顶部的节点,一直向下沉,直到找到属于它的位置,因此这个过程也被称为 Sift Down(下沉),具体示意图如下:
注意:堆顶节点和堆底节点交换之后,就被弹出了,所以图中的 62 不再是堆节点,因此我们颜色刻意画的淡了一些。
下面完善一下之前的代码:
import random
class BinaryHeap:
"""
大根堆
"""
def __init__(self):
# 通过数组来模拟堆,为避免直接修改堆
# 这个数组不对外暴露,而是专门提供一个接口
self.__data = []
def show_heap(self):
return self.__data
@staticmethod
def get_parent(i: int):
# 根据节点的索引找到其父节点的索引
return (i - 1) // 2
@staticmethod
def get_left_child(i: int):
# 根据节点的索引找到左孩子节点的索引
return 2 * i + 1
@staticmethod
def get_right_child(i: int):
# 根据节点的索引找到右孩子节点的索引
return 2 * i + 2
def heappush(self, item: int):
self.__data.append(item)
self.sift_up(len(self.__data) - 1)
def sift_up(self, i: int):
while i > 0:
parent = self.get_parent(i)
if self.__data[i] > self.__data[parent]:
self.__data[i], self.__data[parent] = self.__data[parent], self.__data[i]
i = parent
else:
break
def heappop(self):
# 弹出堆顶元素
if len(self.__data) == 0:
raise ValueError("heap is empty")
# 只需要将第一个元素和最后一个元素进行交换,然后返回即可
self.__data[0], self.__data[-1] = self.__data[-1], self.__data[0]
result = self.__data.pop()
# 不过在返回之前,记得调整一下堆,由 sift_down 函数负责
# 此函数接收一个索引,表示对指定节点的索引进行下沉
# 显然这里是第一个
self.sift_down(0)
return result
def sift_down(self, i: int):
# 对索引为 i 的节点进行下沉,这里需要判断孩子节点是否存在的情况
# 如果左孩子节点的索引越界,说明该节点已经是叶子节点了
while self.get_left_child(i) < len(self.__data):
left_child = self.get_left_child(i)
right_child = self.get_right_child(i)
# 获取子节点大的那一个,注意:需要考虑右节点是否存在的情况
child = (right_child
if right_child < len(self.__data) and
self.__data[left_child] < self.__data[right_child]
else left_child)
# 将该节点和孩子节点进行比较,如果比孩子节点小,那么交换位置
# 继续和新的孩子节点进行比较
if self.__data[i] < self.__data[child]:
self.__data[i], self.__data[child] = self.__data[child], self.__data[i]
i = child
# 否则直接跳出循环
else:
break
heap = BinaryHeap()
data = [random.randint(1, 20) for _ in range(10)]
print(data)
"""
[3, 9, 3, 12, 12, 14, 5, 18, 20, 11]
"""
# 依次添加到堆中
for item in data:
heap.heappush(item)
# 从堆中弹出,由于每次都会弹出最大值
# 所以得到的新列表是降序排序的
sorted_data = [heap.heappop() for _ in range(10)]
print(sorted_data)
"""
[20, 18, 14, 12, 12, 11, 9, 5, 3, 3]
"""
显然是没有问题的,因此我们这里就实现了一个堆排序,只不过这个堆排序还不太完美,不完美之处有两个地方:
1. 默认是从大到小排序的,应该提供一个参数供外界选择究竟是从大到小还是从小到大;
2. 这里开辟了一个额外的数组,合适的做法应该是接收一个数组,然后原地排序;
那么下面我们完善一下堆排序。
import random
def get_left_child(i: int):
return 2 * i + 1
def get_right_child(i: int):
return 2 * i + 2
def sift_down_large(data, i: int, length: int):
# 大根堆下沉,但是参数多了一个 length,这是为啥呢?
# 首先我们之前是将堆顶和堆底的元素交换之后,就将堆底的元素弹出去了
# 以至于我们需要单独开辟一个数组去接收
# 但很明显,我们这里要求原地排序,那么交换之后的元素在堆底不可以动
# 因此每 sift_down 一次,length 要减去 1
while get_left_child(i) < length:
left_child = get_left_child(i)
right_child = get_right_child(i)
# 判断是否有右孩子,如果有右孩子
# 那么选择值较大的那一个孩子节点
child = (right_child
if right_child < length and
data[left_child] < data[right_child]
else left_child)
# 如果比孩子节点的值小,那么两者进行交换
# 因为大根堆要求父节点不小于子节点
if data[i] < data[child]:
data[i], data[child] = data[child], data[i]
i = child
else:
break
def sift_down_small(data, i: int, length: int):
# 小根堆下沉
while get_left_child(i) < length:
left_child = get_left_child(i)
right_child = get_right_child(i)
# 判断是否有右孩子,如果有右孩子
# 那么选择值较小的那一个孩子节点
child = (right_child
if right_child < length and
data[left_child] > data[right_child]
else left_child)
# 如果比孩子节点的值大,那么两者进行交换
# 因为大根堆要求父节点不大于子节点
if data[i] > data[child]:
data[i], data[child] = data[child], data[i]
i = child
else:
break
def heapify_large(data):
# 将一个数组整理成大根堆的形状
# 从最后一个非叶子节点进行 sift_down 即可
for i in range((len(data) - 1) >> 1, -1, -1):
sift_down_large(data, i, len(data))
def heapify_small(data):
# 将一个数组整理成小根堆的形状
for i in range((len(data) - 1) >> 1, -1, -1):
sift_down_small(data, i, len(data))
def heap_sort(data, reverse=False):
# 堆排序
# 首先将其整理成堆的形状
if reverse:
heapify_small(data)
else:
heapify_large(data)
# i 从最后一个元素开始
for i in range(len(data) - 1, -1, -1):
# 交换完之后的元素就不可以动了
data[0], data[i] = data[i], data[0]
# 并且也不能再参与后续的 sift_down
# 因此依旧调整堆,但是范围变了
# 比如第一次交换,那么最后一个元素为最大值
# sift_down 的时候,整个范围就是 [0: len(data) - 1]
# 同理第二次 sift_down 的时候,范围就是 [0: len(data) - 2]
if reverse:
sift_down_small(data, 0, i)
else:
sift_down_large(data, 0, i)
data = [random.randint(1, 20) for _ in range(10)]
print(data)
"""
[17, 16, 10, 3, 13, 15, 11, 9, 12, 9]
"""
heap_sort(data)
print(data)
"""
[3, 9, 9, 10, 11, 12, 13, 15, 16, 17]
"""
data = [random.randint(1, 20) for _ in range(10)]
print(data)
"""
[3, 1, 14, 20, 1, 10, 7, 8, 3, 15]
"""
heap_sort(data, reverse=True)
print(data)
"""
[20, 15, 14, 10, 8, 7, 3, 3, 1, 1]
"""
以上我们就实现了堆排序,那么问题来了,你觉得 heapq 模块里的 nlargest 和 nsmallest 是怎么实现的呢?
假设我们要选取 k 个最小的元素,那么首先我们可以从数组中截取前 k 个元素,构建一个大根堆。然后从第 k + 1 个元素开始遍历数组,如果当前元素大于等于堆顶元素,那么它肯定就不是前 k 小的元素,如果当前元素小于堆顶的元素,那么两者进行交换,然后进行一次 Sift Down 操作。当数组遍历完毕之后,堆中的 k 个元素就是最小的前 k 个元素。同理,如果想选择前 k 个最大的元素,那么就构建一个小根堆。
或者将整个数组构建成一个堆,然后heappop k 次即可,这样也能选择前 k 个元素。
优先队列
其实在排序的时候,堆排序不是效率最高的排序,它比三路快排要慢一些。但是堆存在的目的绝不仅仅是为了排序,由于它可以动态添加元素、删除元素,并且时间复杂度都为 O(logN) 级别,所以堆的强大之处就在于非常适合实现优先队列。
事实上 heapq 也已经为我们提供了堆的相关操作:
"""
heapq.heapify(data)
将数组 data 整理成堆的形状,只支持小根堆
heapq.heappush(data, item)
向堆中添加元素,并维护堆的形状
要求 data 已经是一个小根堆
heapq.heappop(data, item)
从堆中弹出元素,并维护堆的形状
要求 data 已经是一个小根堆
"""
而 Python 的优先队列,底层就是借助于 heapq 实现的,我们看一下:
里面的 item 是一个元组,第一个元素是优先级(值越小、优先级越高),第二个元素是具体的数据,这就是优先队列,是不是比你想象中的要简单许多呢?
小结
堆是一种非常高效的数据结构,它可以动态地添加、删除元素,并且时间复杂度均为 O(logN) 级别。这个特性就决定了它非常适合实现优先队列,维护一个堆,在往堆中添加元素的时候,只需要加一个优先级即可,也就是将优先级和数据组合成一个元组添加到堆中。如果构建的是小根堆,值越小、优先级越高;构建的是大根堆,值越大,优先级越高。
另外我们说,当获取最大值或最小值时,推荐使用内置函数 max 和 min。但如果数组 data 一直在动态变化,并且要随时获取里面的最大值或最小值,那么相比使用内置函数 max、min,更好的做法是将 data 维护成一个堆。然后添加元素使用 heappush,获取元素直接 data[0] 即可。因为这整体是一个 O(logN) 的操作,而是 min、max 是一个 O(N) 的操作。
最后,堆可以用来实现排序,效率也很高,但相比三路快排还差了那么一点。但堆存在的目的不在于排序,而在于它的动态性。优先队列就不必说了,还有 TOP K,三路快排和堆都可以实现 TOP K,但前者要求数据必须一次性全部给出,而堆则没有这个要求,换句话说堆可以满足对流式数据的处理。
比如 1T 的文件,一行就是一串数字,如果想在 16G 内存的机器上查找最大的 100 个数字,用快排是无法实现的,因为无法将文件一次性加载到内存中。
但堆可以实现这个需求,先读取 100 行维护一个小根堆,然后从 101 行继续读取,依次和堆顶进行比较。如果小于堆顶元素,那么它一定不是前 100 个最大的数字;如果大于堆顶元素,那么就替换掉,然后 sift_down,维护堆的形状。这样总有一刻,能够选出最大的 100 个数字。
所以当数组 data 不断地变化时,将其维护成一个堆,然后通过 heappush 添加元素、heappop 弹出堆顶元素、data[0] 获取堆顶元素,往往是最佳选择。并且添加和弹出都是 logN 级别的时间复杂度,也正是这个特性,它适合优先队列以及流式数据(数据无法一次性全部给出)的处理。
还不过瘾?试试它们