寻找数组中第K大的元素

2018/12/25 算法 leetcode

问题

Find the kth largest element in an unsorted array. Note that it is the kth largest element in the sorted order, not the kth distinct element.

Example 1:

Input: [3,2,1,5,6,4] and k = 2
Output: 5
Example 2:

Input: [3,2,3,1,2,4,5,5,6] and k = 4
Output: 4
Note: 
You may assume k is always valid, 1 ≤ k ≤ array's length.

tag: Medium

分析

这题最简单的做法是将数组排序,然后直接返回第K大的元素。复杂度为:O(NlogN)。但是,很明显,出题者并不想让我们这么做。

如果对数组排序,算法的复杂度起码是 O(NlogN)。那么如果我们不排序,能不能求出第K大元素呢?答案是可以的,我们知道快速排序中有一个步骤是 partition。它选择一个元素作为枢纽(pivot),将所有小于枢纽的元素放到枢纽的左边,将所有大于枢纽的元素放到枢纽的右边。然后返回枢纽在数组中的位置。那么,关键就在这里了。如果此时返回的枢纽元素在数组中的位置刚好是我们所要求的位置,问题就能得到解决了。

我们先选取一个枢纽元素,从数组的第一个元素开始到最后一个元素结束。将大于枢纽的左边元素和小于枢纽的右边元素交换。然后判断当前枢纽元素的 index 是否为 n - k,如果是则直接返回这个枢纽元素。如果 index 大于 n - k,则我们从枢纽的左边继续递归寻找,如果 index 小于 n - k,则我们从枢纽的右边继续递归寻找。

代码

public class KthLargestElementInArray {

    public int findKthLargest(int[] nums, int k) {
        return findKth(nums, nums.length - k /* 第k大,也就是排序后数组中 index 为 n - k 的元素*/, 0, nums.length - 1);
    }

    public int findKth(int[] nums, int k, int low, int high) {
        if (low >= high) {
            return nums[low];
        }
        int pivot = partition(nums, low, high); // 枢纽元素的 index
        if (pivot == k) {
            return nums[pivot];
        } else if (pivot < k) { // 枢纽元素的 index 小于 k,继续从枢纽的右边部分找
            return findKth(nums, k, pivot + 1, high);
        } else { // 枢纽元素的 index 大于 k,继续从枢纽的左边部分找
            return findKth(nums, k, low, pivot - 1);
        }
    }

    int partition(int[] nums, int low, int high) {
        int i = low;
        int j = high + 1;
        int pivot = nums[low];
        while (true) {
            while (less(nums[++i], pivot)) {
                if (i == high) {
                    break;
                }
            }
            while (less(pivot, nums[--j])) {
                if (j == low) {
                    break;
                }
            }
            if (i >= j) {
                break;
            }
            swap(nums, i, j);
        }
        swap(nums, low, j);
        return j;
    }

    boolean less(int i, int j) {
        return i < j;
    }

    void swap(int[] nums, int i, int j) {
        if (i == j) {
            return;
        }
        int temp = nums[i];
        nums[i] = nums[j];
        nums[j] = temp;
    }
}

复杂度分析

复杂度为:O(N) + O(N / 2) + O(N / 4) + … 最终算法复杂度为 O(N)

Search

    Table of Contents