查看原文
其他

C++17: Simplify Code with Fold Expressions

里缪 CppMore 2023-07-27

本篇作为Understanding variadic templates的进阶内容,同时,Fold Expressions也是C++17最常用的特性之一。

1

Fold Expressions的基本概念

C++11中,参数包只能在需要参数列表的上下文展开,比如函数递归。而递归函数需要终止条件,因此往往需要提供一个同名的函数来终止递归。

举个例子:

1void print()
2
{
3    std::cout << '\n';
4}
5
6template<typename F, typename... Args>
7void print(F first, Args... args)
8
{
9    std::cout << first << ' ';
10    print(args...);
11}

我们无法在函数主体中展开,例如不能这样做:

1template<typename... Args>
2void print(Args... args)
3
{
4    std::cout << ... << args << std::endl;
5}


但是可以通过逗号表达式和初始化列表在函数主体展开参数包。

例子如下:

1template<typename... T>
2void print(T&&... args)
3
{
4    std::initializer_list<int>{([](const auto& t) {
5                std::cout << t << std::endl;
6        }(std::forward<T>(args)), 0)...};
7}


短短的一行代码中,便已使用了initializer list、lambda、perfect forwarding、comma operator等多种特性。

由于参数包只能在需要参数列表的上下文展开,因而不用递归,便需要一个支持可变参数的类型,initializer_list刚好满足条件。

此外,我们知道initializer_list只支持单类型,而需要处理的参数却是不同类型的,所以再借助逗号表达式来实现目的。

逗号表达式基于一个基本的事实,看如下代码:

1int a = 1;
2int b = 2;
3int c = 5;
4a = (a = b, c);


a最终的结果将为5,在逗号操作符所连接的表达式中,会从左往右依次执行。本例中,首先会执行(a = b),然后返回c,所以a的值便为c的值。

依据这个事实可以在initializer_list中展开参数包,表面上是在初始化initializer_list,实际上是在初始化过程中利用逗号表达式展开参数包。

这样的代码写起来依旧麻烦,所以在C++17,提供了Fold Expressions(折叠表达式),可以直接在函数主体中展开参数包。

可以简单地通过fold expression来改写上述例子:

1template<typename... Args>
2void print(Args... args)
3
{
4    ((std::cout << args << ' '), ...);
5    std::cout << '\n';
6}


具体细节,见于下节。

2

Fold Expressions的细节疏理

Fold Expression的语法比较简单,可以直接参考cppreference。

语法分为一元操作和二元操作。pack指的就是参数包,op指的是操作类型,有一个op的是一元操作,两个op的是二元操作。

我们先来看一元操作,来写一个最简单的输出函数:

1template<typename... T>
2void print(T&&... args)
3
{
4        (std::cout << ... << args) << '\n';
5}


这是使用的是第(2)个语法,即一元左折叠,需要注意的是:括号也是语法的一部分。

因为cout输出后会返回一个ostream,所以可以不断输出所有参数,展开后的代码如下:

print(1234);
((((std::cout << 1) << 2) << 3) << 4) << '\n';


但是注意不能这样写:

std::cout << (args << ... << '\n');

这是右折叠,别忘了operator<<本义为左移操作符,所以执行结果会出人意料。

上述输出函数的缺点在于输出内容之间没有空格,若要在每个参数的输出之间加入空格,最简单的方式是借助逗号表达式,例子便是上节末尾的代码:

1template<typename... Args>
2void print(Args&&... args)
3
{
4    ((std::cout << args << ' '), ...);
5    std::cout << '\n';
6}
当然还有其它的做法,若你想在处理参数之前或之后添加额外的操作,可以额外定义一个函数:

1template<typename F, typename... Args>
2void print(F first, const Args&... args)
3
{
4    std::cout << first;
5    auto space = [](const auto& arg) {
6        std::cout << ' ' << arg;
7    };
8    (..., space(args));
9    std::cout << '\n';
10}


区别在于,第二种做法不支持0个参数,而逗号表达式可以支持,不过小做修改便可支持:

1template<typename... Args>
2void printer(const Args&... args)
3
{
4    auto space = [](const auto& arg) {
5        std::cout << arg << ' ';
6    };
7
8    (..., space(args));
9    std::cout << '\n';
10}


但是不是任何时候都可折叠0个参数的,比如:

1template<typename... Args>
2auto sum(Args... args)
3
{
4    return (... + args);
5}


这是个求和函数,它必须要返回一个值,所以0参时会报错。

这种情况就应该使用二元折叠,你可以翻回本节头部去查看二元的语法,其中的init的意思就是初始值,面对0参时依旧可以运行。

因此例子可以更改为:

1template<typename... Args>
2auto sum(Args... args)
3
{
4    // binary left fold
5    return (0 + ... + args);
6
7    // binary right fold
8    //return (args + ... + 0);  
9}

这里使用了二元左折叠,在这种情况下和二元右折叠没有区别,但一般都优先使用左折叠。

假如我们要相加字符串,那么左折叠与右折叠便有差异了。

重写一个例子用来连接字符串:

1template<typename... Args>
2auto strcat(Args&&... args)
3
{
4    // unary left fold
5    return (... + args);
6}

若像下面这样调用:

std::cout << strcat(std::string{"Have"}, "a""nice""day!") << '\n';


输出结果将为:

空格问题暂且不论。现在若把参数调用的顺序稍微切换,

std::cout << strcat("Have""a""nice"std::string("day!")) << '\n';

便是完全不同的结果:

现在编译不过了。

这是由于原生字符串不支持operator+,第一个调用中将string放在首位,因为它内部重载了operator+,所以可以进行加法,之后会返回一个string,遂可将所有参数组成完整的字符串。

而第二个调用中,string处于尾部,而实现却为一元左折叠,由于前面的参数不支持operator+而编译出错。

只需将strcat改为一元右折叠便能编译第二个调用,但此时又无法支持第一个调用。

解决之道如下:

template<typename... Args>
auto strcat(Args&&... args)
{
    // unary left fold
    return ((std::string{} + args + " ") + ...);
}


对于每一个字符串,都在开头添加一个空string,这样就能针对所有形式。


3

Filter Fold Expressions

若你想对解包结果添加约束,可以组合逗号表达式和逻辑运算符(&&, ||, !)使用来添加过滤。

举个例子,若想写一个只打印偶数的输出函数,可以这样编写:

1#include <iostream>
2
3template<typename... Args>
4void print_even_number(Args... args)
5
{
6    bool b = false;
7    ((b = [](int arg) { return arg % 2 == 0; }(args) && (std::cout << args << ' ')), ...);
8    std::cout << '\n';
9}
10
11int main()
12
{
13    print_even_number(12345612111819);
14
15    return 0;
16}


输出结果为:

给了一个未使用变量警告,可以使用C++17的[[maybe_unused]]特性来消除。

bool b [[maybe_unused]] = false;


此外,还可以结合type traits来使用,比如我们想实现一个判断容器类型是否一致的功能,可以这样编写:

template<typename H, typename... Ts>
struct is_same_type {
    static constexpr bool value =  (std::is_same_v<H, Ts> && ...);
};

// will be true
is_same_type<intintdecltype(3)>::value;


标准中的array就使用了这个技巧来判断类型是否一致。

4

编译期排序

Fold Expression也可以和一些算法结合起来使用。

这里有个来自网络的编译期排序例子,可以参考一 二。

1#include <iostream>
2#include <array>
3#include <utility>
4
5template<typename Values> struct SortImpl;
6
7template<typename I, I... values>
8constexpr auto sort(std::integer_sequence<I, values...> sequence)
9
{
10    return SortImpl<decltype(sequence)>::sort();
11}
12
13template<typename I, I... values>
14struct SortImpl<std::integer_sequence<I, values...>>
15{

16    static constexpr auto sort() {
17        // 创建4位索引序列
18        return sort(std::make_index_sequence<sizeof...(values)>{});
19    }
20
21    template<std::size_t... index>
22    static constexpr auto sort(std::index_sequence<index...>) {
23        // 创建integer_sequence,用于返回排序好的结果
24        return std::integer_sequence<I, ith<index>()...>{};
25    }
26
27    template<std::size_t i>
28    static constexpr auto ith() {
29        I result{};
30
31        // 利用rankOf计算当前数值所应排列的位置,再和索引比较,相同则返回对应的数值
32        ((i >= rankOf<values>() && i < rankOf<values>() + count<values>()
33                ? result = values : I{}), ...);
34        return result;
35    }
36
37    template<I x> // 排序位置,例如第一次调用:(0 > 5) + (0 > 2) + (0 > 2) == 0,则应排第0个
38    static constexpr auto rankOf() return ((x > values) + ...); }
39
40    template<I x> // 计算相同数值的个数
41    static constexpr auto count() return ((x == values) + ...); }
42};
43
44template<typename I, I... values>
45constexpr auto toArray(std::integer_sequence<I, values...>)
46
{
47    // 转换成数组,以便输出
48    return std::array<I, sizeof...(values)> { values... };
49}
50
51int main()
52
{
53    auto y = toArray(sort(std::index_sequence<0522>{}));
54
55    for(auto& elem : y) {
56        std::cout << elem << ' ';
57    }
58    std::cout << '\n';
59
60    return 0;
61}

关键位置我已经提供了注释,主要是利用了integer_sequence和fold expression,前者用于提供索引,后者用于找出该索引所应对应的值的位置和计算重复次数,索引与位置重复次数一比较,就能找到索引对应的值。

运行结果如下:


5

Pretty-print std::tuple

我们还可以利用fold expressions来方便地打印tuple的元素,例子如下:

1#include <iostream>
2#include <tuple>
3#include <utility>
4
5template<typename T, std::size_t... Is>
6void print_tuple(const T& tup, std::index_sequence<Is...>)
7
{
8    std::cout << "(";
9    (..., (std::cout << (Is == 0 ? "" : ", ") << std::get<Is>(tup)));
10    std::cout << ")\n";
11}
12
13template<typename... T>
14void print_tuple(const std::tuple<T...>& tup)
15
{
16    print_tuple(tup, std::make_index_sequence<sizeof...(T)>());
17}
18
19int main()
20
{
21    print_tuple(std::make_tuple("apple""pineapple""cherry""lemon""mango"));
22
23    return 0;
24}

输出如下:

std::tuple提供的get<I>(tup)是编译期的,若需要运行期访问tuple,也很简单,代码如下:

1template<typename T, std::size_t... Is>
2void get_tuple(std::size_t i, const T& tup, std::index_sequence<Is...>)
3
{
4    (..., ((Is == i) && (std::cout << std::get<Is>(tup))));
5    std::cout << '\n';
6}
7
8template<typename... T>
9void get_tuple(std::size_t i, const std::tuple<T...>& tup)
10
{
11    get_tuple(i, tup, std::make_index_sequence<sizeof...(T)>());
12}
13
14// call
15auto tup = std::make_tuple("apple""pineapple""cherry""lemon""mango");
16get_tuple(1, tup);

输出如下:

6

折叠函数调用

Fold expression也可以用折叠访问任意基类的共有成员,一个小例子:

1#include <iostream>
2
3template<typename... Bases>
4struct Foo : private Bases...
5{
6    void print() {
7        (..., Bases::print());
8    }
9};
10
11struct A {
12    void print() std::cout << "A::print()\n"; }
13};
14
15struct B {
16    void print() std::cout << "B::print()\n"; }
17};
18
19struct C {
20    void print() std::cout << "C::print()\n"; }
21};
22
23int main()
24
{
25    Foo<A, B, C> foo;
26    foo.print();
27
28    return 0;
29}

输出将为:

7

总结

Fold expression可以简化代码,完成一些本来实现起来较麻烦的操作。

本篇介绍了基本概念与用法,并提供了大量的例子,相信大家读完之后对其已不陌生。

实际上,它的用法还有很多,有一些用法相当复杂,大家平时可以再多思考下它的其它用处。此外,C++20的一些特性也能和它结合使用,等介绍到的时候会补充。

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存