200 行 Rust 代码编写一个向量搜索库,代码已开源!
由于 AI 与机器学习的快速发展,如今向量数据库已无处不在。虽然向量搜索能够支持复杂的 AI 或机器学习应用程序,但这个概念本身并不难。在文本中,我们来一起看看向量数据库的工作原理,并使用不到 200 行 Rust 代码构建一个简单的向量搜索库。
GitHub 完整的代码:https://github.com/fennel-ai/fann/?ref=fennel.ai
本文使用的方法基于流行库 annoy 中使用的一种名叫“Locality Sensitive Hashing”(局部敏感哈希,LSH)的系列算法。本文的目的不是介绍一种新奇的算法或库,而是分享如何编写一个向量搜索。
在写代码之前,首先我们来介绍一下什么是向量搜索。
向量(又名嵌入)介绍
我们很难使用传统的数据库表示和查询文档、图像、视频等复杂的非结构化数据,尤其是无法查找“相似”数据。那么,YouTube 又是如何选择下一个播放视频的呢?Spotify 又是如何根据当前歌曲自定义播放列表的呢?
2010 年代 AI 的进步(从 Word2Vec 和 GloVe 开始)使我们能够构建这些对象的语义表示,即使用笛卡尔空间中的点表示这些对象。假设一个视频被映射到点 [0.1, -5.1, 7.55],另一个被映射到点 [5.3, -0.1, 2.7]。这些机器学习算法的神奇之处在于,这些点的选择保留了语义信息:两个视频的相似度越高,它们的向量之间的距离越小。
请注意,这些向量(或更专业的叫法是“嵌入”)不一定是三维的,它们可以、而且通常确实位于更高维的空间中(比如 128 维或 750 维)。而且距离也不一定是欧几里得距离,其他形式的距离也可以,比如点积。无论采用哪种形式,重要的是它们之间的距离对应于相似性。
下面,假设我们可以访问所有 YouTube 视频的此类向量。那么我们如何找到与某个视频相似度最高的其他视频呢?很简单。遍历所有视频,计算它们之间的距离,并选择距离最小的视频,这个过程又称之为查找视频的“最近邻居”。这种方法确实有效,除了一个问题,线性O(N)搜索的成本过高。所以,我们需要一种更快的次线性方法来查找视频的最近邻居。这通常是不可能的,必须付出一些代价。
然而,在实际情况下,我们并不需要找到“最近”的视频,只要足够接近就可以了。这就是“近似最近邻”(Approximate Nearest Neighbor,ANN)算法,又称为向量搜索。我们的目标是通过次线性方法找到空间中任何足够近的最近邻。那么,实际如何实现呢?
如何找到近似最近邻?
向量搜索算法背后的基本思想都是相同的:通过一些预处理来识别彼此足够接近的点(有点类似于建立索引)。查询时,使用这个“索引”来排除大部分点。只留下少量点,然后再进行线性扫描。
然而,这个简单的想法有很多可以实现的方法。目前有几种最先进的向量搜索算法,例如 HNSW(Hierarchical Navigable Small World,分层的可导航小世界)是一种图,它连接了互相接近的顶点,而且还保存了距离固定入口点的长距离边。此外还有一些开源项目,例如 Facebook 的 FAISS,以及一些高可用向量数据库的 PaaS 产品,例如 Pinecone 和 Weaviate。
在本文中,假设有“N”个指定的点,我们要在这些点上构建一个简化的向量搜索索引,步骤如下:
随机抽取 2 个向量 A 和 B。
计算这两个向量之间的中点 C。
构建一个通过 C 并垂直于连接 A 和 B 的线段的超平面(即更高维度的“线”)。
根据相对于超平面的位置:“上方”或“下方”,将所有向量分为两组。
分别对两组向量做以下处理:如果组的大小大于可配置参数“最大节点数”,则在该组之上递归调用此过程,构建子树。否则,使用所有向量(或它们的唯一 ID)构建一个叶节点。
我们将使用这个随机过程来构建一棵树,其中的每个节点都是一个超平面定义,左边的子树是超平面“下方”的所有向量,右边的子树是超平面“上方”的所有向量。向量集不断递归拆分,直到叶节点包含的向量不超过“最大节点数”。下图是包含5个点的示例:
图1:利用随机超平面分割空间
我们随机选择向量 A1=(4,2)和 B1=(5,7),二者的中点是 (4.5,4.5),我们构建一条线,要求通过中点且垂直于线 (A1, B1)。这条线 x + 5y=27(图中蓝色的线)给了我们两组向量:一组包含 2 个向量,另一组包含 4 个。假设“最大节点数”设置为 2。那么,第一组向量不需要进一步拆分,而第二组向量需要进一步递归——如图所示,我们又构建了一个红色超平面。对于大型数据集,如此重复拆分可以将超空间拆分为几个不同的区域,如下所示。
图2:被许多超平面分割的空间(来自 https://t.co/K0Xlht8GwQ,作者:Erik Bernhardsson)
此处的每个区域代表一个叶节点,而且感觉足够接近的点很可能最终会出现在同一个叶节点中。接下来,给定一个查询点,我们可以在对数时间内自上而下遍历树,找到它所属的叶子,然后对叶节点中的所有(数量不多)点进行线性扫描。很明显,这个算法并不是万无一失的,即便是距离非常近的两个点也完全有可能被一个超平面分开,最终导致二者彼此相距很远。但是,这个问题可以通过构建多棵独立的树来解决,这样,如果两个点之间的距离足够近,它们出现在某些树同一个叶节点中的概率就很高。在查询时,我们自上而下遍历所有树,定位相关的叶节点,然后求所有叶节点的所有候选节点的并集,并对所有节点进行线性扫描。
好了,理论的介绍到此为止。下面,我们来编写代码。
首先,我们为 Rust 的 Vector 类型定义一些工具函数,包括求点积、均值、哈希值以及平方 L2 距离。感谢 Rust 的优雅的类型系统,我们可以传递泛型类型参数 N,以在编译时强制索引中的所有向量具有相同的维度。
#[derive(Eq, PartialEq, Hash)]
pub struct HashKey<const N: usize>([u32; N]);
#[derive(Copy, Clone)]
pub struct Vector<const N: usize>(pub [f32; N]);
impl<const N: usize> Vector<N> {
pub fn subtract_from(&self, vector: &Vector<N>) -> Vector<N> {
let mapped = self.0.iter().zip(vector.0).map(|(a, b)| b - a);
let coords: [f32; N] = mapped.collect::<Vec<_>>().try_into().unwrap();
return Vector(coords);
}
pub fn avg(&self, vector: &Vector<N>) -> Vector<N> {
let mapped = self.0.iter().zip(vector.0).map(|(a, b)| (a + b) / 2.0);
let coords: [f32; N] = mapped.collect::<Vec<_>>().try_into().unwrap();
return Vector(coords);
}
pub fn dot_product(&self, vector: &Vector<N>) -> f32 {
let zipped_iter = self.0.iter().zip(vector.0);
return zipped_iter.map(|(a, b)| a * b).sum::<f32>();
}
pub fn to_hashkey(&self) -> HashKey<N> {
// f32 in Rust doesn't implement hash. We use bytes to dedup. While it
// can't differentiate ~16M ways NaN is written, it's safe for us
let bit_iter = self.0.iter().map(|a| a.to_bits());
let data: [u32; N] = bit_iter.collect::<Vec<_>>().try_into().unwrap();
return HashKey::<N>(data);
}
pub fn sq_euc_dis(&self, vector: &Vector<N>) -> f32 {
let zipped_iter = self.0.iter().zip(vector.0);
return zipped_iter.map(|(a, b)| (a - b).powi(2)).sum();
}
}
构建好这些核心工具函数之后,下面我们来定义超平面:
struct HyperPlane<const N: usize> {
coefficients: Vector<N>,
constant: f32,
}
impl<const N: usize> HyperPlane<N> {
pub fn point_is_above(&self, point: &Vector<N>) -> bool {
self.coefficients.dot_product(point) + self.constant >= 0.0
}
}
接下来,我们来生成随机超平面,构建最近邻树的森林。我们应该如何表示树中的点呢?
我们可以直接将 D 维向量存储在叶节点中。但是,如果这个维度(D)非常大,那么会导致内存碎片化(严重影响到性能),而且当多棵树引用同一个向量时,还会造成森林重复保存。因此,我们将向量存储在全局可访问的连续位置中,并在叶节点中保存类型为 usize 的索引(在 64 位系统上仅占用 8 字节,相反如果直接保存四维向量,每个维度的 f32 类型就会占用 4 字节)。以下是用于表示树内部节点和叶节点的数据类型。
enum Node<const N: usize> {
Inner(Box<InnerNode<N>>),
Leaf(Box<LeafNode<N>>),
}
struct LeafNode<const N: usize>(Vec<usize>);
struct InnerNode<const N: usize> {
hyperplane: HyperPlane<N>,
left_node: Node<N>,
right_node: Node<N>,
}
pub struct ANNIndex<const N: usize> {
trees: Vec<Node<N>>,
ids: Vec<i32>,
values: Vec<Vector<N>>,
}
那么,我们应该如何找到正确的超平面呢?
我们从索引中随机采样两个,分别对应于向量 A 和 B,计算 n = A - B,并找到 A 和 B 的中点(point_on_plane)。存储超平面使用的结构包含系数(向量 n)和常量(n 和 point_on_plane 的点积),形式为:n(x-x0) = nx - nx0。我们可以求向量和 n 之间的点积,然后减去常数,就可以将向量放在超平面的“上方”或“下方”。由于树中的内部节点包含超平面定义,叶节点包含向量 ID,因此我们可以使用 ADT 对树进行类型检查:
impl<const N: usize> ANNIndex<N> {
fn build_hyperplane(
indexes: &Vec<usize>,
all_vecs: &Vec<Vector<N>>,
) -> (HyperPlane<N>, Vec<usize>, Vec<usize>) {
let sample: Vec<_> = indexes
.choose_multiple(&mut rand::thread_rng(), 2)
.collect();
// cartesian eq for hyperplane n * (x - x_0) = 0
// n (normal vector) is the coefs x_1 to x_n
let (a, b) = (*sample[0], *sample[1]);
let coefficients = all_vecs[a].subtract_from(&all_vecs[b]);
let point_on_plane = all_vecs[a].avg(&all_vecs[b]);
let constant = -coefficients.dot_product(&point_on_plane);
let hyperplane = HyperPlane::<N> {
coefficients,
constant,
};
let (mut above, mut below) = (vec![], vec![]);
for &id in indexes.iter() {
if hyperplane.point_is_above(&all_vecs[id]) {
above.push(id)
} else {
below.push(id)
};
}
return (hyperplane, above, below);
}
}
接下来,我们定义递归过程,根据索引时的“最大节点数”构建树:
impl<const N: usize> ANNIndex<N> {
fn build_a_tree(
max_size: i32,
indexes: &Vec<usize>,
all_vecs: &Vec<Vector<N>>,
) -> Node<N> {
if indexes.len() <= (max_size as usize) {
return Node::Leaf(Box::new(LeafNode::<N>(indexes.clone())));
}
let (plane, above, below) = Self::build_hyperplane(indexes, all_vecs);
let node_above = Self::build_a_tree(max_size, &above, all_vecs);
let node_below = Self::build_a_tree(max_size, &below, all_vecs);
return Node::Inner(Box::new(InnerNode::<N> {
hyperplane: plane,
left_node: node_below,
right_node: node_above,
}));
}
}
请注意,由于该算法不允许重复,构建两点之间的超平面要求这两个点是唯一的,因此在建立索引之前,我们必须去除向量集中的重复项。
因此整个索引(树的森林)可以这样构建:
impl<const N: usize> ANNIndex<N> {
fn deduplicate(
vectors: &Vec<Vector<N>>,
ids: &Vec<i32>,
dedup_vectors: &mut Vec<Vector<N>>,
ids_of_dedup_vectors: &mut Vec<i32>,
) {
let mut hashes_seen = HashSet::new();
for i in 1..vectors.len() {
let hash_key = vectors[i].to_hashkey();
if !hashes_seen.contains(&hash_key) {
hashes_seen.insert(hash_key);
dedup_vectors.push(vectors[i]);
ids_of_dedup_vectors.push(ids[i]);
}
}
}
pub fn build_index(
num_trees: i32,
max_size: i32,
vecs: &Vec<Vector<N>>,
vec_ids: &Vec<i32>,
) -> ANNIndex<N> {
let (mut unique_vecs, mut ids) = (vec![], vec![]);
Self::deduplicate(vecs, vec_ids, &mut unique_vecs, &mut ids);
// Trees hold an index into the [unique_vecs] list which is not
// necessarily its id, if duplicates existed
let all_indexes: Vec<usize> = (0..unique_vecs.len()).collect();
let trees: Vec<_> = (0..num_trees)
.map(|_| Self::build_a_tree(max_size, &all_indexes, &unique_vecs))
.collect();
return ANNIndex::<N> {
trees,
ids,
values: unique_vecs,
};
}
}
查询时间
索引建立之后,下一步我们如何使用它搜索某棵树中输入向量的 K 个近似最近邻?我们的超平面存储在非叶节点处,因此我们可以从树的根开始搜索:“这个向量是在这个超平面的上方还是下方?”我们可以用点积计算,复杂度为 O(D)。接下来,我们可以根据响应,递归搜索左子树或右子树,直到找到叶节点。请记住,叶节点中存储的向量数最大为“最大节点数”,这些向量位于输入向量的近似邻域中(它们将落入所有超平面下超空间的同一分区中)。如果这个叶节点的向量索引数超过 K,我们就需要根据到输入向量的 L2 距离,对所有这些向量进行排序,并返回最接近的 K 个向量。
假设我们的索引最后得到了一棵平衡树,对于维度 D、向量数量 N 和最大节点数 M << N,搜索耗费的时间为 O(Dlog(N) + DM + Mlog(M)),最坏的情况下我们需要比较 log(N) 个超平面(即树的高度)才能找到叶节点,每次比较耗费的时间为 O(D) 个点积,计算一个叶节点的所有候选向量的 L2 距离需要 O(DM),最后在 O(Mlog(M)) 时间内对它们进行排序,并返回 前 K 个。
但是,如果我们找到的叶节点少于 K 个向量,会发生什么情况?如果最大节点数太小,或超平面分割不均匀,则会导致某个子树中的向量非常少。为了解决这个问题,我们可以在树搜索中添加一些简单的回溯功能。例如,如果返回的候选向量数量不足,我们可以在内部节点再递归调用一次,访问另一个分枝。代码如下:
1impl<const N: usize> ANNIndex<N> {
2 fn tree_result(
3 query: Vector<N>,
4 n: i32,
5 tree: &Node<N>,
6 candidates: &mut HashSet<usize>,
7 ) -> i32 {
8 // take everything in node, if still needed, take from alternate subtree
9 match tree {
10 Node::Leaf(box_leaf) => {
11 let leaf_values = &(box_leaf.0);
12 let num_candidates_found = min(n as usize, leaf_values.len());
13 for i in 0..num_candidates_found {
14 candidates.insert(leaf_values[i]);
15 }
16 return num_candidates_found as i32;
17 }
18 Node::Inner(inner) => {
19 let above = (*inner).hyperplane.point_is_above(&query);
20 let (main, backup) = match above {
21 true => (&(inner.right_node), &(inner.left_node)),
22 false => (&(inner.left_node), &(inner.right_node)),
23 };
24 match Self::tree_result(query, n, main, candidates) {
25 k if k < n => {
26 k + Self::tree_result(query, n - k, backup, candidates)
27 }
28 k => k,
29 }
30 }
31 }
32 }
33}
请注意,为了进一步优化递归调用,我们还可以保存子树中的向量总数,并在每个内部节点中保存指向所有叶节点的指针列表,但为了简单起见,此处略过。
将此搜索扩展到一片森林非常简单,只需从所有树木中单独收集前 K 个候选者,按距离对它们进行排序,然后返回总体的前 K 个匹配项。请注意,随着树的数量增加,内存的使用和搜索时间都会呈线性增长,但我们可以获得更好的“更近”邻居,因为我们收集了来自不同树的候选者。
1impl<const N: usize> ANNIndex<N> {
2 pub fn search_approximate(
3 &self,
4 query: Vector<N>,
5 top_k: i32,
6 ) -> Vec<(i32, f32)> {
7 let mut candidates = HashSet::new();
8 for tree in self.trees.iter() {
9 Self::tree_result(query, top_k, tree, &mut candidates);
10 }
11 candidates
12 .into_iter()
13 .map(|idx| (idx, self.values[idx].sq_euc_dis(&query)))
14 .sorted_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
15 .take(top_k as usize)
16 .map(|(idx, dis)| (self.ids[idx], dis))
17 .collect()
18 }
19}
以上就是一个简单的向量搜索索引的 200 行 Rust 代码!
基准测试
其实,这个实现非常简单,以至于我们一度怀疑相较于最先进的算法,这个算法非常拙略。因此,我们做了一些基准测试来证实我们的怀疑。
算法可以通过延迟和质量来评估。质量通常通过召回率来衡量:求近似最近邻搜索返回结果与线性搜索返回结果的百分比。严格来说,有时返回的结果并不在前 K,但非常接近 K,因此也没关系,为了量化这一点,我们可以计算邻居的平均欧几里德距离,并将其与平均距离进行比较。
测量延迟很简单,我们可以检查执行查询所需的时间。
所有的基准测试都是在搭载了 2.3 GHz 四核英特尔酷睿 i5 处理器的机器上运行的,使用了 999,994 条维基百科数据 FastText 嵌入,300 个维度。我们设置返回“前 K 个”查询结果。
作为参考,我们拿 FAISS HNSW 索引(ef_search=16、ef_construction=40、max_node_size=15)与 Rust 索引的小型版本(num_trees=3、max_node_size=15)进行了比较。我们使用 Rust 实现了穷举搜索,而 FAISS 库有 HNSW 的 C++ 源代码。原始延迟数突出了近似搜索的优势:
算法 | 延迟 | QPS |
穷举搜索 | 675.25 毫秒 | 1.48 |
FAISS HNSW 索引 | 355.36 微秒 | 2814.05 |
自定义 Rust 索引 | 112.02 微秒 | 8926.98 |
两种近似最近邻方法快了三个数量级。看起来在这个微型基准测试中,我们的 Rust 实现比 HNSW 快 3 倍。
在分析质量时,我们的测试数据为:返回“river”的 10 个最近的邻居。
穷举搜索 | FAISS HNSW 索引 | 自定义 Rust 索引 | |||
单词 | 欧几里得距离 | 单词 | 欧几里得距离 | 单词 | 欧几里得距离 |
river | 0 | river | 0 | river | 0 |
River | 1.39122 | River | 1.39122 | creek | 1.63744 |
rivers | 1.47646 | river- | 1.58342 | river. | 1.73224 |
river- | 1.58342 | swift-flowing | 1.62413 | lake | 1.75655 |
swift-flowing | 1.62413 | flood-swollen | 1.63798 | sea | 1.87368 |
creek | 1.63744 | river.The | 1.68156 | up-river | 1.92088 |
flood-swollen | 1.63798 | river-bed | 1.68510 | shore | 1.92266 |
river.The | 1.68156 | unfordable | 1.69245 | brook | 2.01973 |
river-bed | 1.68510 | River- | 1.69512 | hill | 2.03419 |
unfordable | 1.69245 | River.The | 1.69539 | pond | 2.04376 |
再来一个例子,这次我们搜索“war”。
穷举搜索 | FAISS HNSW 索引 | 自定义 Rust 索引 | |||
单词 | 欧几里得距离 | 单词 | 欧几里得距离 | 单词 | 欧几里得距离 |
war | 0 | war | 0 | war | 0 |
war-- | 1.38416 | war-- | 1.38416 | war-- | 1.38416 |
war--a | 1.44906 | war--a | 1.44906 | wars | 1.45859 |
wars | 1.45859 | wars | 1.45859 | quasi-war | 1.59712 |
war--and | 1.45907 | war--and | 1.45907 | war-footing | 1.69175 |
war.It | 1.46991 | war.It | 1.46991 | check-mate | 1.74982 |
war.In | 1.49632 | war.In | 1.49632 | ill-begotten | 1.75498 |
unwinable | 1.51296 | unwinable | 1.51296 | subequent | 1.76617 |
war.And | 1.51830 | war.And | 1.51830 | humanitary | 1.77464 |
hostilities | 1.54783 | Iraw | 1.54906 | execution | 1.77992 |
我们使用的这个语料库包含 999,994 个单词,我们可视化了在 HNSW 和我们的自定义 Rust 索引下,每个单词与其前 K=20 个近似邻居的平均欧几里得距离的分布:
最先进的 HNSW 索引提供的邻居确实比我们的示例索引更近,其平均距离和中值距离分别为 1.31576 和 1.20230(而与我们的示例索引分别为 1.47138 和 1.35620)。在大小为 1 万的子集上,HNSW 的前 K=20 的召回率为 58.2%,而我们的示例索引在不同配置下的测试结果如下:
树的数量 | 最大节点数 | 平均运行时间 | QPS | 召回率 |
3 | 5 | 129.48微秒 | 7723 | 0.11465 |
3 | 15 | 112.02微秒 | 8297 | 0.11175 |
3 | 30 | 114.48微秒 | 8735 | 0.09265 |
9 | 5 | 16.77毫秒 | 60 | 0.22095 |
9 | 15 | 1.54毫秒 | 649 | 0.20985 |
9 | 30 | 370.80微秒 | 2697 | 0.16835 |
15 | 5 | 35.45毫秒 | 28 | 0.29825 |
15 | 15 | 7.34毫秒 | 136 | 0.28520 |
15 | 30 | 2.19毫秒 | 457 | 0.23115 |
为什么我们的算法如此之快?
通过上面的数字可以看出,虽然我们的算法在质量上无法与尖端技术相媲美,但它的速度非常快。为什么会这样?
老实说,在构建这个算法时,我们兴奋得昏了头脑,性能优化也只是为了好玩。下面是一些重要的优化:
将去除文档的重复数据的过程转移到索引冷路径上。通过索引(而不是浮点数组)引用向量可以大幅提高搜索速度,因为跨树查找唯一候选者只需要对 8 字节索引进行散列(而不是 300 维 f32 数据)。
在将点添加到全局候选列表之前,散列并查找唯一向量,通过递归搜索调用中的可变引用传递数据,以避免栈帧之间和栈帧内的复制。
将 N 作为通用类型参数传递进去,这样类型检查就会将 300 维的数据当作长度为 300 的 f32 数组,这不仅可以改善缓存局部性,还可以减少内存占用。
我们怀疑内部操作(例如点积)被 Rust 编译器向量化了,但我们没有检查。
一些真实世界应用的考虑
有一些问题这个示例没有考虑到,但在生产环境的向量搜索中非常关键:
当搜索涉及多个树时应当使用并行。不要顺序收集候选者,而是应该并行进行,因为每个树都会访问完全不同的内存,因此每个树都可以在各自的线程上单独运行,候选者通过通道发送到主线程。线程可以在创建索引时建立,并使用模拟搜索进行预热(将树加载到缓存),从而减小搜索的额外开销。这样搜索就不会根据树的数量线性增长了。
大型树可能无法完整地加载到内存中,可能需要从磁盘中读取,所以可以将特定的子图放到磁盘上,并精心设计算法,保证搜索正常工作的同时,尽可能减小文件I/O。
继续深入,如果树无法保存在一个实例的磁盘上,可以将子树分布到多个实例中,并让递归搜索在本地数据不存在时发出RPC请求。
树包含许多内存重定向(基于指针的树对L1缓存不友好)。平衡树可以写成数组,但我们的树只是接近平衡,其超平面是随机的。是否可能采用新的数据结构?
上述问题的解决方案,对于新数据即时编制索引也是成立的(可能需要对大型树进行分片)。如果特定索引序列会导致非常不平衡的树,那么是否应该重建树?
这些都是我们未来该深刻思考的问题。
原文链接:https://fennel.ai/blog/vector-search-in-200-lines-of-rust/
声明:本文为 CSDN 翻译,未经允许,禁止转载。