Zer0e's Blog

TopK问题与快速选择算法

字数统计: 1.9k阅读时长: 8 min
2020/05/13 Share

前言

TopK问题通常指的是寻找数组中第k大(小)的数或前k大(小)的数组。最简单的方法当然是使用排序,本文将从TopK问题入手,讲讲常见的几种解决方法并详细讲解快速选择算法。

实践1

解决TopK问题的解决方案有以下几种:

  • 排序
  • 快速选择算法

我们使用剑指offer中给出的一道题目作为例子。
输入整数数组 arr ,找出其中最小的 k 个数。例如,输入4、5、1、6、2、7、3、8这8个数字,则最小的4个数字是1、2、3、4。
接下来我将详细讲解这几种解决方案来解决上述问题。

排序

排序是最容易想到也是比较简单的方法,许多语言中都内置了排序方法,当然自己实现排序也是可以的。

1
2
3
4
class Solution:
def getLeastNumbers(self, arr, k)
arr.sort()
return arr[:k]

由于python中的排序使用的是快速排序,所以平均时间复杂度为O(nlogn)。

我们使用大根堆来解决上述问题。由于上述题目是寻找前k小的数,所以我们使用大根堆,poll出n-k个数,留下的就是前k小的数。详细思路为:将k个数插入大根堆中,从第k+1个数开始,如果当前数小于堆顶的数,把堆顶数弹出,再插入当前数。最后留在堆中的数即为前k小的数。
在java当中,可以使用PriorityQueue并重写比较器来实现一个大根堆,而python中因为heapq模块只支持小根堆,我们需要将数组中的数取反,才能使用小根堆来获得前k个最小值。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class Solution {
public int[] getLeastNumbers(int[] arr, int k) {
if (k == 0 || arr.length == 0) {
return new int[0];
}
Queue<Integer> pq = new PriorityQueue<>((v1, v2) -> v2 - v1);
for (int num: arr) {
if (pq.size() < k) {
pq.offer(num);
} else if (num < pq.peek()) {
pq.poll();
pq.offer(num);
}
}

int[] res = new int[pq.size()];
int idx = 0;
for(int num: pq) {
res[idx++] = num;
}
return res;
}
}

使用小根堆

1
2
3
4
5
6
7
8
9
10
11
12
13
class Solution:
def getLeastNumbers(self, arr, k)
if k == 0:
return list()

pq = [-x for x in arr[:k]]
heapq.heapify(hp)
for i in range(k, len(arr)):
if -pq[0] > arr[i]:
heapq.heappop(pq)
heapq.heappush(pq, -arr[i])
ans = [-x for x in pq]
return ans

使用堆的平均时间复杂度为O(nlogk),空间复杂度为O(k)。

快速选择

快速选择算法其实是快速排序的思想,我们可以先回忆下快排的思想。使用快排思想可以将数组分隔为左右两边,数组下标为[0,a)与[a,n),如果a刚好等于k-1的话,那么[0,a)就是我们要的前k小的数,如果a小于k-1则在右区间继续寻找a,如果a大于k-1的话则在左区间寻找。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class Solution:
def getLeastNumbers(self, arr: List[int], k: int) -> List[int]:
def quickSelect(arr,left,right,k):
if k < 0:
return []

if left <= right:
i = left
j = right
key = arr[left]
while i < j:
while i < j and arr[j] > key:
j -= 1
while i < j and arr[i] <= key:
i += 1
if i < j:
arr[i],arr[j] = arr[j],arr[i]
arr[left],arr[j] = arr[j],arr[left]

if j == k:
return arr[:j+1]

if j > k:
return quickSelect(arr,left,j-1,k)
else:
return quickSelect(arr,j+1,right,k)

return quickSelect(arr,0,len(arr)-1,k-1)

这个算法的改进之处与快排的改进之处一致,在于每次对于key的选取,如果数组本身有序,并且key总是取左边一个数作为对比,或者说key的选取总是最大值或最小值,那么可能导致时间复杂度退化为O(n^2),并且由于快速选择相较于快速排序,只需要对左区间或者右区间进行partition,而不是左右区间都要partition,因此时间复杂度为N + N/2 + N/4 + … + N/N = 2N,即O(N)时间复杂度。

实践2

上面一道题我们解决了前k小的数,而TopK其实说的是top,即第k个大的数。我们使用leetcode第215题。
在未排序的数组中找到第 k 个最大的元素。请注意,你需要找的是数组排序后的第 k 个最大的元素,而不是第 k 个不同的元素。

我们照样使用堆和快速选择来解决这个问题。

这里我们使用小根堆,将数全部入堆,如果堆大小超过k,则poll出堆顶元素,最后在堆顶的就是第k大的数。

1
2
3
4
5
6
7
8
9
10
11
12
class Solution {
public int findKthLargest(int[] nums, int k) {
PriorityQueue<Integer> hp =
new PriorityQueue<Integer>();
for (int n: nums) {
hp.add(n);
if (hp.size() > k)
hp.poll();
}
return hp.poll();
}
}

在python的heapq模块中,我们可以使用nlargest方法来获取前k个大的数,并返回最后一个

1
2
3
class Solution:
def findKthLargest(self, nums, k):
return heapq.nlargest(k, nums)[-1]

快速选择

我们可以完全复制上一道题的代码,只需改动些许地方。1.当j==k时,返回的是一个数。2.由于上一道题代码是找第k个小的数,所以刚好是下标与k-1相等时返回,也就是说寻找第k大相当于寻找第n-k+1小的数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class Solution:
def findKthLargest(self, nums: List[int], k: int) -> int:
def quickSelect(arr,left,right,k):
if k < 0:
return []

if left <= right:
i = left
j = right
key = arr[left]
while i < j:
while i < j and arr[j] > key:
j -= 1
while i < j and arr[i] <= key:
i += 1
if i < j:
arr[i],arr[j] = arr[j],arr[i]
arr[left],arr[j] = arr[j],arr[left]

if j == k:
return arr[j]

if j > k:
return quickSelect(arr,left,j-1,k)
else:
return quickSelect(arr,j+1,right,k)
# len(nums)-k 是数组下标
return quickSelect(nums,0,len(nums)-1,len(nums)-k)

老生常谈的优化,对于key的选择很关键,在LeetCode中,如果key总是为左边那个数,则时间耗时1100ms,而如果使用随机下标与left进行交换,则时间降至50ms以内。并且减少了递归所需要的内存消耗。具体部分代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
....
if left <= right:
i = left
j = right
a = random.randint(left, right)
arr[left],arr[a] = arr[a],arr[left]
key = arr[left]
while i < j:
while i < j and arr[j] > key:
j -= 1
while i < j and arr[i] <= key:
i += 1
if i < j:
arr[i],arr[j] = arr[j],arr[i]
arr[left],arr[j] = arr[j],arr[left]
...

partition思路

在官方答案中是将右边作为起始点,其思想大同小异,这里简单讲讲与我思路的不同。
首先随机选取一个pivot,并将这个数与最右边那个数进行一次交换。
第二步,定义i,j指针,初始化为left,循环退出条件为j指针等于最右边数的下标。查看nums[j]是否小于等于pivot,如果不是,则j向右移动。如果是,交换i,j位置的元素,并且i,j都向右移动。
第三步,重复第二步,直到j==right,此时交换i与j的元素,此时,i左边元素都小于它,右边元素都大于它。
以上就是另一种partition的思路。
这篇文章中的partition思路与那篇快速排序的文章相同。

总结

快速选择算法与快速排序思想一致,通过对数组进行partition来获取前k小的数,通过写这篇文章,再一次复习了快速排序算法,并对两种算法有了自己的认识与理解。

原文作者:Zer0e

原文链接:https://re0.top/2020/05/13/topk/

发表日期:五月 13日 2020, 4:10:00 下午

更新日期:May 13th 2020, 11:54:29 pm

版权声明:本文采用知识共享署名-非商业性使用 4.0 国际许可协议进行许可

CATALOG
  1. 1. 前言
  2. 2. 实践1
    1. 2.1. 排序
    2. 2.2.
    3. 2.3. 快速选择
  3. 3. 实践2
    1. 3.1.
    2. 3.2. 快速选择
    3. 3.3. partition思路
  4. 4. 总结