STL_Algorithm

死亡谷 -- 美国

STL_Algorithm

0. 文件结构

算法部分总共分为了 5 个部分:

  • algo
  • algobase
  • algoset
  • heap
  • numeric

1. algo

第一部分:

adjacent_find, count, count_if, find, find_if, find_end, find_first_of, for_each, generate, generate_n, remove, remove_copy, remove_if, remove_copy_if, replace, replace_copy,replace_if, replace_copy_if, reverse, reverse_copy, rotate, rotate_copy, search, search_n, swap_range, transform, max_element, min_element, includes, merge, partition, unique, unique_copy

第二部分:

lower_bound, upper_bound, binary_search, next_permutation, prev_permutation, random_shuffle, partial_sort, sort, stable_sort, equal_range, random_shuffle, nth_element, stable_partition

1.1 第一部分

直接上例子

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
#include <algorithm>
#include <vector>
#include <functional>
#include <iostream>

using namespace std;

template<class T>
struct display {
void operator()(const T& x) {
cout << x << ' ';
}
};
struct even {
bool operator()(int x) const {
return x % 2 ? false : true;
}
};
class even_by_two {
public:
int operator()() const {
return _x += 2;
}
private:
static int _x;
};
int even_by_two::_x = 0;

int main() {
int ia[] = {0,1,2,3,4,5,6,6,6,7,8};
vector<int> iv(ia, ia + sizeof(ia) / sizeof(int));

// 找到 iv 之中相邻元素值相等的第一个元素
cout << *adjacent_find(iv.begin(), iv.end()) << endl; // 6
cout << *adjacent_find(iv.begin(), iv.end(), equal_to<int>()) << endl; // 6

// 找到 iv 之中元素值为 6 的元素个数
cout << count(iv.begin(), iv.end(), 6) << endl; // 3
// 找到 iv 之中小于 7 的个数
cout << count_if(iv.begin(), iv.end(), bind2nd(less<int>(), 7)) << endl; // 9

cout << *find(iv.begin(), iv.end(), 4) << endl;
cout << *find_if(iv.begin(), iv.end(), bind2nd(greater<int>(), 2)) << endl;

// 找出 iv 序列之中子序列 iv2 所出现的最后一个位置(再往后的三个位置的值)
vector<int> iv2(ia + 6, ia + 8); // {6 6}
cout << *(find_end(iv.begin(), iv.end(), iv2.begin(), iv2.end()) + 3) << endl; // 8

// 找出 iv 序列之中子序列 iv2 所出现的第一个位置(再往后的三个位置的值)
cout << *(find_first_of(iv.begin(), iv.end(), iv2.begin(), iv2.end()) + 3) << endl; // 7

// 迭代遍历整个 iv 区间,对每个元素施行 display 操作(不得改变元素内容)
for_each(iv.begin(), iv.end(), display<int>()); // iv : 0 1 2 3 4 5 6 6 6 7 8
cout << endl;

// generate
generate(iv2.begin(), iv2.end(), even_by_two()); // 每次执行一次: _x + 2
for_each(iv2.begin(), iv2.end(), display<int>()); // iv2 : 2 4
cout << endl;

generate_n(iv.begin(), 3, even_by_two()); // 前三个会覆盖 6 8 10
for_each(iv.begin(), iv.end(), display<int>()); // iv : 6 8 10 3 4 5 6 6 6 7 8
cout << endl;

// remove
remove(iv.begin(), iv.end(), 6); // __ 代表残余数据
for_each(iv.begin(), iv.end(), display<int>());
// iv : 8 10 3 4 5 7 8 _6_6_7_8_
cout << endl;

vector<int> iv3(12);
remove_copy(iv.begin(), iv.end(), iv3.begin(), 6);
for_each(iv3.begin(), iv3.end(), display<int>());
// iv3 : 8 10 3 4 5 7 8 7 8 _0_0_0_
cout << endl;

remove_if(iv.begin(), iv.end(), bind2nd(less<int>(), 6));
for_each(iv.begin(), iv.end(), display<int>());
// iv : 8 10 7 8 6 6 7 8 _7_8_
cout << endl;

remove_copy_if(iv.begin(), iv.end(), iv3.begin(), bind2nd(less<int>(), 7));
for_each(iv3.begin(), iv3.end(), display<int>());
// iv3 : 8 10 7 8 7 8 7 8 8_0_0_0_
cout << endl;

// replace
replace(iv.begin(), iv.end(), 6, 3);
for_each(iv.begin(), iv.end(), display<int>());
// iv : 8 10 7 8 3 3 7 8 3 7 8
cout << endl;

replace_copy(iv.begin(), iv.end(), iv3.begin(), 3, 5);
for_each(iv3.begin(), iv3.end(), display<int>());
// iv3 : 8 10 7 8 5 5 7 8 5 7 8 _0_
cout << endl;

replace_if(iv.begin(), iv.end(), bind2nd(less<int>(), 5), 2);
for_each(iv.begin(), iv.end(), display<int>());
// iv : 8 10 7 8 2 2 7 8 2 7 8
cout << endl;

replace_copy_if(iv.begin(), iv.end(), iv3.begin(), bind2nd(equal_to<int>(), 8), 9);
for_each(iv3.begin(), iv3.end(), display<int>());
// iv3 : 9 10 7 9 2 2 7 9 2 7 9 _0_
cout << endl;

// reverse
reverse(iv.begin(), iv.end());
for_each(iv.begin(), iv.end(), display<int>());
// iv : 8 7 2 8 7 2 2 8 7 10 8
cout << endl;

reverse_copy(iv.begin(), iv.end(), iv3.begin());
for_each(iv3.begin(), iv3.end(), display<int>());
// iv3: 8 10 7 8 2 2 7 8 2 7 8 _0_
cout << endl;

// rotate
rotate(iv.begin(), iv.begin() + 4, iv.end());
for_each(iv.begin(), iv.end(), display<int>());
// iv : 7 2 2 8 7 10 8 8 7 2 8
cout << endl;

rotate_copy(iv.begin(), iv.begin() + 5, iv.end(), iv3.begin());
for_each(iv3.begin(), iv3.end(), display<int>());
// iv3: 10 8 8 7 2 8 7 2 2 8 7 _0_
cout << endl;

// search
int ia2[3] = {2, 8};
vector<int> iv4(ia2, ia2 + 2); // iv4 : {2 8}
cout << *search(iv.begin(), iv.end(), iv4.begin(), iv4.end()) << endl; // 2
cout << *search_n(iv.begin(), iv.end(), 2, 8) << endl; // 8
cout << *search_n(iv.begin(), iv.end(), 3, 8, less<int>()) << endl; // 7

// swap_range
swap_ranges(iv4.begin(), iv4.end(), iv.begin());
for_each(iv.begin(), iv.end(), display<int>()); // iv : 2 8 2 8 7 10 8 8 7 2 8
cout << endl;
for_each(iv4.begin(), iv4.end(), display<int>()); // iv4 : 7 2
cout << endl;

//transform
transform(iv.begin(), iv.end(), iv.begin(), bind2nd(minus<int>(), 2));
for_each(iv.begin(), iv.end(), display<int>());
cout << endl;
// 第一个 begin() 表示输入区间开始,第二个 begin() 第二个区间开始,第三个begin()表示输出区间开始
transform(iv.begin(), iv.end(), iv.begin(), iv.begin(), plus<int>());
for_each(iv.begin(), iv.end(), display<int>());
cout << endl;

// ******************************
vector<int> iv5(ia, ia + sizeof(ia) / sizeof(int));
vector<int> iv6(ia + 4, ia + 8);
vector<int> iv7(15);
for_each(iv5.begin(), iv5.end(), display<int>());
// iv5 : 0 1 2 3 4 5 6 6 6 7 8
cout << endl;
for_each(iv6.begin(), iv6.end(), display<int>()); // iv6 : 4 5 6 6
cout << endl;
// max_element, min_element
cout << *max_element(iv5.begin(), iv5.end()) << endl; // 8
cout << *min_element(iv5.begin(), iv5.end()) << endl; // 0

// includes
cout << includes(iv5.begin(), iv5.end(), iv6.begin(), iv6.end()) << endl; // 1

// merge
merge(iv5.begin(), iv5.end(), iv6.begin(), iv6.end(), iv7.begin());
for_each(iv7.begin(), iv7.end(), display<int>());
// iv7 : 0 1 2 3 4 4 5 5 6 6 6 6 6 7 8
cout << endl;

// partition
partition(iv7.begin(), iv7.end(), even());
for_each(iv7.begin(), iv7.end(), display<int>());
// iv7 : 0 8 2 6 4 4 6 6 6 6 6 5 5 3 7 1
cout << endl;

// unique
unique(iv5.begin(), iv5.end());
for_each(iv5.begin(), iv5.end(), display<int>());
// iv5 : 0 1 2 3 4 5 6 7 8 _7_8_
cout << endl;

unique_copy(iv5.begin(), iv5.end(), iv7.end());
for_each(iv7.begin(), iv7.end(), display<int>());
// 0 1 2 3 4 5 6 7 8 _7_8_5_3_7_1_
cout << endl;

return 0;
}

《STL源码剖析》 – P338 - P343 (例子出处)

《STL源码剖析》 – P343 - P372 (具体细节)


1.2 第二部分

直接上例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
#include <iostream>
#include <algorithm>
#include <functional>
#include <vector>
#include <iterator>
using namespace std;

struct even {
bool operator() (int x) const {
return x % 2 ? false : true;
}
};
int main() {
int ia[] = {12,17,20,22,23,30,33,40};
vector<int> iv(ia, ia + sizeof(ia) / sizeof(int));
// lower_bound : 结果严格 >= 参数
cout << *lower_bound(iv.begin(), iv.end(), 21) << endl; // 22
cout << *lower_bound(iv.begin(), iv.end(), 22) << endl; // 22
// upper_bound : 结果严格 < 参数
cout << *upper_bound(iv.begin(), iv.end(), 21) << endl; // 22
cout << *upper_bound(iv.begin(), iv.end(), 22) << endl; // 23

// 二分查找
cout << binary_search(iv.begin(), iv.end(), 33) << endl; // ture
cout << binary_search(iv.begin(), iv.end(), 34) << endl; // false

ostream_iterator<int> oite(cout, " ");

// 下一个排列组合
next_permutation(iv.begin(), iv.end());
copy(iv.begin(), iv.end(), oite);
cout << endl;

// 上一个排列组合
prev_permutation(iv.begin(), iv.end());
copy(iv.begin(), iv.end(), oite);
cout << endl;

// 随即重排
// shuffle(iv.begin(), iv.end());
// copy(iv.begin(), iv.end(), oite);

// 部分区间排序
partial_sort(iv.begin(), iv.begin() + 4, iv.end());
copy(iv.begin(), iv.end(), oite);
cout << endl;

// sort 增序
sort(iv.begin(), iv.end());
copy(iv.begin(), iv.end(), oite);
cout << endl;
// 降序
sort(iv.begin(), iv.end(), greater<int>());
copy(iv.begin(), iv.end(), oite);
cout << endl;

// 在 iv 尾端附加三个新元素
iv.push_back(22);
iv.push_back(30);
iv.push_back(17);

// 稳定版本的 sort
stable_sort(iv.begin(), iv.end());
copy(iv.begin(), iv.end(), oite);
cout << endl;

// equal_range 面对有序区间
// 返回该子区间的首尾迭代器
// 如果没有找到,就返回可插入的区域
pair<vector<int>::iterator, vector<int>::iterator> pairIte;
pairIte = equal_range(iv.begin(), iv.end(), 22);
cout << *(pairIte.first) << endl;
cout << *(pairIte.second) << endl;

pairIte = equal_range(iv.begin(), iv.end(), 25);
cout << *(pairIte.first) << endl;
cout << *(pairIte.second) << endl;

// random_shuffle(iv.begin(), iv.end());
// copy(iv.begin(), iv.end(), oite);
// cout << endl;

// 将小于 iv.begin() + 5 的元素置于该元素左边
nth_element(iv.begin(), iv.begin() + 5, iv.end());
copy(iv.begin(), iv.end(), oite);
cout << endl;

// 将大于 iv.begin() + 5 的元素置于该元素左边
nth_element(iv.begin(), iv.begin() + 5, iv.end(), greater<int>());
copy(iv.begin(), iv.end(), oite);
cout << endl;

stable_partition(iv.begin(), iv.end(), even());
copy(iv.begin(), iv.end(), oite);
cout << endl;

return 0;
}

《STL源码剖析》 – P372 - P374


1.2.1 sort

快速排序会为了一些小序列产生大量的递归调用,quick sort 在效率上反而低于 insertion sort,STL里面的设计了两个 threshold_lg, _stl_threshold

_lg 用以控制分割恶化的情况,用以选择 quick_sort 或者 heap_sort

_stl_threshold 默认为 16 用以作为选择quick_sort 或者 insertion sort;

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
// sort 对外接口
template <class RandomAccessIterator>
inline void sort(RandomAccessIterator first, RandomAccessIterator last) {
if (first != last) { // _lg 设定递归深度上限
__introsort_loop(first, last, value_type(first), __lg(last - first) * 2);
__final_insertion_sort(first, last); // 插入排序
}
}
// _lg:用以控制分割恶化的情况,找出 2^k <= n 的最大值 k
// 举例而言,当 size = 40 时,_lg = 5
// __introsort_loop 的最后一个参数为 5*2,即最多允许分割 5*2 层
template <class Size>
inline Size __lg(Size n) {
Size k;
for (k = 0; n > 1; n >>= 1) ++k;
return k;
}
// __introsort_loop:intosort的具体实现
// 子序列长度大于 16 使用快速排序,子序列长度小于 16 使用插入排序
// (递归深度未达到上限,使用快速排序,递归深度达到上限,使用堆排序,后直接返回)
template <class RandomAccessIterator, class T, class Size>
void __introsort_loop(RandomAccessIterator first, RandomAccessIterator last, T *, Size depth_limit) {
//__STL_threshold是一个定义为 16 的全局常数
while (last - first > /*__stl_threshold*/ 16) {
if (depth_limit == 0) { // 已经产生了分割恶化
partial_sort(first, last, last); // 改用 heap-sort
return;
}
--depth_limit; // 每次递归就会越接近上限
RandomAccessIterator cut = __unguarded_partition(first, last,
T(__median(*first, *(first + (last - first) / 2), *(last - 1))));
__introsort_loop(cut, last, value_type(first), depth_limit); // 递归
last = cut; // 回归 while,执行左侧排序
}
// last - first <= 16 会 return, 回到主函数
// 接着调用 __final_insertion_sort (插入排序)
}
// 插入排序...
void __final_insertion_sort(RandomAccessIterator first, RandomAccessIterator last){}

《STL源码剖析》 – P389 - P400


2. algobase

STL 标准规格中并没有区分基本算法或复杂算法,然而SGI却把常用的一些算法定义与<stl_algobase.h> 之中,其他算法定义与 <stl_algo.h>中,以下一一列举这些所谓的基本算法。以下分为两个部分,第一部分:equal,fill,fill_n,iter_swap,lexicographical_compare,max,min,mismatch,swap

第二部分:copy, copy_backward

2.1 第一部分

直接上例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#include <algorithm>
#include <vector>
#include <functional>
#include <iostream>
#include <iterator>
#include <string>

using namespace std;

template <class T>
struct display {
void operator()(const T& x) const {
cout << x << ' ';
}
};

int main() {
int ia[9] = {0,1,2,3,4,5,6,7,8};
vector<int> iv1(ia, ia + 5);
vector<int> iv2(ia, ia + 9);
// {0,1,2,3,4} v.s {0,1,2,3,4,5,6,7,8}
// 该操作很危险,应该先判断返回的迭代器是否不等于容器的 end()
// cout << *(mismatch(iv1.begin(), iv1.end(), iv2.begin()).first);
// cout << *(mismatch(iv1.begin(), iv1.end(), iv2.begin()).second);

// 如果第二序列的元素较多,多出来的不予考虑
cout << equal(iv1.begin(), iv1.end(), iv2.begin()) << endl; // 1, true
// {0,1,2,3,4} 不等于 {3,4,5,6,7}
cout << equal(iv1.begin(), iv1.end(), &ia[3]) << endl; // 0, false

fill(iv1.begin(), iv1.end(), 9); // 9 9 9 9 9
for_each(iv1.begin(), iv1.end(), display<int>()); cout << endl;

fill_n(iv1.begin(), 3, 7); // 7 7 7 9 9
for_each(iv1.begin(), iv1.end(), display<int>()); cout << endl;

vector<int>::iterator ite1 = iv1.begin(); // 指向 7
vector<int>::iterator ite2 = ite1;
advance(ite2, 3); // 指向 9

iter_swap(ite1, ite2); // 9 7 7 9 9
for_each(iv1.begin(), iv1.end(), display<int>()); cout << endl;

cout << max(*ite1, *ite2) << endl; // 9
cout << min(*ite1, *ite2) << endl; // 7

// 以下是错误形式,比较的迭代器本身的大小,与所指向的元素无关
cout << *max(ite1, ite2) << endl;
cout << *min(ite1, ite2) << endl;

// iv1 : 9 7 7 7 9; iv2 : 0 1 2 3 4 5 6 7 8
swap(*iv1.begin(), *iv2.begin());
for_each(iv1.begin(), iv1.end(), display<int>()); cout << endl;
// 0 7 7 7 9
for_each(iv2.begin(), iv2.end(), display<int>()); cout << endl;
// 9 1 2 3 4 5 6 7 8

string stra1[] = {"jamie", "JJHou", "Jason"};
string stra2[] = {"jamie", "JJHou", "Jerry"};

cout << lexicographical_compare(stra1, stra1 + 2, stra2, stra2 + 2) << endl; // 1
cout << lexicographical_compare(stra1, stra1 + 2, stra2, stra2 + 2, greater<string>()) << endl; // 0


return 0;
}

《STL源码剖析》 – P306 - P307


2.2 第二部分

copy() 的操作脉络如下所示:

上述的各个版本的的测试见:

《STL源码剖析》 – P321 - P324

**注意:**copy(),以及copy_backward() 需要注意区间重叠的情况。

《STL源码剖析》 – P301, P326


3. algoset

STL 一共提供了四种与 set 相关的算法,分别是并集、交集、差集和对称差集。本节的四个算法所接受的 set,必须是有序区间,元素值可以重复出现。

直接上例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#include <set>
#include <iostream>
#include <algorithm>
#include <iterator>
using namespace std;

template<class T>
struct display {
void operator()(const T& x) {
cout << x << ' ';
}
};

int main() {
int ia1[6] = {1,3,5,7,9,11};
int ia2[7] = {1,1,2,3,5,8,13};
multiset<int> S1(ia1, ia1 + 6);
multiset<int> S2(ia2, ia2 + 7);
// {1,3,5,7,9,11}
for_each(S1.begin(), S1.end(), display<int>()); cout << endl;
// {1,1,2,3,5,8,13}
for_each(S2.begin(), S2.end(), display<int>()); cout << endl;

multiset<int>::iterator first1 = S1.begin();
multiset<int>::iterator last1 = S1.end();
multiset<int>::iterator first2 = S2.begin();
multiset<int>::iterator last2 = S2.end();

ostream_iterator<int> oite(cout, " ");
cout << "Union of S1 and S2: ";
set_union(first1, last1, first2, last2, oite); cout << endl;
// Union of S1 and S2: 1 1 2 3 5 7 8 9 11 13
// 重复元素为 max(m, n) 次, m 是 S1 中出现的次数, n 是 S2 中出现的次数。

first1 = S1.begin();
first2 = S2.begin();
cout << "Intersection of S1 and S3: ";
set_intersection(first1, last1, first2, last2, oite); cout << endl;
// Intersection of S1 and S3: 1 3 5
// 重复元素为 min(m, n) 次, m 是 S1 中出现的次数, n 是 S2 中出现的次数。

first1 = S1.begin();
first2 = S2.begin();
cout << "Difference of S1 and S2 (S1 - S2): ";
set_difference(first1, last1, first2, last2, oite); cout << endl;
// Difference of S1 and S2 (S1 - S2): 7 9 11
// 重复元素为 max(n - m, 0) 次, m 是 S1 中出现的次数, n 是 S2 中出现的次数。
// set_difference 表示 S1 中有 S2 中没有,表示 -> S1 - S1 ∩ S2

first1 = S1.begin();
first2 = S2.begin();
cout << "Difference of S2 and S1 (S2 - S1): ";
set_difference(first2, last2, first1, last1, oite); cout << endl;

first1 = S1.begin();
first2 = S2.begin();
cout << "symmetric difference of S1 and S2: ";
set_symmetric_difference(first1, last1, first2, last2, oite); cout << endl;
// set_symmetric_difference 表示 -> S1 U S2 - S1 ∩ S2

return 0;
}

《STL源码剖析》 – P329 - P330


4 heap

包含以下四个函数:make_heap, pop_heap, push_heap, sort_heap


5. numeric

该部分的算法,统称为数值算法,欲使用他们必须要包含头文件。

1
#include <numeric>

总共包含了 6 个算法:

  • accumulate
  • adjacent_difference
  • partial_sum
  • inner_product
  • power
  • iota

5.1 accumulate

1
2
3
4
5
6
7
8
9
10
11
12
13
// ~ 累积算法
template <class InputIterator, class T> // ~ 如果没有给定初值,会默认初始化一个
T accumulate(InputIterator first, InputIterator last, T init = T()) {
for (; first != last; ++first) init += *first;
return init; // ~ 返回临时对象
}

template <class InputIterator, class T, class BinaryOperation>
// ~ 二元操作符,不必满足交换律、结合律
T accumulate(InputIterator first, InputIterator last, T init, BinaryOperation binary_op) {
for (; first != last; ++first) binary_op(init, *first); // ~ 可以传入二元仿函数
return init;
}

**使用:**详细代码见 5.7

1
2
3
4
vector<int> iv{2, 3, 4, 5, 6};
cout << accumulate(iv.begin(), iv.end(), 0) << endl;
// output : 20
// 0 + 2 + 3 + 4 + 5 + 6

《STL源码剖析》 – P299 - P300


5.2 adjacent_difference

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
template <class InputIterator, class OutputIterator>  // ! 版本一
OutputIterator adjacent_difference(InputIterator first, InputIterator last, OutputIterator result) {
if (first == last) return first;
*result = *first; // ! 记录第一个元素
return __adjacent_difference(first, last, result, value_type(first));
}

template <class InputIterator, class OutputIterator, class T>
OutputIterator __adjacent_difference(InputIterator first, InputIterator last, OutputIterator result, T *) {
T value = *result;
while (++first != last) {
T temp = *first;
*++result = temp - value; // ! 先++,首部元素还是原来的数据
value = temp;
}
return ++result;
}

template <class InputIterator, class OutputIterator, class BinaryOperation> // ! 版本二
OutputIterator adjacent_difference(InputIterator first, InputIterator last, OutputIterator result, BinaryOperation binary_op) { // ! 传入仿函数的版本
if (first == last) return first;
*result = *first;
return __adjacent_difference(first, last, result, value_type(first), binary_op);
}

template <class InputIterator, class OutputIterator, class T,
class BinaryOperation>
OutputIterator __adjacent_difference(InputIterator first, InputIterator last, OutputIterator result, T *,BinaryOperation binary_op) {
T value = *result;
while (++first != last) {
T temp = *first;
*++result = binary_op(temp, value); // ! 下一个元素在运算符左侧
value = temp;
}
return ++result;
}

存在两个版本,分别针对默认的邻接差值版本,以及自定义仿函数版本。

使用:

1
2
3
4
5
6
7
8
vector<int> iv{2, 3, 4, 5, 6};
ostream_iterator<int> oite(cout, " ");
adjacent_difference(iv.begin(), iv.end(), oite);
adjacent_difference(iv.begin(), iv.end(), oite, plus<int>());
// output1 : 2 1 1 1 1
// (2, 3-2, 4-3, 5-4, 6-5)
// output2 : 2 5 7 9 11
// (2, 3+2, 4+3, 5+4, 6+5)

《STL源码剖析》 – P300 - P301


5.3 partial_sum

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
template <class InputIterator, class OutputIterator> // @ 版本一
OutputIterator partial_sum(InputIterator first, InputIterator last, OutputIterator result) {
if (first == last) return first;
*result = *first;
return __partial_sum(first, last, result, value_type(first));
}

template <class InputIterator, class OutputIterator, class T>
OutputIterator __partial_sum(InputIterator first, InputIterator last,
OutputIterator result, T *) {
T value = *first;
while (++first != last) {
value += *first; // @ 默认加法
*++result = value;
}
return ++result;
}

template <class InputIterator, class OutputIterator, class BinaryOperation> // @ 版本二
OutputIterator partial_sum(InputIterator first, InputIterator last,
OutputIterator result, BinaryOperation binary_op) {
if (first == last) return first;
*result = *first;
return __partial_sum(first, last, result, value_type(first), binary_op);
}

template <class InputIterator, class OutputIterator, class T,
class BinaryOperation>
OutputIterator __partial_sum(InputIterator first, InputIterator last,
OutputIterator result, T *,
BinaryOperation binary_op) {
T value = *result;
while (++first != last) {
value = binary_op(value, *first); // ! 下一个元素在运算符右
*++result = value;
}
return ++result;
}

存在两个版本,分别针对默认累计加法版本,以及自定义仿函数版本。

使用:

1
2
3
4
5
6
7
8
vector<int> iv{2, 3, 4, 5, 6};
ostream_iterator<int> oite(cout, " ");
partial_sum(iv.begin(), iv.end(), oite);
partial_sum(iv.begin(), iv.end(), oite, minus<int>());
// output1 : 2 5 9 14 20
// (2, 2 + 3, 2 + 3 + 4, 2 + 3 + 4 + 5, 2 + 3 + 4 + 5 + 6) // 类似accumulate
// output2 : 2 -1 -5 -10 -16
// (2, 2 - 3, 2 - 3 - 4, 2 - 3 - 4 - 5, 2 - 3 - 4 - 5 - 6)

其实 partial_sumadjacent_difference 为互逆运算。

1
2
3
4
5
6
7
vector<int> iv{2, 3, 4, 5, 6};
vector<int> iv2{2, 5, 9, 14, 20};
ostream_iterator<int> oite(cout, " ");
partial_sum(iv.begin(), iv.end(), oite);
adjacent_difference(iv2.begin(), iv2.end(), oite);
// output1 : 2 5 9 14 20
// output2 : 2 3 4 5 6

《STL源码剖析》 – P303 - P304


5.4 inner_product

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
template <class InputIterator1, class InputIterator2, class T>  
T inner_product(InputIterator1 first1, InputIterator2 last1, InputIterator2 first2, T init = T()) {
for (; first1 != last1; ++first1, ++first2) init += (*first1) * (*first2);
return init;
} // @ 注意如果传入的长度不相同,也是可以通过编译的,就是结果不对

template <class InputIterator1, class IutputIterator2, class T,
class BinaryOperation1, class BinaryOperation2>
T inner_product(InputIterator1 first1, InputIterator1 last1,
IutputIterator2 first2, T init, BinaryOperation1 binary_op1,
BinaryOperation2 binary_op2) {
for (; first1 != last1; ++first1, ++first2) // @ 可以传入两个二元仿函数
// @ op1 替换原来的 +,op2 替换原来的 *
init = binary_op1(init, binary_op2(*first1, *first2));
// @ first1 位于操作符左侧
return init;
}

存在两个版本,分别针对默认内积版本,以及自定义仿函数版本。

使用:

1
2
3
4
5
6
vector<int> iv{2,3,4,5,6};
vector<int> iv2{2,5,9,14,20};
cout << inner_product(iv.begin(), iv.end(), iv2.begin(), 0) << endl;
cout << inner_product(iv.begin(), iv.end(), iv2.begin(), 0, plus<int>(), minus<int>()) << endl;
// output1 : 245 (2*2 + 3*5 + 4*9 + 5*14 + 6*20)
// output2 : -30 ((2-2) + (3-5) + (4-9) + (5-14) + (6-20))

《STL源码剖析》 – P301 - P302


5.5 power

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
// $ 对x反复执行n次算数操作
template <class T, class Integer>
inline T power(T x, Integer n) {
return power(x, n, mutiplies<T>()); // $ 默认的情况下使用乘法
}
// $ 支持传入二元运算符,该运算必须支持结合律
template <class T, class Integer, class MonoidOperation>
inline T power(T x, Integer n, MonoidOperation op) {
// $ 快速幂算法
// $ 该算法主要针对 n >= 0 的情况
if (n == 0)
return identity_element(op);
// $ 取出证同元素(该元素与A作运算将得到A自身,例如加法的证同为0,乘法为1,等等)
else {
while ((n & 1) == 0) { // $ 剥离偶数因子
n >>= 1;
x = op(x, x);
}
T result = x;
n >>= 1; // $ 右移
while (n != 0) {
x = op(x, x);
if ((n & 1) != 0) // $ 依然存在奇数位
result = op(result, x);
n >>= 1;
}
return result;
}
}

剑指offer16使用了同样的思想。

使用:

1
2
3
4
cout << power(10, 3) << endl;
cout << power(10, 3, plus<int>()) << endl;
// output1 : 1000 (10*10*10)
// output2 : 30 (10+10+10)

《STL源码剖析》 – P304 - P305

剑指offer16


5.6 itoa

1
2
3
4
5
6
7
// ~ 赋值操作
// ~ vector<int> vec(5);
// ~ iota(vec.begin(), vec.end(), 0); --> vec : [0,1,2,3,4]
template <class ForwardIterator, class T>
void iota(ForwardIterator first, ForwardIterator last, T value) {
while (first != last) *first++ = value++;
}

使用:

1
2
3
4
5
int n = 3;
iota(iv.begin(), iv.end(), n); // ! 在指定区间内填入 n,n+1,...
for (const auto& v : iv) {
cout << v << ' ';
}

《STL源码剖析》 – P305


5.7 综合例子

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
#include <iostream>
#include <vector>
#include <iterator>
#include <functional>
#include <ext/numeric>
using namespace std;
// 虽然GCC编译器完美集成了SGI STL,但是在使用GCC编译器编译使用SGI STL的源码时,需要注意:如iota、power等函数是非C++标准的函数,是SGI专属的函数,对于这类函数需要修改包含的头文件.
using namespace __gnu_cxx;

int main() {
vector<int> iv{2,3,4,5,6};
vector<int> iv2{2,5,9,14,20};
cout << "accumulate test : " << endl;
cout << accumulate(iv.begin(), iv.end(), 0) << endl;

ostream_iterator<int> oite(cout, " ");
cout << "adjacent_difference test : " << endl;
adjacent_difference(iv.begin(), iv.end(), oite); cout << endl;
adjacent_difference(iv.begin(), iv.end(), oite, plus<int>());


cout << endl << "partial_sum test : " << endl;
partial_sum(iv.begin(), iv.end(), oite); cout << endl;
partial_sum(iv.begin(), iv.end(), oite, minus<int>()); cout << endl;
adjacent_difference(iv2.begin(), iv2.end(), oite);

cout << endl << "inner_product test : " << endl;
cout << inner_product(iv.begin(), iv.end(), iv2.begin(), 0) << endl;
cout << inner_product(iv.begin(), iv.end(), iv2.begin(), 0, plus<int>(), minus<int>()) << endl;

cout << "power test : " << endl;
cout << power(10, 3) << endl;
cout << power(10, 3, plus<int>()) << endl;

cout << "iota test : " << endl;
int n = 3;
iota(iv.begin(), iv.end(), n);
for (const auto& v : iv) {
cout << v << ' ';
}
cout << endl;

return 0;
}

运行结果如下:

《STL源码剖析》 – P298 - P299


5 参考资料

《STL源码剖析》

MiniSTL / Algorithm 部分



----------- 本文结束 -----------




0%